<h3><u>Deep Canonical Correlation Analysis- Implementation<br>

This code contains the implementation of DCCA.
The libaries used are those of torch:<br><br>
<b>matmul</b>: Used for matrix multiplication of tensors <br>
<b>optim</b>: Used for algorithm optimization<br>
<b>nn</b>: for creating and modelling the trainig data into a neural network with configurations mentioned in the config file.<br>
<b>functional</b>: Applies a 1D convolution over an input signal composed of several input planes.<br>

In [None]:
from torch import nn
from torch import optim, matmul
from torch.nn import functional as F
from DCCAE_repo.configuration import Config
from DCCAE_repo.objectives import compute_matrix_power

<br>All of the deep architectures have forward methods inherited from pytorch as well as the methods:

<b>loss()</b>: which calculates the loss given some inputs and model outputs i.e. loss(inputs,model(inputs))

This allows us to wrap them all up in the deep wrapper.It is useful
for standardising the pipeline for comparison<br>

<br><b>We use the following functions:</b>

create_encoder(): For creating the encoder with the given input parameters specified in the config file

In [None]:
def create_encoder(config, i):
    encoder = config.encoder_models[i](config.hidden_layer_sizes[i], config.input_sizes[i], config.latent_dims).double()
    return encoder

<br>Here we have defined a class DCCA, containing the methods for conduction of multi-view analysis of non-linear data.<br><br>
<b>__init__:</b> Here all the parameteres for class objects are assigned. Constructor.<br>
<b>encode:</b> Data input for encoding<br>
<b>forward:</b> Data is forwarded to the encoder<br>
<b>update_weights_tn:</b> To add or alter the associated weights of the nn.<br>
<b>tn_loss:</b> To compute the net loss.<br>
<b>update_weights_als:</b> Update weights for Alternating Least Squares method.<br>
<b>als_loss:</b> Computing loss for Alternating Least Squares method.<br>
<b>als_loss_validation:</b> To calculate the loss for Alternating Least Squares methodAlternating Least Squares method.<br>
<b>update_covariances:</b> Update the net Covariances.<br>

In [None]:

class DCCA(nn.Module):

    def __init__(self, config: Config = Config):
        super(DCCA, self).__init__()
        views = len(config.encoder_models)
        self.config = config
        self.encoders = nn.ModuleList([create_encoder(config, i) for i in range(views)])
        self.objective = config.objective(config.latent_dims)
        self.optimizers = [optim.Adam(list(encoder.parameters()), lr=config.learning_rate) for encoder in self.encoders]
        self.covs = None
        if config.als:
            self.update_weights = self.update_weights_als
            self.loss = self.als_loss_validation
        else:
            self.update_weights = self.update_weights_tn
            self.loss = self.tn_loss

    def encode(self, *args):
        z = []
        for i, arg in enumerate(args):
            z.append(self.encoders[i](arg))
        return tuple(z)

    def forward(self, *args):
        z = self.encode(*args)
        return z

    def update_weights_tn(self, *args):
        [optimizer.zero_grad() for optimizer in self.optimizers]
        loss = self.tn_loss(*args)
        loss.backward()
        [optimizer.step() for optimizer in self.optimizers]
        return loss

    def tn_loss(self, *args):
        z = self(*args)
        return self.objective.loss(*z)

    def update_weights_als(self, *args):
        loss_1, loss_2 = self.als_loss(*args)
        self.optimizers[0].zero_grad()
        loss_1.backward()
        self.optimizers[0].step()
        self.optimizers[1].zero_grad()
        loss_2.backward()
        self.optimizers[1].step()
        return (loss_1 + loss_2) / 2 - self.config.latent_dims

    def als_loss(self, *args):
        z = self(*args)
        self.update_covariances(*z)
        covariance_inv = [compute_matrix_power(cov, -0.5, self.config.eps) for cov in self.covs]
        preds = [matmul(z, covariance_inv[i]).detach() for i, z in enumerate(z)]
        
        # Least squares for each projection in same manner as linear from before
        # Currently 2 view case
        losses = [F.mse_loss(preds[-i], z) for i, z in enumerate(z)]
        return losses

    def als_loss_validation(self, *args):
        z = self(*args)
        SigmaHat11RootInv = compute_matrix_power(self.covs[0], -0.5, self.config.eps)
        SigmaHat22RootInv = compute_matrix_power(self.covs[1], -0.5, self.config.eps)
        pred_1 = (z[0] @ SigmaHat11RootInv).detach()
        pred_2 = (z[1] @ SigmaHat22RootInv).detach()
        # Least squares for each projection in same manner as linear from before
        
        loss_1 = F.mse_loss(pred_1, z[1])
        loss_2 = F.mse_loss(pred_2, z[0])
        return (loss_1 + loss_2) / 2 - self.config.latent_dims

    def update_covariances(self, *args):
        b = args[0].shape[0]
        batch_covs = [z_i.T @ z_i / b for i, z_i in enumerate(args)]
        if self.covs is not None:
            self.covs = [(self.config.rho * self.covs[i]).detach() + (1 - self.config.rho) * batch_cov for i, batch_cov
                         in
                         enumerate(batch_covs)]
        else:
            self.covs = batch_covs
