# Singular Value Decomposition

In [1]:
import torch
import numpy as np
_ = torch.manual_seed(0)

### Generate a rank-deficient matrix W

In [2]:
d, k = 10, 10

# We can generate a rank-deficient matrix
W_rank = 2
W = torch.randn(d, W_rank) @ torch.randn(W_rank, k)
print(W)

tensor([[-1.0797,  0.5545,  0.8058, -0.7140, -0.1518,  1.0773,  2.3690,  0.8486,
         -1.1825, -3.2632],
        [-0.3303,  0.2283,  0.4145, -0.1924, -0.0215,  0.3276,  0.7926,  0.2233,
         -0.3422, -0.9614],
        [-0.5256,  0.9864,  2.4447, -0.0290,  0.2305,  0.5000,  1.9831, -0.0311,
         -0.3369, -1.1376],
        [ 0.7900, -1.1336, -2.6746,  0.1988, -0.1982, -0.7634, -2.5763, -0.1696,
          0.6227,  1.9294],
        [ 0.1258,  0.1458,  0.5090,  0.1768,  0.1071, -0.1327, -0.0323, -0.2294,
          0.2079,  0.5128],
        [ 0.7697,  0.0050,  0.5725,  0.6870,  0.2783, -0.7818, -1.2253, -0.8533,
          0.9765,  2.5786],
        [ 1.4157, -0.7814, -1.2121,  0.9120,  0.1760, -1.4108, -3.1692, -1.0791,
          1.5325,  4.2447],
        [-0.0119,  0.6050,  1.7245,  0.2584,  0.2528, -0.0086,  0.7198, -0.3620,
          0.1865,  0.3410],
        [ 1.0485, -0.6394, -1.0715,  0.6485,  0.1046, -1.0427, -2.4174, -0.7615,
          1.1147,  3.1054],
        [ 0.9088,  

Evaluate the rank of the matrix W

In [4]:
W_rank = np.linalg.matrix_rank(W)
print(f"Rank of W: {W_rank}")

Rank of W: 2


Calculate the SVD decomposition of the W matrix

In [5]:
# Perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)

# For rank-r factorisation, keeps only the first r singular values (and correspoonding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t() # Transpose V_r to get the right dimensions

# Compute C = U_r * S_r and R = V_r
B = U_r @ S_r
A = V_r
print(f"Shape of B: {B.shape}")
print(f"Shape of A: {A.shape}")

Shape of B: torch.Size([10, 2])
Shape of A: torch.Size([2, 10])


In [9]:
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + b
y = W @ x + bias
# Compute y' = CRx + b
y_prime = (B @ A) @ x + bias

print ("Original y using W:\n", y)
print("Total parameters of W: ", W.nelement())
print("")
print("y' computed using BA:\n", y_prime)
print("Total parameters of BA: ", B.nelement() + A.nelement())

Original y using W:
 tensor([ 2.4027,  1.2373,  0.5844, -2.9918,  1.1609,  0.9889, -3.3747, -0.9453,
        -4.6008, -0.7128])
Total parameters of W:  100

y' computed using BA:
 tensor([ 2.4027,  1.2373,  0.5844, -2.9918,  1.1609,  0.9889, -3.3747, -0.9453,
        -4.6008, -0.7128])
Total parameters of BA:  40
