### Math intuition for LoRA

This scratchpad goes over the intuition of how, for a given rank-deficient matrix (i.e our layer weight W), we can capture most of the 'information' of W using 2 lower rank matrices

In [1]:
import torch as t
import numpy as np
t.manual_seed(0)

<torch._C.Generator at 0x7fcbe471cf30>

In [None]:
# W matrix dimensions
d, k = 10, 10

# generate a rank-deficient matrix
W_rank = 2
W = t.randn(d, W_rank) @ t.randn(W_rank, k)

W

tensor([[ 2.8501, -4.1679, -1.2931, -1.7376, -2.5698, -3.2220, -1.4271, -1.2982,
          0.2702,  1.2163],
        [ 3.2737, -4.7411, -1.4644, -1.9621, -2.9216, -3.6760, -1.6166, -1.4949,
          0.2975,  1.3819],
        [-0.0141, -3.3560, -1.5177, -2.4550, -2.1852, -1.7979, -1.6433,  0.2801,
          0.9375,  1.1010],
        [-0.8365,  0.4910,  0.0490, -0.0243,  0.2776,  0.5523,  0.0609,  0.4404,
          0.1243, -0.1169],
        [-3.9740, -0.6857, -1.1295, -2.3176, -0.6460,  1.0025, -1.1858,  2.3367,
          1.4298,  0.4341],
        [ 0.7376, -0.9989, -0.2987, -0.3915, -0.6132, -0.7910, -0.3304, -0.3424,
          0.0478,  0.2886],
        [-2.2472,  1.8582,  0.3750,  0.3281,  1.0966,  1.7733,  0.4272,  1.1393,
          0.1840, -0.4908],
        [ 0.7821, -0.5984, -0.1087, -0.0790, -0.3502, -0.5912, -0.1251, -0.4004,
         -0.0775,  0.1550],
        [-0.0482, -0.4016, -0.1912, -0.3150, -0.2638, -0.1991, -0.2066,  0.0602,
          0.1267,  0.1342],
        [ 0.6151, -

In [5]:
print(f"rank of W: {np.linalg.matrix_rank(W)}")
print(f"shape of W: {W.shape}")

rank of W: 2
shape of W: torch.Size([10, 10])


Calculate SVD

In [None]:
# W = UxSxV^T
U, S, V = t.svd(W)

# if we just take the first r columns, i.e the rank of the original matrix, we can capture most of the information
U_r = U[:, :W_rank]
S_r = t.diag(S[:W_rank])
V_r = V[:, :W_rank].t()

B = U_r @ S_r
A = V_r

print(f"shape of B: {B.shape}")
print(f"shape of A: {A.shape}")

shape of B: torch.Size([10, 2])
shape of A: torch.Size([2, 10])


Given the same input, check the output comparing W and B @ A

In [None]:
bias = t.randn(d)
x = t.randn(d)

# using original W
y_ref = W @ x + bias

# using B & A matrices
y = (B @ A) @ x + bias

print(f"is similar: {t.allclose(y_ref, y)}")
print(f"num elem W: {W.numel()}")
print(f"num elem B + A: {B.numel() + A.numel()}")

is similar: True
num elem W: 100
num elem B + A: 40
