<h3><u>Objectives definition for CCA<br>

In [None]:
import torch

<br>Following functions have been defined:<br>

<b>attach_dim():</b>To alter the dimensions of the nn<br>
<b>compute_matrix_power()</b>: To compute the power of the matrix for understanding the loss and optimization required for better understanding of correlation.

In [None]:
def attach_dim(v, n_dim_to_prepend=0, n_dim_to_append=0):
    return v.reshape(
        torch.Size([1] * n_dim_to_prepend)
        + v.shape
        + torch.Size([1] * n_dim_to_append))


def compute_matrix_power(M, p, eps):
    [D, V] = torch.symeig(M, eigenvectors=True)
    # Added to increase stability
    posInd1 = torch.gt(D, eps).nonzero()[:, 0]
    D = D[posInd1]
    V = V[:, posInd1]
    M_p = torch.matmul(
        torch.matmul(V, torch.diag(torch.pow(D, p))), V.t())
    return M_p



<br>The CCA class defines the attributes and functions of a CCA object



In [None]:
class CCA:
    
    """
    Differentiable CCA Loss.
    Loss() method takes the outputs of each view's network
    
    """

    def __init__(self, outdim_size: int, r: float = 0, eps: float = 0):
        
        """
        :param outdim_size: the number of latent dimensions
        :param r: regularisation as in regularized CCA. 
        :param eps: an epsilon parameter used in some operations
        
        """
        self.outdim_size = outdim_size
        self.r = r
        self.eps = eps

    def loss(self, H1, H2):
        H1, H2 = H1.t(), H2.t()

        o1 = H1.size(0)
        o2 = H2.size(0)

        m = H1.size(1)

        H1bar = H1 - H1.mean(dim=1).unsqueeze(dim=1)
        H2bar = H2 - H2.mean(dim=1).unsqueeze(dim=1)

        SigmaHat12 = (1.0 / (m - 1)) * torch.matmul(H1bar, H2bar.t())
        SigmaHat11 = (1.0 / (m - 1)) * torch.matmul(H1bar,
                                                    H1bar.t()) + self.r * torch.eye(o1, dtype=torch.double,
                                                                                    device=H1.device)
        SigmaHat22 = (1.0 / (m - 1)) * torch.matmul(H2bar,
                                                    H2bar.t()) + self.r * torch.eye(o2, dtype=torch.double,
                                                                                    device=H2.device)

        SigmaHat11RootInv = compute_matrix_power(SigmaHat11, -0.5, self.eps)
        SigmaHat22RootInv = compute_matrix_power(SigmaHat22, -0.5, self.eps)

        Tval = torch.matmul(torch.matmul(SigmaHat11RootInv,
                                         SigmaHat12), SigmaHat22RootInv)

        # just the top self.outdim_size singular values are used
        trace_TT = torch.matmul(Tval.t(), Tval)
        U, V = torch.symeig(trace_TT, eigenvectors=True)
        U_inds = torch.gt(U, self.eps).nonzero()[:, 0]
        U = U[U_inds]
        corr = torch.sum(torch.sqrt(U))
        return -corr
