# Rotary Lora

In [1]:
import torch
import torch.nn as nn

class RotaryLora(nn.Module):
    def __init__(self, model_dim, lora_rank):
        super().__init__()
        assert lora_rank % 2 == 0

        self.model_dim = model_dim
        self.lora_rank = lora_rank

        self.A = nn.Parameter(torch.empty(model_dim, lora_rank))
        self.B = nn.Parameter(torch.empty(lora_rank, model_dim))

        self._init_parameters()

    def _init_parameters(self):
        half_rank = self.lora_rank // 2
        std = torch.sqrt(torch.tensor(self.lora_rank))

        half_A = torch.randn((self.model_dim, half_rank)) / std
        self.A.data = torch.cat([half_A, half_A], dim=-1)

        half_B = torch.randn((half_rank, self.model_dim)) / std
        self.B.data = torch.cat([-half_B, half_B], dim=0)

    def forward(self, x):
        # Make sure input x has shape (batch_size, sequence_length, model_dim)
        return torch.matmul(torch.matmul(x, self.A), self.B)

In [2]:
model_dim = 8
lora_rank = 8096
rotary_lora = RotaryLora(model_dim=model_dim, lora_rank=lora_rank)

In [3]:
rotary_lora.A @ rotary_lora.B < 1e-7

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [4]:
torch.norm(rotary_lora.A, dim=1)

tensor([0.9893, 1.0137, 0.9990, 0.9998, 0.9942, 1.0083, 1.0102, 1.0074],
       grad_fn=<LinalgVectorNormBackward0>)

In [5]:
torch.norm(rotary_lora.B.T, dim=1)

tensor([0.9798, 1.0130, 0.9973, 0.9792, 0.9905, 1.0020, 1.0121, 1.0138],
       grad_fn=<LinalgVectorNormBackward0>)

In [6]:
import torch.optim as optim

optimizer = optim.SGD(params=rotary_lora.parameters(), lr=0.01, momentum=0.9)

for idx in range(10):
    print('-' * 100)
    optimizer.zero_grad()
    x = torch.randn((model_dim, ))
    y = rotary_lora(x)
    z = y.mean()
    z.backward()
    print(f'step{idx} norm(lora.A.grad):', torch.norm(rotary_lora.A.grad))
    print(f'step{idx} norm(lora.B.grad):', torch.norm(rotary_lora.B.grad))
    print(f'step{idx} norm(y):', torch.norm(y))
    # print('lora.A:', rotary_lora.A)
    # print('lora.B:', rotary_lora.B)
    optimizer.step()

----------------------------------------------------------------------------------------------------
step0 norm(lora.A.grad): tensor(1.0158)
step0 norm(lora.B.grad): tensor(1.0171)
step0 norm(y): tensor(6.9340e-08, grad_fn=<LinalgVectorNormBackward0>)
----------------------------------------------------------------------------------------------------
step1 norm(lora.A.grad): tensor(0.5577)
step1 norm(lora.B.grad): tensor(0.5593)
step1 norm(y): tensor(0.0062, grad_fn=<LinalgVectorNormBackward0>)
----------------------------------------------------------------------------------------------------
step2 norm(lora.A.grad): tensor(1.0429)
step2 norm(lora.B.grad): tensor(1.0445)
step2 norm(y): tensor(0.0551, grad_fn=<LinalgVectorNormBackward0>)
----------------------------------------------------------------------------------------------------
step3 norm(lora.A.grad): tensor(0.8516)
step3 norm(lora.B.grad): tensor(0.8463)
step3 norm(y): tensor(0.0544, grad_fn=<LinalgVectorNormBackward0>)
----