# lora

 - [大模型高效微调-LoRA原理详解和训练过程深入分析](https://www.cnblogs.com/justLittleStar/p/18242820)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

lora_rank = 2
model_dim = 16

class Lora(nn.Module):
    def __init__(self, lora_rank, model_dim):
        super().__init__()
        self.A = nn.Parameter(torch.randn(model_dim, lora_rank) / torch.sqrt(torch.tensor(lora_rank)), requires_grad=True)
        self.B = nn.Parameter(torch.zeros(lora_rank, model_dim) / torch.sqrt(torch.tensor(lora_rank)), requires_grad=True)
    def forward(self, x):
        return x @ self.A @ self.B

In [2]:
lora = Lora(lora_rank, model_dim)

In [3]:
optimizer = optim.SGD(params=lora.parameters(), lr=0.01, momentum=0.9)

In [4]:
for idx in range(10):
    print('-' * 100)
    optimizer.zero_grad()
    x = torch.randn((model_dim, ))
    y = lora(x)
    z = y.mean()
    z.backward()
    print(f'step{idx} norm(lora.A.grad):', torch.norm(lora.A.grad))
    print(f'step{idx} norm(lora.B.grad):', torch.norm(lora.B.grad))
    print(f'step{idx} norm(y):', torch.norm(y))
    # print('lora.A:', lora.A)
    # print('lora.B:', lora.B)
    optimizer.step()

----------------------------------------------------------------------------------------------------
step0 norm(lora.A.grad): tensor(0.)
step0 norm(lora.B.grad): tensor(0.7942)
step0 norm(y): tensor(0., grad_fn=<LinalgVectorNormBackward0>)
----------------------------------------------------------------------------------------------------
step1 norm(lora.A.grad): tensor(0.0096)
step1 norm(lora.B.grad): tensor(0.6021)
step1 norm(y): tensor(0.0174, grad_fn=<LinalgVectorNormBackward0>)
----------------------------------------------------------------------------------------------------
step2 norm(lora.A.grad): tensor(0.0129)
step2 norm(lora.B.grad): tensor(0.0650)
step2 norm(y): tensor(0.0054, grad_fn=<LinalgVectorNormBackward0>)
----------------------------------------------------------------------------------------------------
step3 norm(lora.A.grad): tensor(0.0273)
step3 norm(lora.B.grad): tensor(0.5162)
step3 norm(y): tensor(0.0678, grad_fn=<LinalgVectorNormBackward0>)
----------------

In [5]:
class Lora2(nn.Module):
    def __init__(self, lora_rank, model_dim):
        super().__init__()
        self.A = nn.Parameter(torch.zeros(model_dim, lora_rank) / torch.sqrt(torch.tensor(lora_rank)), requires_grad=True)
        self.B = nn.Parameter(torch.randn(lora_rank, model_dim) / torch.sqrt(torch.tensor(lora_rank)), requires_grad=True)
    def forward(self, x):
        return x @ self.A @ self.B

In [6]:
lora = Lora2(lora_rank, model_dim)

In [7]:
optimizer = optim.SGD(params=lora.parameters(), lr=0.1, momentum=0.9)

In [8]:
for _ in range(10):
    print('-' * 100)
    optimizer.zero_grad()
    x = torch.randn((model_dim, ))
    y = lora(x)
    z = y.mean()
    z.backward()
    print(f'step{idx} norm(lora.A.grad):', torch.norm(lora.A.grad))
    print(f'step{idx} norm(lora.B.grad):', torch.norm(lora.B.grad))
    print(f'step{idx} norm(y):', torch.norm(y))
    # print('lora.A:', lora.A)
    # print('lora.B:', lora.B)
    optimizer.step()

----------------------------------------------------------------------------------------------------
step9 norm(lora.A.grad): tensor(0.6384)
step9 norm(lora.B.grad): tensor(0.)
step9 norm(y): tensor(0., grad_fn=<LinalgVectorNormBackward0>)
----------------------------------------------------------------------------------------------------
step9 norm(lora.A.grad): tensor(0.6017)
step9 norm(lora.B.grad): tensor(0.0074)
step9 norm(y): tensor(0.0705, grad_fn=<LinalgVectorNormBackward0>)
----------------------------------------------------------------------------------------------------
step9 norm(lora.A.grad): tensor(0.6816)
step9 norm(lora.B.grad): tensor(0.0046)
step9 norm(y): tensor(0.0440, grad_fn=<LinalgVectorNormBackward0>)
----------------------------------------------------------------------------------------------------
step9 norm(lora.A.grad): tensor(0.7764)
step9 norm(lora.B.grad): tensor(0.0587)
step9 norm(y): tensor(0.5567, grad_fn=<LinalgVectorNormBackward0>)
----------------