### LoRA concept

In [1]:
import torch
torch.manual_seed(2024)

<torch._C.Generator at 0x77b435ab4390>

In [2]:
# Define a simple NN model for demo
class SimpleModel(torch.nn.Module):

    def __init__(self, vocab_size, dim, num_classes):
        super().__init__()
        self.emb = torch.nn.Embedding(vocab_size, dim)
        self.linear = torch.nn.Linear(dim, dim)
        self.lm_head = torch.nn.Linear(dim, num_classes)

    def forward(self, input_ids):
        x = self.emb(input_ids)
        x = self.linear(x)
        x = self.lm_head(x)
        return x

# Define a vocab
vocab = [
    "red",
    "orange",
    "yellow",
    "green",
    "blue",
    "indigo",
    "violet",
    "magenta",
    "marigold",
    "chartreuse",
]

In [3]:
# Create a model
vocab_size, dim, num_classes = 5, 1024, 10
simple_model = SimpleModel(vocab_size, dim, num_classes)

In [4]:
# Create a sample input 
sample_token_ids = torch.LongTensor([[2, 4, 0]])
print(sample_token_ids.shape, sample_token_ids)


torch.Size([1, 3]) tensor([[2, 4, 0]])


In [5]:
# Function to generate next token
def generate_next_token(model, vocab, **kwargs):
    
    # Generate next token id
    with torch.no_grad():
        logits = model(**kwargs)
    next_token_logits = logits[:, -1, :]
    next_token_ids = next_token_logits.argmax(dim=1)  # Argmax on vocab
    print(next_token_ids)

    # Get next token
    next_tokens = [vocab[token_id] for token_id in next_token_ids]
    print(next_tokens)

    return next_tokens


In [6]:
# Generate next token using base model
next_token = generate_next_token(simple_model, vocab, input_ids=sample_token_ids)

tensor([8])
['marigold']


In [7]:
# LoRA concept on linear layer

# Example linear layer input
seq_len = 3
dim = 1024
x = torch.randn(1, seq_len, dim) # (bs, seq_len, dim)

# Example lora_a, lora_b
rank = 2
A = torch.randn(dim, rank)
B = torch.randn(rank, dim)

# Concept
base_output = simple_model.linear(x)
lora_output = x @ A @ B
total_output = base_output + lora_output
print(base_output.shape, lora_output.shape, total_output.shape)

# Num elems
base_num_params = simple_model.linear.weight.numel()
lora_num_params = A.numel() + B.numel()
print(f"LoRA trainable params: {lora_num_params}, Actual params: {base_num_params}")


torch.Size([1, 3, 1024]) torch.Size([1, 3, 1024]) torch.Size([1, 3, 1024])
LoRA trainable params: 4096, Actual params: 1048576


In [8]:
## LoRA layer in Pytorch

# Main LoRA
import math
class LoraLayerModule(torch.nn.Module):
    def __init__(self, din, dout, rank, alpha):
        super().__init__()
        self.alpha = alpha
        self.A = torch.nn.Parameter(torch.empty(din, rank)) # torch.randn(din, rank)
        torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        self.B = torch.nn.Parameter(torch.zeros(rank, dout)) # torch.randn(rank, dout)

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

# LoRA on any base layer
class LoraLayer(torch.nn.Module):
    def __init__(self, base_layer, rank, alpha):
        super().__init__()
        self.base_layer = base_layer
        din, dout = base_layer.weight.shape
        self.lora = LoraLayerModule(din, dout, rank, alpha)

    def forward(self, x):
        return self.base_layer(x) + self.lora(x)


In [9]:
# Test Lora layer
rank, alpha = 2, 4
lora_layer = LoraLayer(simple_model.linear, rank, alpha)

# Example linear layer input
seq_len = 3
dim = 1024
x = torch.randn(1, seq_len, dim) # (bs, seq_len, dim)
total_output = lora_layer(x)
print(total_output.shape)

torch.Size([1, 3, 1024])


In [10]:
# Simple way to create lora layer & replace existing model layers

# Create a base model
vocab_size, dim, num_classes = 5, 1024, 10
simple_model = SimpleModel(vocab_size, dim, num_classes)

# Create Lora layer
rank, alpha = 2, 4
lora_layer = LoraLayer(simple_model.linear, rank, alpha)

# Replace existing with Lora
simple_model.linear = lora_layer
print(simple_model)

SimpleModel(
  (emb): Embedding(5, 1024)
  (linear): LoraLayer(
    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
    (lora): LoraLayerModule()
  )
  (lm_head): Linear(in_features=1024, out_features=10, bias=True)
)


In [11]:
# Generate next token using Lora model
next_token = generate_next_token(simple_model, vocab, input_ids=sample_token_ids)

tensor([9])
['chartreuse']


In [12]:
# General way to create lora layer & replace existing model layers

# Get modules for LoRA
def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:  # needed for 16 bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

# Replace layers with Lora
def replace_layers_with_lora(model, replace_layers, rank, alpha):
    for name, module in model.named_modules():
        name = name.split('.')
        name = name[0] if len(name) == 1 else name[-1]
        if name in replace_layers:
            setattr(model, name, LoraLayer(module, rank, alpha))

# Get trainable params
def get_num_trainable_params(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params

# Create a base model
vocab_size, dim, num_classes = 5, 1024, 10
simple_model = SimpleModel(vocab_size, dim, num_classes)

# Freeze all params
print(f"Num trainable parameters before freeze: {get_num_trainable_params(simple_model):,}")
for param in simple_model.parameters():
    param.requires_grad = False
print(f"Num trainable parameters after freeze: {get_num_trainable_params(simple_model):,}")

# Find layers for Lora
replace_layers = find_all_linear_names(simple_model)
print(replace_layers)

# LoRA params & layers
rank, alpha = 2, 4
replace_layers_with_lora(simple_model, replace_layers, rank, alpha)
print(simple_model)
print(f"Num Lora trainable parameters: {get_num_trainable_params(simple_model):,}")




Num trainable parameters before freeze: 1,064,970
Num trainable parameters after freeze: 0
['linear']
SimpleModel(
  (emb): Embedding(5, 1024)
  (linear): LoraLayer(
    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
    (lora): LoraLayerModule()
  )
  (lm_head): Linear(in_features=1024, out_features=10, bias=True)
)
Num Lora trainable parameters: 4,096


In [13]:
# Generate next token using Lora model
next_token = generate_next_token(simple_model, vocab, input_ids=sample_token_ids)

tensor([7])
['magenta']


### References:

> https://www.deeplearning.ai/short-courses/efficiently-serving-llms/

> https://github.com/rasbt/LLMs-from-scratch

> https://www.coursera.org/specializations/generative-ai-engineering-with-llms
