In [1]:
# Sigular Value Decompostion
import torch
import numpy as np

_ = torch.manual_seed(0)

In [2]:
# generate new matrix with fixed dimensions and rank
d, k = 10, 10
w_rank = 2
W = torch.randn(d, w_rank) @ torch.randn(w_rank, k)
W

tensor([[-1.0797,  0.5545,  0.8058, -0.7140, -0.1518,  1.0773,  2.3690,  0.8486,
         -1.1825, -3.2632],
        [-0.3303,  0.2283,  0.4145, -0.1924, -0.0215,  0.3276,  0.7926,  0.2233,
         -0.3422, -0.9614],
        [-0.5256,  0.9864,  2.4447, -0.0290,  0.2305,  0.5000,  1.9831, -0.0311,
         -0.3369, -1.1376],
        [ 0.7900, -1.1336, -2.6746,  0.1988, -0.1982, -0.7634, -2.5763, -0.1696,
          0.6227,  1.9294],
        [ 0.1258,  0.1458,  0.5090,  0.1768,  0.1071, -0.1327, -0.0323, -0.2294,
          0.2079,  0.5128],
        [ 0.7697,  0.0050,  0.5725,  0.6870,  0.2783, -0.7818, -1.2253, -0.8533,
          0.9765,  2.5786],
        [ 1.4157, -0.7814, -1.2121,  0.9120,  0.1760, -1.4108, -3.1692, -1.0791,
          1.5325,  4.2447],
        [-0.0119,  0.6050,  1.7245,  0.2584,  0.2528, -0.0086,  0.7198, -0.3620,
          0.1865,  0.3410],
        [ 1.0485, -0.6394, -1.0715,  0.6485,  0.1046, -1.0427, -2.4174, -0.7615,
          1.1147,  3.1054],
        [ 0.9088,  

In [3]:
r = np.linalg.matrix_rank(W)
r

2

##### Calculate SVD of the W matrix

In [4]:
# svd will result in 3 different matrix that gives matrix W when multiplied. 
U, S, V = torch.svd(W)
# For rank-r factorization, keepy only th first r singular values (and corresponding columns of U and V)
U_r = U[:, :w_rank]
S_r = torch.diag(S[:w_rank])
V_r = V[:, :w_rank].t() # Transpose V_r to get the right dimensions

# compute C = U_r * S_r and R = V_r
B = U_r @ S_r
A = V_r

print(f"B shape: {B.shape}")
print(f"A shape: {A.shape}")

B shape: torch.Size([10, 2])
A shape: torch.Size([2, 10])


- Given same input, check the outpout using the original W matrix and the matrices resulting from the decomposition

In [6]:
# Generate random bias and input 
bias = torch.randn(d)
x = torch.randn(d)

# compute y = wx+b
y = W @ x + bias
# compuote y' = CRx+b
y_prime = (B@A) @ x + bias

print(f"Original y using W: \n {y}")
print(f"\n Y' computed  usin BxA: \n {y_prime}")

Original y using W: 
 tensor([-4.2808, -1.2647,  2.9707, -0.9784,  0.9193,  3.5211,  5.3049,  5.5870,
         1.8145,  4.7644])

 Y' computed  usin BxA: 
 tensor([-4.2808, -1.2647,  2.9707, -0.9784,  0.9193,  3.5211,  5.3049,  5.5870,
         1.8145,  4.7644])


In [7]:
# parameters of the matrix
print(f"Total parameters of W: {W.nelement()}")
print(f"Total parameters of B and A: {B.nelement() + A.nelement()}")

Total parameters of W: 100
Total parameters of B and A: 40
