# LoRA Implementation

In [1]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
random.seed(42)

if torch.cuda.is_available():
    logger.info("Using GPU")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    logger.info("Using MPS")
    device = torch.device("mps")
else:
    logger.info("Using CPU")
    device = torch.device("cpu")
device

INFO:__main__:Using MPS


device(type='mps')

## Helpers

In [2]:
def generate_token(ins: torch.Tensor, model: torch.nn.Module, detokenizer: list[str]) -> list[str]:
    with torch.no_grad():
        _output = model(ins)
    _next_token_ids = _output[:, -1, :].argmax(dim=1)
    return [detokenizer[tid] for tid in _next_token_ids]

## Creating a Model

In [16]:
class FakeModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.hidden = torch.nn.Linear(embedding_size, 1024)
        self.head = torch.nn.Linear(1024, 10)
    
    def forward(self, x):
        x = self.embedding(x.long())
        x = self.hidden(x)
        x = self.head(x)
        return x

In [4]:
_x = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]])
vocab = [
    "red",
    "orange",
    "yellow",
    "green",
    "blue",
    "indigo",
    "violet",
    "magenta",
    "marigold",
    "chartreuse",
]

In [5]:
fake_model = FakeModel(vocab_size=len(vocab), embedding_size=1024)
fake_model

FakeModel(
  (embedding): Embedding(10, 1024)
  (hidden): Linear(in_features=1024, out_features=1024, bias=True)
  (head): Linear(in_features=1024, out_features=10, bias=True)
)

In [6]:
_generated_tokens = generate_token(ins=_x, model=fake_model, detokenizer=vocab)
_generated_tokens

['yellow']

## Implementing LoRA

In [7]:
X = torch.randn(1, 8, 1024)
X.shape

torch.Size([1, 8, 1024])

In [8]:
W = fake_model.hidden.weight

lora_a = torch.randn(1024, 2)
lora_b = torch.randn(2, 1024)

lora_numel = lora_a.numel() + lora_b.numel()
base_numel = W.numel()
print("|A+B| / |W|:", lora_numel / base_numel)

|A+B| / |W|: 0.00390625


In [9]:
# compute the output of X @ W (the original linear layer)
base_output = fake_model.hidden(X)

# compute the output of X @ A @ B (the added lora adapter)
lora_output = X @ lora_a @ lora_b

# sum them together
total_output = base_output + lora_output

# output should have the same shape as the original output:
# (batch_size, sequence_length, hidden_size)
total_output.shape

torch.Size([1, 8, 1024])

In [10]:
class LoraLayer(torch.nn.Module):
    def __init__(self, base_layer, r):
        super().__init__()
        self.base_layer = base_layer
        
        d_in, d_out = self.base_layer.weight.shape
        self.lora_a = torch.randn(d_in, r)
        self.lora_b = torch.randn(r, d_out) 
        
    def forward(self, x):
        y1 = self.base_layer(x)
        y2 = x @ self.lora_a @ self.lora_b
        return y1 + y2
    
    def during_inference(self, x):
        return x @ self.lora_a @ self.lora_b

In [11]:
# wrap the linear layer of our toy model, use rank 2
lora_layer = LoraLayer(fake_model.hidden, 2)
lora_layer(X).shape

torch.Size([1, 8, 1024])

In [12]:
fake_model.hidden = lora_layer
fake_model

FakeModel(
  (embedding): Embedding(10, 1024)
  (hidden): LoraLayer(
    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (head): Linear(in_features=1024, out_features=10, bias=True)
)

In [18]:
next_tokens = generate_token(ins=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]), model=fake_model, detokenizer=vocab)
next_tokens

['red']