# Singular Value Decomposition (SVD)



In [1]:
import torch 
import numpy as np

_ = torch.manual_seed(42)

Generate a rank deficient matrix W

In [2]:
d, k = 10, 10

# Generate a rank deficient matrix
W_rank = 2
W = torch.randn(d, W_rank) @ torch.randn(W_rank, k)
print(W)

tensor([[ 2.6746e+00, -1.4811e+00, -1.1993e+00, -4.3986e-01, -3.2132e+00,
         -9.8070e-01,  5.2384e+00,  1.1440e+00,  3.4292e-02, -8.0985e-01],
        [-2.6293e+00,  5.5638e-01,  2.8071e+00, -2.1880e+00,  1.5113e+00,
         -1.9586e+00, -8.4328e-01, -1.0369e+00,  1.2841e+00,  1.7830e+00],
        [ 1.2519e+00, -1.5893e-01, -1.5284e+00,  1.3505e+00, -5.2550e-01,
          1.2769e+00, -1.0580e-01,  4.8339e-01, -7.6663e-01, -9.6521e-01],
        [ 2.5833e+00, -1.0893e+00, -1.7760e+00,  5.6927e-01, -2.4785e+00,
          1.6158e-01,  3.4259e+00,  1.0717e+00, -4.6680e-01, -1.1566e+00],
        [-8.3761e-01,  1.5506e-01,  9.3442e-01, -7.6165e-01,  4.4083e-01,
         -6.9603e-01, -1.6246e-01, -3.2818e-01,  4.4156e-01,  5.9234e-01],
        [ 1.7718e-01,  1.0177e-01, -4.4121e-01,  5.5308e-01,  1.5319e-01,
          5.8442e-01, -6.0979e-01,  5.6301e-02, -2.9052e-01, -2.7291e-01],
        [-1.6017e-02,  7.9710e-02, -1.2103e-01,  2.0897e-01,  1.4897e-01,
          2.3602e-01, -3.7046e-0

In [3]:
W_rank = np.linalg.matrix_rank(W)
print(f"Rank of W: {W_rank}")

Rank of W: 2


Calculate the SVD decomposition of the W matrix

In [4]:
# Perform svd on W
U, S, V = torch.svd(W)
print(U.shape, S.shape, V.shape)

# For rank r-factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()
print(U_r.shape, S_r.shape, V_r.shape)

B = U_r@S_r
A = V_r 
print(f"Shape of B: {B.shape}")
print(f"Shape of A: {A.shape}")




torch.Size([10, 10]) torch.Size([10]) torch.Size([10, 10])
torch.Size([10, 2]) torch.Size([2, 2]) torch.Size([2, 10])
Shape of B: torch.Size([10, 2])
Shape of A: torch.Size([2, 10])


In [8]:
bias = torch.randn(d)
x = torch.randn(d)

y = W @ x + bias 
y_prime = (B @ A) @ x + bias 

print("Original y using W:\n", y)
print("")
print("y' computed using BA:\n", y_prime)

Original y using W:
 tensor([ 8.7740, -3.2261,  0.4063,  7.2494, -0.3587, -0.8469, -1.2226,  1.8885,
         0.1394,  0.2454])

y' computed using BA:
 tensor([ 8.7740, -3.2261,  0.4063,  7.2494, -0.3587, -0.8469, -1.2226,  1.8885,
         0.1394,  0.2454])


In [15]:
print(W.nelement())

100


In [16]:
print(A.nelement() + B.nelement())

40


In [17]:
import torchvision