## Singular value decomposition

In [1]:
import torch
import numpy as np
_ = torch.manual_seed(0)

In [2]:
## Generate a rank deficient matrix

d, k = 10, 10

# this way we can generate a rank deficient matrix
W_rank = 2
W = torch.randn(d, W_rank) @ torch.randn(W_rank, k)
print(W, W.shape)

tensor([[-1.0797,  0.5545,  0.8058, -0.7140, -0.1518,  1.0773,  2.3690,  0.8486,
         -1.1825, -3.2632],
        [-0.3303,  0.2283,  0.4145, -0.1924, -0.0215,  0.3276,  0.7926,  0.2233,
         -0.3422, -0.9614],
        [-0.5256,  0.9864,  2.4447, -0.0290,  0.2305,  0.5000,  1.9831, -0.0311,
         -0.3369, -1.1376],
        [ 0.7900, -1.1336, -2.6746,  0.1988, -0.1982, -0.7634, -2.5763, -0.1696,
          0.6227,  1.9294],
        [ 0.1258,  0.1458,  0.5090,  0.1768,  0.1071, -0.1327, -0.0323, -0.2294,
          0.2079,  0.5128],
        [ 0.7697,  0.0050,  0.5725,  0.6870,  0.2783, -0.7818, -1.2253, -0.8533,
          0.9765,  2.5786],
        [ 1.4157, -0.7814, -1.2121,  0.9120,  0.1760, -1.4108, -3.1692, -1.0791,
          1.5325,  4.2447],
        [-0.0119,  0.6050,  1.7245,  0.2584,  0.2528, -0.0086,  0.7198, -0.3620,
          0.1865,  0.3410],
        [ 1.0485, -0.6394, -1.0715,  0.6485,  0.1046, -1.0427, -2.4174, -0.7615,
          1.1147,  3.1054],
        [ 0.9088,  

In [3]:
# Evaluate the rank of the matrix W
W_rank = np.linalg.matrix_rank(W)
print(f"Rank of W: {W_rank}")

Rank of W: 2


In [4]:
## Perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)

# For rank-r factorization, we can truncate the SVD to the first r singular values
U_r = U[:,:W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:,:W_rank].t() # transpose V_r to get the right dimensions

# Compute B = U_r x S_r and A = V_r
B = U_r @ S_r
A = V_r
print(f"Shape of B: {B.shape}")
print(f"Shape of A: {A.shape}")

Shape of B: torch.Size([10, 2])
Shape of A: torch.Size([2, 10])


In [6]:
## Given the same input, check the output using the original W matrix and the matrices resulting from the decomposition

# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + bias
output_W = W @ x + bias

# Compute y' = (B*A)x + bias
output_BA = B @ (A @ x) + bias

# Check if the outputs are the same
print(f"Output using W: {output_W}")
print(f"Output using BA: {output_BA}")

# Find the differences between the two outputs
diff = torch.norm(output_W - output_BA)
print(f"Difference between the two outputs: {diff}")

Output using W: tensor([-4.2808, -1.2647,  2.9707, -0.9784,  0.9193,  3.5211,  5.3049,  5.5870,
         1.8145,  4.7644])
Output using BA: tensor([-4.2808, -1.2647,  2.9707, -0.9784,  0.9193,  3.5211,  5.3049,  5.5870,
         1.8145,  4.7644])
Difference between the two outputs: 4.846704541705549e-06


In [8]:
print("Total params in W:", W.numel())
print("Total params in B and A:", B.numel() + A.numel())

Total params in W: 100
Total params in B and A: 40
