In [1]:
import torch
import torch.nn as nn

In [None]:
class DAGMM(nn.Module):

    def __init__(self, enc_dims, estim_dims, types):
        super(DAGMM,self).__init__()

        self.encoder, self.decoder = self.construct_ae(enc_dims)
        self.estim_net = self.construct_estim_net(estim_dims, types)

    def construct_ae(self, dimensions):
        enc_layers = []
        n_dims = len(dimensions)
        for i in range(n_dims):
            enc_layers.append(nn.Linear(dimensions[i][0], dimensions[i][1]))
            if i < n_dims-1:
                enc_layers.append(nn.Tanh())

        rev_dims = dimensions[::-1]
        dec_layers = []
        for i in range(n_dims):
            dec_layers.append(nn.Linear(rev_dims[i][1], rev_dims[i][0]))
            if i < n_dims-1:
                dec_layers.append(nn.Tanh())

        return nn.Sequential(*enc_layers), nn.Sequential( *dec_layers)
    
    def construct_estim_net(self, dimensions, types):
        layers = []
        n_dims = len(dimensions)
        for i in range(n_dims):
            if types[i] == 'Linear':
                if i == 0:
                    layers.append(nn.Linear(dimensions[i][0] + 2, dimensions[i][1]))
                else:
                    layers.append(nn.Linear(dimensions[i][0], dimensions[i][1]))
                if i < n_dims-1:
                    layers.append(nn.Tanh())
                else:
                    layers.append(nn.Softmax(dim = -1))
            if types[i] == 'drop':
                layers.append(nn.Dropout(p=dimensions[i]))
        
        return nn.Sequential(*layers)

    def compute_gmm_params(self, z, gamma):
        N = gamma.size(0)
        sum_gamma = torch.sum(gamma, dim=0)

        # Phy reppresent the importance of the components
        self.phi = (sum_gamma / N).detach() # K

        # Mu are the expected values of the mixture
        self.mu = (torch.sum(gamma.unsqueeze(-1)*z.unsqueeze(1), dim=0) / sum_gamma.unsqueeze(1)).detach() # K x D

        # Sigma are the covariance matrixes of the mixture
        z_mu = z.unsqueeze(1) - self.mu.unsqueeze(0) # N x K x D

        z_outer = z_mu.unsqueeze(-1) * z_mu.unsqueeze(-2) # N x K x D x D

        self.cov = (torch.sum(gamma.unsqueeze(-1).unsqueeze(-1) * z_outer, dim=0) / sum_gamma.unsqueeze(-1).unsqueeze(-1)).detach() # K x D x D

    def compute_energy(self, z, flag=True):

        z_mu = z.unsqueeze(1) - self.mu.unsqueeze(0) # N x K x D

        # Inverse matrix 
        eps = 1e-12 # Done for numerical errors
        eye = torch.eye(self.cov.size(-1)).to(self.cov.device)
        cov_reg = self.cov + (eye * eps)
        if flag:
            self.cov_inverse = torch.linalg.inv(cov_reg) # K x D x D

        # Log determnant calculation for numerical stability
        log_det = 0.5 * torch.linalg.slogdet(2 * torch.pi * cov_reg)[1]

        # Exponential term autonormalized for stability
        malan = -0.5 * torch.sum(torch.sum(z_mu.unsqueeze(-1) * self.cov_inverse.unsqueeze(0), dim=-2) * z_mu, dim=-1)
        exponent = torch.log(self.phi + eps).unsqueeze(0) + malan - log_det.unsqueeze(0) # K
        max_val = torch.max(exponent, dim=1, keepdim=True)[0]
        exp_term = torch.exp(exponent - max_val)

        energy = - (max_val.squeeze() + torch.log(torch.sum(exp_term, dim=1) + eps))

        # Computing the regularization penality for loss
        diagonals = torch.diagonal(cov_reg, dim1=-2, dim2=-1)
        p_sigma = torch.sum(1.0 / (diagonals + eps))

        if flag:
            return torch.mean(energy), p_sigma
        else:
            return energy, p_sigma

    def con_loss(self, x, x_rec, z, gamma, lam_energy, lam_cov):
        rec_err = torch.mean((x-x_rec)**2)

        self.compute_gmm_params(z, gamma)

        energy, p_sigma = self.compute_energy(z)
        
        return rec_err + lam_energy * energy + lam_cov * p_sigma

    def get_estimation_input(self, x, x_rec, z_c):
        # Compute the inpute vector to the estimation net
        # Euclidean dist
        euclidean_dist = torch.norm(x - x_rec, p=2, dim=1, keepdim=True) / torch.norm(x, p=2, dim=1, keepdim=True)
        
        # Cosine Similarity
        cosine_sim = nn.functional.cosine_similarity(x, x_rec, dim=1).unsqueeze(1)
        
        z_combined = torch.cat([z_c, euclidean_dist, cosine_sim], dim=1)
        return z_combined

    def predict(self, x):
        self.eval()
        with torch.no_grad():
            z = self.encoder(x)
            x_rec = self.decoder(z)
            est_in = self.get_estimation_input(x, x_rec, z)

        return self.compute_energy(est_in, False)




    def forward(self, X):
        pass





