In [76]:
import torch.nn as nn
import torch



class LoRALayer(nn.Module):
    def __init__(self, rank, hidden_size_in, hidden_size_out, alpha):
        super().__init__()
        self.rank = rank
        self.alpha = alpha

        self.std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(rank, hidden_size_out))
        self.B = nn.Parameter(torch.zeros(hidden_size_in, rank))

    
    def forward(self, x):
        
        x = self.B @ self.A @ x
        
        return x * self.alpha * 1 / self.std_dev



class LoRADecorator(nn.Module):
    def __init__(self, module, rank, alpha):
        super().__init__()

        self.module = module
        self.weight = module.weight
        self.bias = module.bias

        self.lora = LoRALayer(rank, module.in_features, module.out_features, alpha)
    
    def forward(self, x):
        x_module = self.module(x)
        x_lora = self.lora(x)
        return x_module + x_lora 


In [77]:
mod = nn.Linear(128, 128)
lora = LoRADecorator(rank=3, module=mod, alpha = 0.01)

In [79]:
import torch.nn as nn
import torch
lora(torch.randn(128)).shape

torch.Size([128])

In [55]:
# # lets setup a quick training so we see matrixes changing
# import torch.optim as optim
# import torch.nn.functional as F


# lora = LoRALayer(rank=3, hidden_size_in=128, hidden_size_out=128, alpha = 10)
# optimizer = optim.Adam(lora.parameters(), lr=0.01)
# criterion = nn.MSELoss()

# for i in range(1000):
#     optimizer.zero_grad()
#     x = torch.randn(128)
#     y = lora(x)
#     loss = criterion(x, y)
#     loss.backward()
#     optimizer.step()

# print(loss.item())



In [56]:
from torch.nn import TransformerDecoder

In [83]:
model = TransformerDecoder(nn.TransformerDecoderLayer(512, 8), 2)

In [84]:
# traverse the model, print all the trainable layer names
# do it recursively
from functools import reduce  


lora_to_replace = ['linear', 'out_proj']

def apply_lora(model, rank, alpha):  
    for name, layer in model.named_modules():  
        if len(list(layer.children())) == 0:  
            name_child = name.split('.')[-1] 
            if name_child[-1].isdigit():
                name_child = name_child[:-1]
            if name_child  in lora_to_replace and not isinstance(layer, LoRALayer):
                new_layer = LoRADecorator(layer, rank, alpha)
                # get the parent module  
                parent_name, child_name = name.rsplit('.', 1)  
                parent_module = reduce(getattr, parent_name.split('.'), model)  
                # replace the layer in the parent module  
                parent_module._modules[child_name] = new_layer  
    return model  


In [85]:
model = apply_lora(model, 3, 0.01)

In [None]:
model(torch.randn(10, 32, 512), memory=torch.randn(10, 32, 512))

In [75]:
(lora.lora.B @ lora.lora.A @ torch.randn(128)).shape

torch.Size([128])