In [0]:
import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F
import argparse
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.optim as optim
import math

class ChoiceNet(nn.Module):
  def __init__(self, backbone, y_dim, num_mixture, feature_dim, logSigmaZval, tau_inv, pi1_bias):
    super(ChoiceNet,self).__init__()
    self.backbone = backbone
    self.y_dim = y_dim
    self.num_mixture = num_mixture
    self.feature_dim = feature_dim
    self.logSigmaZval = logSigmaZval
    self.tau_inv = tau_inv
    self.pi1_bias = pi1_bias
    self.rho_ref = 1
    #self.USE_GAT = USE_GAT
    
    
    self.fc_feature_dim = nn.Linear(7*7*64, self.feature_dim)
    self.fc_num_mixture = nn.Linear(self.feature_dim, num_mixture)

    self.fc_var_raw = nn.Linear(self.feature_dim, self.y_dim)

    self.fc_pi_logits = nn.Linear(self.feature_dim, self.num_mixture)

  def forward(self, x):
    x = self.backbone(x)
    x = x.view(x.size(0),-1) # flatten
    self.feature = self.fc_feature_dim(x) # feature, h
    #print(self.feature.size())
    
    rho_raw = self.fc_num_mixture(self.feature)
    rho_temp = F.sigmoid(rho_raw)
    rho = torch.cat([rho_temp[:, 0:1]*0.0 + self.rho_ref, rho_temp[:, 1:]], axis=1) # rho(h)=rho1~rhoK, rho_ref=1

    Q = self.feature_dim
    num_data = x.size()[0]

    #make_sample
    muW_tile, muZ_tile, sigmaW_tile, sigmaZ_tile = self.make_sample(Q, num_data) 
    
    # cholesky #[K*N*Q*D] #W_bar
    #branch_2
    samplerList = self.cholesky(self.num_mixture, Q, rho, num_data, muW_tile, sigmaW_tile, muZ_tile, sigmaZ_tile)
    wSample = samplerList.permute(1,3,0,2) #[N*D*K*Q]

    #K mean mixtures
    wTemp = wSample.contiguous().view(num_data, self.num_mixture*self.y_dim, Q)
    featRsh = self.feature.view(num_data, Q, 1)
    _mu = torch.matmul(wTemp, featRsh) #[N*DK*1]
    mu = _mu.view(num_data, self.y_dim, self.num_mixture)

    ### Add bias to mu (after)

    #K var mixtures #(6) 
    #branch_3
    logvar_raw = self.fc_var_raw(self.feature) #[N*D]
    var_raw = torch.exp(logvar_raw)
    var_tile = var_raw.unsqueeze(-1).repeat(1, 1, self.num_mixture) #N*D*K
    rho_tile = rho.unsqueeze(1).repeat(1, self.y_dim, 1)#N,D,K
    tau_inv = self.tau_inv
    var = (1.0 - torch.pow(rho_tile,2))*var_tile + tau_inv

    # Weight allocation probability pi [N*K] # pi_k = softmax()_k
    #branch_1
    pi_logits = self.fc_pi_logits(self.feature) #[N*K]
    pi_temp = F.softmax(pi_logits, dim=1)
    pi_temp = torch.cat((pi_temp[:, 0:1] + self.pi1_bias, pi_temp[:, 1:]), axis=1)
    pi = F.softmax(pi_temp, dim=1)

    return rho, mu, var, pi
    
  def make_sample(self, Q, num_data):
    N = num_data

    muW = torch.nn.init.normal_(torch.empty(Q, self.y_dim), mean = 0.0, std = 0.1)
    #muW = torch.normal(std=0.1,size=(Q, self.y_dim)) # Q*D
    logSigmaW = torch.nn.init.constant_(torch.empty(Q, self.y_dim), -3.0)
    
    muZ = torch.zeros(Q, self.y_dim) # Q*D
    logSigmaZ = torch.nn.init.constant_(torch.empty(Q, self.y_dim), self.logSigmaZval)
    
    muW_tile = muW.unsqueeze(0).repeat(N,1,1) # N*Q*D
    sigmaW_tile = torch.exp(logSigmaW.unsqueeze(0).repeat(N,1,1)) #N*Q*D
    
    muZ_tile = muZ.unsqueeze(0).repeat(N,1,1)
    sigmaZ_tile = torch.exp(logSigmaZ.unsqueeze(0).repeat(N,1,1))

    return muW_tile, muZ_tile, sigmaW_tile, sigmaZ_tile

  def cholesky(self, num_mixture, Q, rho, num_data, muW_tile, sigmaW_tile, muZ_tile, sigmaZ_tile):
    samplerList = []
    for mix_idx in range(self.num_mixture):
      rho_j = rho[:, mix_idx : mix_idx+1] # N*1
      rho_tile = rho_j.unsqueeze(-1).repeat(1, Q, self.y_dim) # N*Q*D
      
      epsW = torch.randn(num_data, Q, self.y_dim, dtype=torch.float) #mean=0, std=1
      W = muW_tile + torch.sqrt(sigmaW_tile)*epsW
      
      epsZ = torch.randn(num_data, Q, self.y_dim, dtype=torch.float)
      Z = muZ_tile + torch.sqrt(sigmaZ_tile)*epsZ

      #Cholesky
      Y = rho_tile*muW_tile + (1.0 - torch.pow(rho_tile,2))\
                               *(rho_tile*torch.sqrt(sigmaZ_tile)/torch.sqrt(sigmaW_tile)\
                                 *(W - muW_tile) + Z*torch.sqrt(1 - torch.pow(rho_tile,2)))
      
      samplerList.append(Y)
    return torch.stack(samplerList)


    #muW_tile = 

  

In [0]:
def make_layers(in_channels, h_dim, filter_size, max_pools, activation, batch_norm =False):
    layers = []

    #mnist
    for h_idx in range(len(h_dim)):
      fs = filter_size[h_idx]
      hidden = h_dim[h_idx]
      conv2d = nn.Conv2d(in_channels, hidden, kernel_size = fs, padding = 1)
      if batch_norm:
        layers += [conv2d, nn.BatchNorm2d(hidden), activation]
      else:
        layers += [conv2d, activation]
      in_channels = hidden
     
      max_pool = max_pools[h_idx]
      if max_pool > 1:
        layers += [nn.MaxPool2d(max_pool)]
    return nn.Sequential(*layers)


def ChoiceNet_Mnist():
  return ChoiceNet(make_layers(in_channels, h_dim, filter_size, max_pools, activation, batch_norm ),
                   y_dim, num_mixture, feature_dim, logSigmaZval, tau_inv, pi1_bias)

In [0]:
def train(epoch):
  model.train()
  for batch_idx, (data,target) in enumerate(train_loader):
    if args['cuda']:
      data, target = data.cuda(), target.cuda()
    
    temp = torch.LongTensor(1000,1).random_() % 10
    y_onehot = torch.FloatTensor(1000,10)

    y_onehot.zero_()
    y_onehot.scatter_(1, temp, 1)
    
    data, target = Variable(data), Variable(y_onehot)
    
    #GRAD_CLIP = True, USE_SGD = False

    optimizer.zero_grad()

    rho, mu, var, pi = model(data)

    loss, acc = MDNloss(len(data), rho, mu, var, pi, target, y_dim=10, num_mixture=10, logsumexp_coef= 1e-2, kl_reg_coef= 1e-4).forward()
    loss.requires_grad=True
    loss.backward()

    #print(acc)

    optimizer.step()

    if batch_idx%args['log_interval'] == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          epoch, batch_idx * len(data), len(train_loader.dataset),
          100. *batch_idx/len(train_loader), loss.data
      ))


In [0]:
def test():
  model.eval()
  test_loss = 0
  correct = 0
  for data,target in test_loader:
    if args['cuda']:
      data, target = data.cuda(), target.cuda()
    
    temp = torch.LongTensor(1000,1).random_() % 10
    y_onehot = torch.FloatTensor(1000,10)

    y_onehot.zero_()
    y_onehot.scatter_(1, temp, 1)
    
    data, target = Variable(data), Variable(y_onehot)

    rho, mu, var, pi = model(data)

    loss , acc = MDNloss(len(data), rho, mu, var, pi, target, y_dim=10, num_mixture=10, logsumexp_coef= 1e-2, kl_reg_coef= 1e-4).forward()
    test_loss += loss
    correct += acc#accuracy(pi,mu, len(data) ,target, y_dim=10)
  
  loss /= len(test_loader.dataset)
  print('\nTest set : Average loss : {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
      loss, correct, len(test_loader.dataset),
      100.*correct/len(test_loader.dataset)))



In [0]:

class MDNloss(nn.Module):
  def __init__(self, num_data, rho, mu, var, pi, target, y_dim, num_mixture, logsumexp_coef, kl_reg_coef):
    super(MDNloss, self).__init__()
    self.num_data = num_data
    self.rho = Variable(rho) #N*K
    self.mu = Variable(mu) #N*D*K
    self.var = Variable(var) #N*D*K
    self.pi = Variable(pi) #N*K
    self.target = Variable(target)
    self.y_dim = y_dim
    self.num_mixture = num_mixture
    self.logsumexp_coef = logsumexp_coef
    self.kl_reg_coef = kl_reg_coef

    self.yhat = self.mu + torch.sqrt(self.var)*torch.randn(self.num_data, self.y_dim, self.num_mixture)

  def forward(self):
    target_tile = self.target.unsqueeze(-1).repeat(1, 1, self.num_mixture)# N*D*K
    pi_tile = self.pi.unsqueeze(1).repeat(1, self.y_dim, 1) # N*D*K

    yhat_normalized = F.softmax(self.yhat, dim=1)
    _loss_fit = torch.sum(-pi_tile*yhat_normalized*target_tile, axis=[1,2])
    loss_fit = torch.mean(_loss_fit)

    _loss_reg = self.pi*torch.logsumexp(self.yhat,axis=[1])
    __loss_reg = torch.sum(_loss_reg,axis=[1])
    loss_reg = self.logsumexp_coef*torch.mean(__loss_reg)

    _eps = 1e-8
    _kl_reg = self.kl_reg_coef*torch.sum(self.rho*(torch.log(self.pi+_eps) - torch.log(self.rho+_eps)), axis=1)
    kl_reg = torch.mean(_kl_reg)
    # prob = self.pi*self.g_p(self.var,self.mu,self.target)
    # nll = -torch.log(torch.sum(prob,dim=1))

    acc = self.acc()
 
    return torch.mean(loss_fit + loss_reg + kl_reg), acc

  def acc(self):
    y = self.yhat[:,:,0] #N*D
    #print('y',torch.argmax(y,dim=1))
    #print('target',torch.argmax(self.target,dim=1))
    acc = (torch.argmax(y,dim=1) == torch.argmax(self.target,dim=1)).sum().item()/y.size()[0]
    print(acc)
    return acc
    #self.target #N*D


    
''' 

def accuracy(pi,mu, N ,target, y_dim):
  max_idx = torch.argmax(pi, axis=1)#n
  max_idx = 0*torch.ones_like(max_idx)

  _mesh = torch.meshgrid(torch.arange(0,y_dim), torch.arange(0,N)) 
  mesh = [_mesh[1], _mesh[0]]
  coords = torch.stack([torch.transpose(gv,1,0) for gv in mesh]+  # N,D,2
                       [max_idx.unsqueeze(-1).repeat(1, y_dim).view(N, y_dim)], axis=2)
  #mu_bar = torch.Tensor(mu, coords)
  #print(mu)
  mu_bar = mu[:,:,:1]#n,d,10      n,d,3
  _corr = torch.equal(torch.argmax(mu_bar, 1), torch.argmax(target, 1))   #n,d=>n
  corr = torch.sum(torch.argmax(mu_bar, 1) == torch.argmax(target, 1))
  return corr/N
''' 
  


In [0]:
args = {}
kwargs = {}
args['batch_size'] = 1000
args['test_batch_size'] = 1000
args['epochs'] = 10
args['lr'] = 0.01
args['momentum'] = 0.5

args['seed'] = 1
args['log_interval'] = 10
args['cuda'] = False

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.1307,), (0.3081,))
                   ])),
                   batch_size = args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True,
                   transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.1307,), (0.3081,))
                   ])),
                   batch_size = args['test_batch_size'], shuffle=True, **kwargs)


In [35]:

model = ChoiceNet(make_layers(in_channels=1, h_dim=[64,64], filter_size=[3,3], max_pools=[2,2], activation=torch.nn.ReLU(), batch_norm=True ),y_dim=10, num_mixture=10, feature_dim=256, logSigmaZval=-2, tau_inv=1e-4, pi1_bias=0.0)

optimizer = optim.Adam(model.parameters(), lr=args['lr'], weight_decay=1e-5)

for epoch in range(1, args['epochs']+1):
  train(epoch)
  test()



0.101
0.094
0.095
0.086
0.107
0.099
0.09
0.113
0.11
0.109
0.107
0.095
0.104
0.099
0.092
0.088
0.095
0.127
0.095
0.095
0.103
0.096
0.103
0.096
0.096
0.098
0.112
0.075
0.105
0.093
0.107
0.117
0.117
0.084
0.109
0.096
0.085
0.089
0.11
0.1
0.113
0.089
0.1
0.083
0.099
0.09
0.107
0.09
0.092
0.092
0.082
0.096
0.102
0.087
0.1
0.092
0.105
0.105
0.078
0.107
0.098
0.094
0.114
0.102
0.098
0.114
0.086
0.114
0.095
0.108

Test set : Average loss : -0.0000, Accuracy: 1.023/10000 (0%)

0.112
0.093
0.12
0.099
0.108
0.102
0.119
0.09
0.091
0.095
0.09
0.114
0.106
0.09
0.085
0.097
0.119
0.089
0.095
0.117
0.104
0.133
0.086
0.099
0.104
0.098
0.087
0.103
0.109
0.1
0.114
0.097
0.082
0.106
0.099
0.115
0.1
0.094
0.104
0.094
0.097
0.099
0.107
0.11
0.092
0.101
0.104
0.113
0.106
0.104
0.104
0.101
0.105
0.124
0.089
0.108
0.115
0.09
0.1
0.086
0.114
0.095
0.094
0.081
0.092
0.104
0.111
0.117
0.098
0.096

Test set : Average loss : -0.0000, Accuracy: 1.002/10000 (0%)

0.114
0.103
0.13
0.105
0.108
0.113
0.096
0.091
0.115
0.

KeyboardInterrupt: ignored