In [1]:
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F

In [2]:
torch.__version__

'1.10.1'

In [3]:
def svd(A):
    """
    Parameters
    ----------
    A: torch.FloatTensor
        A tensor of shape [b, m, n].
    
    Returns
    -------
    U: [b, m, n]
    S: [b, n, n]
    V: [b, n, n]
    
    References
    ----------
    1. https://www.youtube.com/watch?v=pSbafxDHdgE&t=205s
    2. https://www2.math.ethz.ch/education/bachelor/lectures/hs2014/other/linalg_INFK/svdneu.pdf
    """
    ATA = torch.matmul(A.transpose(-1, -2), A)
    lv, vv = torch.linalg.eig(ATA)
    lv = lv.real
    vv = vv.real
    V = F.normalize(vv, dim=1)
    S = torch.diag_embed(torch.sqrt(lv))
    U = torch.matmul(torch.matmul(A, V), torch.inverse(S))
    # A_ = torch.matmul(U, torch.matmul(S, V.transpose(-1, -2)))
    # torch.dist(A_, A)
    return U, S, V

In [4]:
A = torch.randn(10, 4, 3)
U, S, V = svd(A)

In [5]:
U.shape, S.shape, V.shape

(torch.Size([10, 4, 3]), torch.Size([10, 3, 3]), torch.Size([10, 3, 3]))

In [6]:
A_ = torch.matmul(U, torch.matmul(S, V.transpose(-1, -2)))
torch.dist(A_, A)

tensor(2.1643e-06)