# LoRA

![lora](./lora.png)

![lora](./lora_formula.png)

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=16, lora_alpha=16, dropout=0.5, merge=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.dropout = dropout
        self.merge = merge
        #self.linear.weight的shape是[out_features, in_features]，而不是[in_features, out_features]
        self.linear= nn.Linear(in_features, out_features)

        if self.rank>0:
            self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
            self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
            self.scaling = self.lora_alpha / self.rank
            self.linear.weight.requires_grad=False

        if self.dropout>0:
            self.dropout = nn.Dropout(dropout)
        else:
            self.dropout = nn.Identity()

        self.initial_weight()

    def initial_weight(self):
        if self.rank>0:
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)
    
    def forward(self, x):
        if self.rank>0 and self.merge:
            output= F.linear(x,self.linear.weight+self.scaling*self.lora_B@self.lora_A,self.linear.bias)
        else:
            output= self.linear(x)

        output = self.dropout(output)
        return output

lora=LoRALayer(1,10,1,0.5,0.5)
print(lora.linear.weight.shape)

torch.Size([10, 1])
