In [0]:
import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F
class ChoiceNet(nn.Module):
  def __init__(self, feature, y_dim, num_mixture, feature_dim, logSigmaZval, tau_inv, pi1_bias):
    self.feature = feature
    self.y_dim = y_dim
    self.num_mixture = num_mixture
    self.feature_dim = feature_dim
    self.logSigmaZval = logSigmaZval
    self.tau_inv = tau_inv
    self.pi1_bias = pi1_bias
    #self.USE_GAT = USE_GAT
    
    
    self.fc_feature_dim = nn.Linear(7*7*64, self.feature_dim)
    self.fc_num_mixture = nn.Linear(self.feature_dim, num_mixture)

    self.fc_var_raw = nn.Linear(self.feature_dim, self.y_dim)

    self.fc_pi_logits = nn.Linear(self.feature_dim, self.num_mixture)
  def forward(self, x):
    x = self.feature(x)
    x = x.view(x.size(0),-1) # flatten
    feature = self.fc_feature_dim(x) # feature, h
    rho_raw = self.fc_num_mixture(feature)
    rho_temp = F.sigmoid(rho_raw)
    
    rho = torch.cat([rho_temp[:, 0:1]*0.0 + rho_ref, rho_temp[:, 1:]], axis=1) # rho(h)=rho1~rhoK, rho_ref=1

    Q = self.feature_dim
    num_data = x.size()[0]

    muW_tile, muZ_tile, sigmaW_tile, sigmaZ_tile = make_sample(Q, num_data) #make_sample

    samplerList = cholesky(num_mixture, Q, rho, num_data, muW_tile, sigmaW_tile, muZ_tile, sigmaZ_tile) # cholesky #[K*N*Q*D]

    wSample = samplerList.permute(1,3,0,2) #[N*D*K*Q]

    #K mean mixtures
    wTemp = wSample.view(N, self.num_mixture*self.y_dim, Q) #[N*DK*Q]
    featRsh = feature.view(N, Q, 1)
     
    _mu = torch.matmul(wTemp, reatRsh) #[N*DK*1]
    mu = _mu.view(N, self.y_dim, self.num_mixture)

    ### Add bias to mu (after)

    #K var mixtures
    logvar_raw = self.fc_var_raw(feature) #[N*D]
    var_raw = torch.exp(logvar_raw)
    var_tile = var_raw[:, :, None].repeat(1, 1, self.num_mixture)
    rho_tile = rho[:, None, :].repeat(1, self.y_dim, 1)
    tau_inv = self.tau_inv
    var = (1.0 - torch.pow(rho_tile))*var_tile + tau_inv

    # Weight allocation probability pi [N*K]
    pi_logits = self.fc_pi_logits(feature) #[N*K]
    pi_temp = torch.nn.Softmax(pi_logits, dim=1)

    pi_temp = torch.cat(pi_temp[:, 0:1] + self.pi1_bias, pi_temp[:, 1:], axis=1)

    pi = torch.nn.softmax(pi_temp, dim=1)

    

  def cholesky(self, num_mixture, Q, rho, num_data, muW_tile, sigmaW_tile, muZ_tile, sigmaZ_tile):
    samplerList = []
    for mix_idx in range(self.num_mixture):
      rho_j = rho[:, mix_idx : mix_dix+1]
      rho_tile = rho_j[:,:,None].repeat(1, Q, self.y_dim)
      
      epsW = torch.nn.init.normal_(torch.empty(num_data, Q, self.y_dim), mean = 0, std = 1)
      W = muW_tile + torch.sqrt(sigmaW_tile)*epsW
      
      epsZ = torch.nn.init.normal_(torch.empty(num_data, Q, self.y_dim), mean = 0, std = 1)
      Z = muZ_tile + torch.sqrt(sigmaZ_tile)*epsZ

      #Cholesky
      Y = rho_tile*muW_tile + (1.0 - torch.pow(rho_tile,2))\
                               *(rho_tile*torch.sqrt(sigmaZ_tile)/torch.sqrt(sigmaW_tile)\
                                 *(W - muW_tile) + Z*torch.sqrt(1 - torch.pow(rho_tile,2)))
      
      samplerList.append(Y)
    return torch.stack(samplerList)
  def make_sample(self, Q, num_data):
    N = num_data

    muW = torch.nn.init.normal_(torch.empty(Q, self.y_dim), mean = 0.0, std = 0.1)
    logSigmaW = torch.nn.init.constant_(torch.empty(Q, self.y_dim), -3.0)
    
    myZ = torch.zeros(Q, self.y_dim)
    logSigmaZ = torch.nn.init.constant_(torch.empty(Q, self.y_dim), self.logSigmaZval)
    
    muW_tile = muW[None, :, :].repeat(N,1,1)
    sigmaW_tile = torch.exp(logSigmaW[None,:,:].reapeat(N,1,1))
    
    muZ_tile = tile(muZ[None, :, :].repeat(N,1,1))
    sigmaZ_tile = torch.exp(logSigmaZ[None,:,:].reapeat(N,1,1))

    return muW_tile, muZ_tile, sigmaW_tile, sigmaZ_tile

    

    #muW_tile = 
def make_layers(self, in_channels, h_dim, filter_size, max_pools, activation, batch_norm =False):
    layers = []

    #mnist
    for h_idx in range(len(h_dim)):
      fs = filter_size[h_idx]
      hidden = h_dim[h_idx]
      conv2d = nn.Conv2d(in_channels, hidden, kernel_size = fs, padding = 1)
      if batch_norm:
        layers += [conv2d, nn.BatchNorm2d(hidden), nn.ReLU()]
      else:
        layers += [conv2d, nn.ReLU()]
      in_channels = hidden
     
      max_pool = max_pools[h_idx]
      if max_pool > 1:
        layers += [nn.Maxpool2d(max_pool, max_pool)]
    return nn.Sequential(*layers)


def ChoiceNet_Mnist():
  return ChoiceNet(make_layers(in_channels, h_dim, filter_size, max_pools, activation, batch_norm ),
                   num_mixture)