In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.nn.functional as F

In [2]:
torch.__version__

'1.10.1'

In [3]:
def hamming_score(y_pred, y_true):
    out = ((y_pred & y_true).sum(dim=1) / (y_pred | y_true).sum(dim=1)).mean()
    if out.isnan():
        out = torch.tensor(1.0)
    return out

In [4]:
class SVD(nn.Module):
    """
    Singular value decomposition layer
    
    Examples
    --------
    >>> A = torch.rand(126, 100, 20).to('cuda')
    >>> U, S, V = SVD(compute_uv=True)(A)
    >>> A_ = torch.matmul(U, torch.matmul(S, V.transpose(-1, -2)))
    >>> print(torch.dist(A_, A))
    """
    def __init__(self):
        super(SVD, self).__init__()
    
    def forward(self, A):
        """
        Inputs
        ------
        A: [b, m, n]
        
        Outputs
        -------
        U: [b, m, n]
        S: [b, n, n]
        V: [b, n, n]
        """
        return self.svd_(A)
        
    @staticmethod
    def svd_(A):
        """
        Parameters
        ----------
        A: torch.FloatTensor
            A tensor of shape [b, m, n].

        Returns
        -------
        U: [b, m, n]
        S: [b, n, n]
        V: [b, n, n]

        References
        ----------
        1. https://www.youtube.com/watch?v=pSbafxDHdgE&t=205s
        2. https://www2.math.ethz.ch/education/bachelor/lectures/hs2014/other/linalg_INFK/svdneu.pdf
        """
        ATA = torch.matmul(A.transpose(-1, -2), A)
        lv, vv = torch.linalg.eig(ATA)
        lv = lv.real
        vv = vv.real
        V = F.normalize(vv, dim=1)
        S = torch.diag_embed(torch.sqrt(lv))
        U = torch.matmul(torch.matmul(A, V), torch.inverse(S))
        return U, S, V
            
def singular_value_cumsum(S):
    """
    S: [b, n, n]
        S is a diagonal matrix whose off-diagonal entries are all equal to zero.
    """
    numerator = torch.diagonal(S, dim1=-2, dim2=-1).cumsum(dim=1)
    denominator = torch.diagonal(S, dim1=-2, dim2=-1).sum(dim=1, keepdim=True)
    return torch.div(numerator, denominator)

In [5]:
A = torch.rand(16, 512, 10)
A = A.to('cuda')
U, S, V = SVD()(A)

In [6]:
print(U.shape, S.shape, V.shape)

torch.Size([16, 512, 10]) torch.Size([16, 10, 10]) torch.Size([16, 10, 10])


In [7]:
A_ = torch.matmul(U, torch.matmul(S, V.transpose(-1, -2)))
torch.dist(A_, A)

tensor(0.0012, device='cuda:0')