In [1]:
"""
@file metaprior_analysis.ipynb
@author Ryan Missel

This file is to deep dive into the code written for the MetaPrior and making sure that each component
is functioning as it should, specifically how the codes are sampled and built at each iteration.

In non-binary classification tasks, the model is having troubles converging - lending to the idea that
perhaps the code buildup is incorrect. The classifier often first converges at linear decision boundaries
through the classes and cannot converge nicely to a cross in four classes. Convergence for the 3 class problem
takes many iterations.
"""
import time
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as func
from tqdm.notebook import tqdm

from IPython.display import clear_output
import matplotlib.pyplot as plt

from scipy.stats import entropy
from metaprior.metautils import get_act, plot_metaspace_mean, plot_metaspace_var, plot_weight_correlations
from torch.distributions import Normal, kl_divergence as kl

In [2]:
class MetaPrior(nn.Module):
    def __init__(self, layer_sizes=[1, 100, 1], mean=False, local=False, activation='linear', hyperprior_dim=32, code_dim=2):
        super(MetaPrior, self).__init__()
        self.input_dim = layer_sizes[0]
        self.layer_sizes = layer_sizes
        self.code_dim = code_dim

        self.local = local
        self.mean = mean

        self.acts = [get_act('leaky_relu') if i < len(layer_sizes) - 2 else get_act(activation)
                         for i in range(len(layer_sizes) - 1)]

        self.embedder = nn.Sequential(
            nn.Linear(self.input_dim, 20),
            nn.LeakyReLU(),
            nn.Linear(20, code_dim)
        )
        
        self.weight_codes = []
        self.bias_codes = []

        # Define the initial weight code distribution parameters
        self.code_mu = nn.ParameterList([
            torch.nn.Parameter(
                torch.zeros([lsize, code_dim]), #+ 0.1 * torch.randn([lsize, code_dim], requires_grad=True),
                requires_grad=True)
            for lsize in self.layer_sizes
        ])

        self.code_var = nn.ParameterList([
            torch.nn.Parameter(torch.ones([lsize, code_dim]), # + 0.1 * torch.randn([lsize, code_dim], requires_grad=True),
                               requires_grad=True)
            for lsize in self.layer_sizes
        ])

        self.codes = [
            torch.randn([lsize, code_dim]) for lsize in self.layer_sizes
        ]


        # Define the hyperprior network that generates the distribution parameters of the Weights
        self.hyperprior = nn.Sequential(
            nn.Linear(code_dim * 2, 64),
            nn.LeakyReLU(),
            nn.Linear(64, hyperprior_dim)
        )

        self.mean_net = nn.Linear(hyperprior_dim, 1)
        self.var_net = nn.Linear(hyperprior_dim, 1)

    def generate_weight_codes(self):
        """
        Handles building the weight codes and draw samples from i
        :return:
        """
        # Sample the codes array
        self.codes = [
            self.code_mu[i] + torch.randn_like(self.code_mu[i]) * self.code_var[i]
            for i in range(len(self.layer_sizes))
        ]

        self.weight_codes = []
        self.bias_codes = []

        # Loop between the layers and generate their weight codes by concatenating each units' latent var.
        # Units in the smaller layer need to be duplicated to the size of the next layer in order to perform
        # easy concatenation between their latent variables
        for idx in range(len(self.layer_sizes) - 1):
            temp = self.codes[idx].unsqueeze(1).repeat(1, self.layer_sizes[idx + 1], 1).view([-1, self.code_dim])
            temp2 = self.codes[idx + 1].unsqueeze(0).repeat(self.layer_sizes[idx], 1, 1).view([-1, self.code_dim])
            concated = torch.cat((temp, temp2), dim=1)

            self.weight_codes.append(concated)

            # Generate bias codes (concatenation is just with a zeros vector)
            self.bias_codes.append(torch.cat((self.codes[idx + 1], torch.zeros_like(self.codes[idx + 1])), dim=1))

    def kl_z_term(self):
        """
        KL term related to the distribution parameters of the meta-variables, with prior N(0, 1)
        :return: Sum of the KL values over each latent variable
        """
        mus = torch.cat([cmu.view([-1]) for cmu in self.code_mu])
        var = torch.cat([cvar.view([-1]) for cvar in self.code_var])

        q = Normal(mus, var)
        N = Normal(torch.zeros(len(mus), device=mus.device), torch.ones(len(mus), device=mus.device))
        klz = kl(q, N).sum()
        return klz

    def forward(self, x, perturb=False):
        """
        Handles iterating through each layer, generating the distribution parameters and sampling
        the weights and biases for that layer
        :param x: input x
        :param perturb: whether to perturb one meta-var
        """
        # Generate weight codes given current latent codes
        self.generate_weight_codes()
        local_code = self.embedder(x)

        # Perturb one meta-var to test function draws
        if perturb:
            indice = np.random.randint(0, self.weight_codes[0].shape[0], 1)
            self.weight_codes[0][indice] += torch.hstack((torch.zeros([2]), 2 * torch.randn([2])))

        # Iterate over layers to get output
        for lidx in range(len(self.layer_sizes) - 1):
            # Get latent codes of hyperprior
            if self.local:
                latent_w = self.hyperprior(torch.cat((self.weight_codes[lidx], local_code), dim=1))
                latent_b = self.hyperprior(torch.cat((self.bias_codes[lidx], local_code), dim=1))
            else:
                latent_w = self.hyperprior(self.weight_codes[lidx])
                latent_b = self.hyperprior(self.bias_codes[lidx])

            # Sample weights
            w_mu, w_var = self.mean_net(latent_w), self.var_net(latent_w)
            w = (w_mu + torch.randn_like(w_mu) * w_var).view([self.layer_sizes[lidx], self.layer_sizes[lidx + 1]])

            # Sample biases
            b_mu, b_var = self.mean_net(latent_b), self.var_net(latent_b)
            b = (b_mu + torch.randn_like(b_mu) * b_var).squeeze()

            # Perform the linear layer and activate
            x = func.linear(x, w.T, b)
            x = self.acts[lidx](x)
            
        return x

    def bce_predict(self, x):
        """ Simply turns the softmax outputs into full class predictions """
        # Apply softmax to output.
        pred = self.forward(x)

        ans = []
        # Pick the class with maximum weight
        for t in pred:
            if t < 0.5:
                ans.append(0)
            else:
                ans.append(1)
        return ans

    def ce_predict(self, x):
        """ Get class predictions from prediction vector """
        pred = self.forward(x).detach().cpu().numpy()
        argmax = np.argmax(pred, axis=1)
        return argmax

    def weight_correlations(self, x, indice, layer, shift):
        """ Get the weight vectors of two random nodes to another """
        # Clamp the weight code to its distribution mean
        # self.codes = [self.code_mu[i] for i in range(len(self.layer_sizes))]
        self.codes = [
            self.code_mu[i] + torch.randn_like(self.code_mu[i]) * self.code_var[i]
            for i in range(len(self.layer_sizes))
        ]

        self.codes[layer + 1][indice] = self.code_mu[layer + 1][indice] + torch.Tensor(shift)

        # Generate weight codes
        self.weight_codes = []
        self.bias_codes = []

        # Loop between the layers and generate their weight codes by concatenating each units' latent var.
        # Units in the smaller layer need to be duplicated to the size of the next layer in order to perform
        # easy concatenation between their latent variables
        for idx in range(len(self.layer_sizes) - 1):
            temp = self.codes[idx].unsqueeze(1).repeat(1, self.layer_sizes[idx + 1], 1).view([-1, self.code_dim])
            temp2 = self.codes[idx + 1].unsqueeze(0).repeat(self.layer_sizes[idx], 1, 1).view([-1, self.code_dim])
            concated = torch.cat((temp, temp2), dim=1)

            self.weight_codes.append(concated)

            # Generate bias codes (concatenation is just with a zeros vector)
            self.bias_codes.append(torch.cat((self.codes[idx + 1], torch.zeros_like(self.codes[idx + 1])), dim=1))

        # Iterate over layers to get output
        w_out = None
        for lidx in range(len(self.layer_sizes) - 1):
            # Get latent codes of hyperprior
            latent_w = self.hyperprior(self.weight_codes[lidx])
            latent_b = self.hyperprior(self.bias_codes[lidx])

            # Sample weights
            w_mu, w_var = self.mean_net(latent_w), self.var_net(latent_w)
            w = (w_mu + torch.randn_like(w_mu) * w_var.exp()).view([self.layer_sizes[lidx], self.layer_sizes[lidx + 1]])

            # Get the sampled weight for the specific layer
            if lidx == layer:
                w_out = w

            # Sample biases
            b_mu, b_var = self.mean_net(latent_b), self.var_net(latent_b)
            b = (b_mu + torch.randn_like(b_mu) * b_var.exp()).squeeze()

            # Perform the linear layer and activate
            x = func.linear(x, w.T, b)
            x = self.acts[lidx](x)

        # Return only the node weights that are cared about
        return x, w_out[:, indice]

In [3]:
net = MetaPrior(layer_sizes=[1, 100, 3])

In [7]:
net.generate_weight_codes()
# print(net.weight_codes[1])

for p in net.weight_codes[1]:
    print(p)

tensor([ 0.3884, -0.5718, -0.6217,  1.7311], grad_fn=<UnbindBackward>)
tensor([ 0.3884, -0.5718, -0.2133, -0.3903], grad_fn=<UnbindBackward>)
tensor([ 0.3884, -0.5718,  2.9750, -0.6791], grad_fn=<UnbindBackward>)
tensor([-1.4302,  0.0160, -0.6217,  1.7311], grad_fn=<UnbindBackward>)
tensor([-1.4302,  0.0160, -0.2133, -0.3903], grad_fn=<UnbindBackward>)
tensor([-1.4302,  0.0160,  2.9750, -0.6791], grad_fn=<UnbindBackward>)
tensor([-0.1901, -0.4486, -0.6217,  1.7311], grad_fn=<UnbindBackward>)
tensor([-0.1901, -0.4486, -0.2133, -0.3903], grad_fn=<UnbindBackward>)
tensor([-0.1901, -0.4486,  2.9750, -0.6791], grad_fn=<UnbindBackward>)
tensor([ 0.8306,  0.4466, -0.6217,  1.7311], grad_fn=<UnbindBackward>)
tensor([ 0.8306,  0.4466, -0.2133, -0.3903], grad_fn=<UnbindBackward>)
tensor([ 0.8306,  0.4466,  2.9750, -0.6791], grad_fn=<UnbindBackward>)
tensor([ 0.5195, -0.0507, -0.6217,  1.7311], grad_fn=<UnbindBackward>)
tensor([ 0.5195, -0.0507, -0.2133, -0.3903], grad_fn=<UnbindBackward>)
tensor

In [8]:
for p in net.bias_codes[-1]:
    print(p)

tensor([-0.6217,  1.7311,  0.0000,  0.0000], grad_fn=<UnbindBackward>)
tensor([-0.2133, -0.3903,  0.0000,  0.0000], grad_fn=<UnbindBackward>)
tensor([ 2.9750, -0.6791,  0.0000,  0.0000], grad_fn=<UnbindBackward>)


In [20]:
"""
Handles building the weight codes and draw samples from i
:return:
"""
# Define initial distributions
layer_sizes = [2, 100, 3]
code_dim = 2
code_mu = [torch.zeros([lsize, code_dim]) for lsize in layer_sizes]
code_var = [torch.ones([lsize, code_dim]) for lsize in layer_sizes]

# Sample the codes array
codes = [code_mu[i] + torch.randn_like(code_mu[i]) * code_var[i] for i in range(len(layer_sizes))]

weight_codes = []
bias_codes = []

# Loop between the layers and generate their weight codes by concatenating each units' latent var.
# Units in the smaller layer need to be duplicated to the size of the next layer in order to perform
# easy concatenation between their latent variables
# for idx in range(len(layer_sizes) - 1):
#     temp = codes[idx].unsqueeze(1).repeat(1, layer_sizes[idx + 1], 1).view([-1, code_dim])
#     temp2 = codes[idx + 1].unsqueeze(0).repeat(layer_sizes[idx], 1, 1).view([-1, code_dim])
#     concated = torch.cat((temp, temp2), dim=1)

#     weight_codes.append(concated)

#     # Generate bias codes (concatenation is just with a zeros vector)
#     bias_codes.append(torch.cat((codes[idx + 1], torch.zeros_like(codes[idx + 1])), dim=1))
    
    
for idx in range(len(layer_sizes) - 1):
    temp = codes[idx].unsqueeze(1).repeat(1, layer_sizes[idx + 1], 1).view([-1, code_dim])
    temp2 = codes[idx + 1].unsqueeze(0).repeat(layer_sizes[idx], 1, 1).view([-1, code_dim])
    concated = torch.cat((temp2, temp), dim=1)

    weight_codes.append(concated)

    # Generate bias codes (concatenation is just with a zeros vector)
    bias_codes.append(torch.cat((torch.zeros_like(codes[idx + 1]), codes[idx + 1]), dim=1))

In [21]:
for p in weight_codes[0]:
    print(p)

tensor([ 0.6190,  0.0557, -2.7507,  1.2097])
tensor([-0.3273, -1.2928, -2.7507,  1.2097])
tensor([-0.1604,  0.1112, -2.7507,  1.2097])
tensor([-0.6080, -0.7631, -2.7507,  1.2097])
tensor([-1.1110,  0.6588, -2.7507,  1.2097])
tensor([-0.8877,  1.3068, -2.7507,  1.2097])
tensor([-2.7106, -2.0531, -2.7507,  1.2097])
tensor([ 0.8343,  0.6821, -2.7507,  1.2097])
tensor([-1.2697, -0.3515, -2.7507,  1.2097])
tensor([-0.0925,  0.0214, -2.7507,  1.2097])
tensor([-0.0405, -0.7963, -2.7507,  1.2097])
tensor([-1.8382,  0.4193, -2.7507,  1.2097])
tensor([-0.3724, -1.0349, -2.7507,  1.2097])
tensor([ 0.3239,  0.4600, -2.7507,  1.2097])
tensor([ 0.0571, -0.2020, -2.7507,  1.2097])
tensor([ 0.0103,  0.5091, -2.7507,  1.2097])
tensor([-0.9231,  0.5619, -2.7507,  1.2097])
tensor([-0.0176, -1.7578, -2.7507,  1.2097])
tensor([ 0.5695, -0.5468, -2.7507,  1.2097])
tensor([ 0.6417, -1.0731, -2.7507,  1.2097])
tensor([ 1.2620, -0.3106, -2.7507,  1.2097])
tensor([-0.5430, -0.0192, -2.7507,  1.2097])
tensor([ 1