# svd-helpers

> Helper functions related to [Singular Value Decomposition](https://jeremykun.com/2016/04/18/singular-value-decomposition-part-1-perspectives-on-linear-algebra/) and its applications. 

In [None]:
#| default_exp common.svd_helpers

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
from fastcore.test import *

In [None]:
#| export
import torch
from torch.nn import functional as F


In [None]:
# | export
def adjust_singular_vector_sign(
    singular_vector: torch.Tensor, original_matrix: torch.Tensor
) -> torch.Tensor:
    """Depending on the algorithm used to compute the SVD, the sign of the singular
    vectors can be flipped. This function adjusts the sign of the singular vector so
    that it aligns with the majority of the vectors in the original matrix. Per
    https://www.osti.gov/servlets/purl/920802, this is a valid way to resolve the
    sign ambiguity."""
    assert singular_vector.ndim == 1
    assert original_matrix.ndim == 2
    assert singular_vector.shape[0] == original_matrix.shape[1]

    n_negatives = torch.count_nonzero(
        F.cosine_similarity(original_matrix, singular_vector.unsqueeze(dim=0)) < 0
    )
    sign = -1 if n_negatives > original_matrix.shape[0] / 2 else 1
    return sign * singular_vector

In [None]:
# Tests for adjust_singular_vector_sign
test_matrix = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
test_singular_vector = torch.tensor([-0.4797, -0.5724, -0.6651])

test_eq(
    adjust_singular_vector_sign(test_singular_vector, test_matrix),
    -test_singular_vector,
)

test_eq(
    adjust_singular_vector_sign(-test_singular_vector, test_matrix),
    -test_singular_vector,
)

In [None]:
#| export
def projection_matrix_for_rank_k_approximation(
    original_matrix: torch.Tensor, k: int
) -> torch.Tensor:
    """Returns a projection matrix that projects onto the subspace spanned by the top
    k singular vectors of the original matrix. Derivation of the formula:
    https://ocw.mit.edu/courses/18-06sc-linear-algebra-fall-2011/00e9c8f0eafedeab21a3d079a17ed3d8_MIT18_06SCF11_Ses2.2sum.pdf"""
    assert original_matrix.ndim == 2
    assert k > 0 and k <= original_matrix.shape[1]

    _, _, V = torch.linalg.svd(original_matrix, full_matrices=True)
    basis_vectors = []
    for i in range(k):
        basis_vectors.append(adjust_singular_vector_sign(V[i], original_matrix))

    A = torch.stack(basis_vectors).T

    return A @ (A.T @ A).inverse() @ A.T

In [None]:
# Tests for projection_matrix_for_rank_k_approximation

# Make up a test matrix where the singular vectors are just the standard
# basis vectors for R^3
test_matrix = torch.tensor([
    [3, 0, 0],
    [0, 2, 0],
    [0, 0, 1],
], dtype=torch.float32)

test_vector = torch.tensor([1, 2, 3], dtype=torch.float32)

# Test rank 1 approximation
proj_matrix = projection_matrix_for_rank_k_approximation(test_matrix, 1)
projection = proj_matrix @ test_vector

# The projection should be the projection of the test vector onto the x-axis
e_1 = torch.tensor([1, 0, 0], dtype=torch.float32)
test_eq(projection, test_vector.dot(e_1) * e_1)

# Test rank 2 approximation
proj_matrix = projection_matrix_for_rank_k_approximation(test_matrix, 2)
projection = proj_matrix @ test_vector

# The projection should be the projection of the test vector onto the x-y plane
e_1 = torch.tensor([1, 0, 0], dtype=torch.float32)
e_2 = torch.tensor([0, 1, 0], dtype=torch.float32)
test_eq(projection, test_vector.dot(e_1) * e_1 + test_vector.dot(e_2) * e_2)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()