In [1]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset
import numpy as np

In [2]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
cuda = True

In [4]:
train_ds = datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
test_ds = datasets.MNIST('data', train=False, download=True, 
                       transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

In [5]:
class FlatMNIST(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.n = len(self.dataset)
        
    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        return x.view(28*28), y

    def __len__(self): return self.n

In [6]:
tr_ds = FlatMNIST(train_ds)
ts_ds = FlatMNIST(test_ds)

In [7]:
batch_size = 64
#batch_size = 5 # for testing
kwargs = {'num_workers': 1, 'pin_memory': True} 
if cuda:
    train_loader = torch.utils.data.DataLoader(tr_ds, batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(ts_ds, batch_size=batch_size, shuffle=False, **kwargs)
else:
    train_loader = torch.utils.data.DataLoader(tr_ds, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(ts_ds, batch_size=batch_size, shuffle=False)
    

In [8]:
class NTree2(nn.Module):
    def __init__(self, tree_depth=2, n_classes=10, ni=28*28, lmbda = 0.1):
        super().__init__()
        self.num_leaves = 2**tree_depth
        self.n_classes = n_classes
        self.num_nodes = self.num_leaves -1
        self.tree_depth = tree_depth
        
        # regularization
        self.lmbda = lmbda
        self.nodes =  nn.ModuleList([nn.Linear(ni, 1) for i in range(self.num_nodes)])
        self.leaves = nn.ParameterList([nn.Parameter(torch.randn(self.n_classes)) for i in range(self.num_leaves)])
        
        # inverse temperature filter
        self.betas = nn.ParameterList([nn.Parameter(torch.randn(1)) for i in range(self.num_nodes)])
        
        
    def forward(self, x):
        bz = x.size()[0]
        sigmoid = nn.Sigmoid()
        softmax = nn.Softmax(dim=0)
        
        # create distributions at each leaf - store for later
        leaf_pcts = [softmax(leaf_param) for leaf_param in self.leaves]
        leaf_dist = [pct.expand(bz, self.n_classes) for pct in leaf_pcts]
        self.leaf_dist = leaf_dist
        
        # probabilities for inner nodes
        p_x = [sigmoid(self.betas[i]*self.nodes[i](x)) for i in range(self.num_nodes)]
    
        # manually doing 2 tiers
        path_prob = [p_x[0]*p_x[1], p_x[0]*(1-p_x[1]), (1-p_x[0])*p_x[2], (1-p_x[0])*(1-p_x[2])]
        
        return leaf_dist, path_prob, p_nodes

In [102]:
class tmp_mod(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = 1
    def forward(self, x):
        pass
    
tp = tmp_mod()
tp.cuda()

tmp_mod(
)

In [351]:
class NTree3(nn.Module):
    def __init__(self, tree_depth=3, n_classes=10, ni=28*28, lmbda = 0.1, on_cuda=False):
        super(NTree3,self).__init__()
        self.num_leaves = 2**tree_depth
        self.n_classes = n_classes
        self.num_nodes = self.num_leaves -1
        self.tree_depth = tree_depth
        self.on_cuda = on_cuda
        
        # regularization
        self.lmbda = lmbda

        
        leaf_params = [torch.randn(self.n_classes) for i in range(self.num_leaves)]
        beta_params = [torch.randn(1) for i in range(self.num_nodes)]
                
        if self.on_cuda==True:
            beta_params = [beta_param.cuda() for beta_param in beta_params]
            leaf_params = [leaf_param.cuda() for leaf_param in leaf_params]
        
        self.nodes =  nn.ModuleList([nn.Linear(ni, 1) for i in range(self.num_nodes)])
        self.leaves = nn.ParameterList([nn.Parameter(leaf_param) for leaf_param in leaf_params])
        
        # inverse temperature filter
        self.betas = nn.ParameterList([nn.Parameter(beta) for beta in beta_params])
        
        
    def forward(self, x):
        bz = x.size()[0]
        sigmoid = nn.Sigmoid()
        softmax = nn.Softmax(dim=0)
        
        # create distributions at each leaf - store for later
        leaf_pcts = [softmax(leaf_param) for leaf_param in self.leaves]
        leaf_dist = [pct.expand(bz, self.n_classes) for pct in leaf_pcts]
        self.leaf_dist = leaf_dist
        
        # probabilities of inner nodes
        p_x = [sigmoid(self.betas[i]*self.nodes[i](x)) for i in range(self.num_nodes)]
    
        # manually doing 3 tiers
        path_prob = [p_x[0]*p_x[1]*p_x[3],
                     p_x[0]*p_x[1]*(1-p_x[3]),
                     p_x[0]*(1-p_x[1])*p_x[4],
                     p_x[0]*(1-p_x[1])*(1-p_x[4]),
                     (1-p_x[0])*p_x[2]*p_x[5], 
                     (1-p_x[0])*p_x[2]*(1-p_x[5]),                      
                     (1-p_x[0])*(1-p_x[2])*p_x[6],
                     (1-p_x[0])*(1-p_x[2])*(1-p_x[6])                     
                    ]
        
        return leaf_dist, path_prob, p_x

In [352]:
def bigot_leaf_loss(path_prob, leaf_dist, labels, on_cuda):
    ymask = torch.FloatTensor(leaf_dist.size()).zero_()
    if on_cuda:
        ymask = ymask.cuda()    
    ymask.scatter_(1, labels.data.view(-1,1), 1)
    ymask = Variable(ymask)
    Tk_log_Qk = (torch.log(leaf_dist) * ymask).sum(1)
    PTQ = Tk_log_Qk[:,None]*path_prob
    return torch.sum(PTQ)


def total_loss(path_probs, leaf_dists, labels, on_cuda=False):
    L_x = [bigot_leaf_loss(path_prob, leaf_dist, labels, on_cuda) for path_prob, leaf_dist in zip(path_probs, leaf_dists)]
    return(torch.log(-torch.sum(torch.stack(L_x))))


def which_node(path_prob, n_leaves, on_cuda=False):
    node_id = torch.max(torch.stack(path_prob),dim=0)[1]    
    nodes_onehot = torch.FloatTensor(path_prob[0].size()[0], n_leaves).zero_()
    if on_cuda:
        node_id = node_id.cuda()
        nodes_onehot = nodes_onehot.cuda()
    node_mask = nodes_onehot.scatter_(1, node_id.data,1)
    return(node_id,node_mask)


def which_class(path_prob, leaf_dist, on_cuda=False):
    n_leaves = len(leaf_dist)
    node_id, node_mask = which_node(path_prob, n_leaves, on_cuda)
    max_class_per_node = torch.t(torch.max(torch.stack(leaf_dist),dim=2)[1])
    pred_class = torch.sum(Variable(node_mask.long())*max_class_per_node,dim=1)
    return(pred_class)


def acc_calc(val_dl, model, on_cuda=False):
    model.eval()
    val_ = iter(val_dl)
    correct = 0
    total = 0
    final_dist = 0
    n_leaves = 8
    for i, batch in enumerate(val_dl):
        data, labels = batch
        if on_cuda:
                data, labels = data.cuda(), labels.cuda()
        data_var, labels_var = Variable(data), Variable(labels)        
        leaf_dist, path_prob, p_nodes = model(data_var)
        final_dist += which_node(path_prob, n_leaves = 8, on_cuda=True)[1].sum(0)
        
        preds = which_class(path_prob, leaf_dist, on_cuda=on_cuda)
        match = labels.eq(preds.data)
        correct += match.sum()
        total += match.size()[0]
    return(correct/total, correct, total, final_dist) 
    

In [353]:
train_dl = iter(train_loader)

### Try one batch of 64 images

In [354]:
data, labels = train_dl.next()
data_var, labels_var = Variable(data), Variable(labels)

In [355]:
model = NTree2()
model3 = NTree3()

In [356]:
train_dl = iter(train_loader)
data, labels = train_dl.next()
data_var, labels_var = Variable(data), Variable(labels)
leaf_dist, path_prob, p_nodes = model(data_var)

In [357]:
len(leaf_dist)
len(path_prob)

4

In [358]:
len(p_nodes)

7

In [359]:
which_node(path_prob,8)[1].sum(0)


 16
 44
  4
  0
  0
  0
  0
  0
[torch.FloatTensor of size 8]

In [365]:
is_cuda = False
kwargs = {'num_workers': 1, 'pin_memory': True}

if is_cuda:
    train_loader = torch.utils.data.DataLoader(tr_ds, batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(ts_ds, batch_size=batch_size, shuffle=False, **kwargs)
else:
    train_loader = torch.utils.data.DataLoader(tr_ds, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(ts_ds, batch_size=batch_size, shuffle=False)

model = NTree3(tree_depth=3, n_classes=10, ni=28*28, lmbda = 0.1, on_cuda=is_cuda)

if is_cuda:
    model.cuda()
    
n_epochs = 15
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_dl = iter(train_loader)
final_dist = 0

for epoch in range(n_epochs):
    running_loss = 0.0
    train_dl = iter(train_loader)
    
    for i, batch in enumerate(train_dl):
        data, labels = batch
        bz = data.size()[0]
        if is_cuda:
            data = data.cuda()
            labels = labels.cuda()
        data_var, labels_var = Variable(data), Variable(labels)
       
            
        optimizer.zero_grad()
        leaf_dist, path_prob, p_nodes = model(data_var)
        
        loss = total_loss(path_prob, leaf_dist, labels_var, on_cuda=is_cuda)
        loss.backward()
        optimizer.step()
        
        if i%200 == 0:
            acc, correct, total, dist = acc_calc(test_loader, model, on_cuda=is_cuda)
            print('Ep: %d , %05d/60000,  L: %.03f, A: %.03f, dist %s' % (epoch, i*bz ,loss.data[0], acc, list(dist.long().cpu().numpy())))

         

Ep: 0 , 00000/60000,  L: 5.150, A: 0.111, dist [652, 543, 874, 4088, 395, 1441, 865, 1142]
Ep: 0 , 12800/60000,  L: 4.960, A: 0.195, dist [2117, 0, 0, 5669, 890, 0, 2, 1322]
Ep: 0 , 25600/60000,  L: 4.852, A: 0.276, dist [1752, 4, 401, 3164, 2425, 0, 371, 1883]
Ep: 0 , 38400/60000,  L: 4.871, A: 0.336, dist [2066, 28, 1024, 2665, 1347, 0, 577, 2293]
Ep: 0 , 51200/60000,  L: 4.843, A: 0.347, dist [2146, 50, 1110, 2496, 1276, 0, 553, 2369]
Ep: 1 , 00000/60000,  L: 4.863, A: 0.350, dist [2096, 61, 1140, 2431, 1146, 0, 620, 2506]
Ep: 1 , 12800/60000,  L: 4.723, A: 0.349, dist [2067, 92, 1125, 2216, 1182, 0, 654, 2664]
Ep: 1 , 25600/60000,  L: 4.765, A: 0.354, dist [2135, 132, 1087, 2208, 1105, 0, 662, 2671]
Ep: 1 , 38400/60000,  L: 4.756, A: 0.425, dist [2087, 144, 1060, 2069, 1128, 0, 684, 2828]
Ep: 1 , 51200/60000,  L: 4.685, A: 0.433, dist [2167, 219, 1044, 2169, 1043, 0, 672, 2686]
Ep: 2 , 00000/60000,  L: 4.668, A: 0.436, dist [2175, 256, 1025, 2169, 1039, 0, 689, 2647]
Ep: 2 , 12800/

In [None]:
model.leaf_dist

In [None]:
acc, correct, total, dist = acc_calc(test_loader, model)

In [309]:
correct

539

In [927]:
total

10000

In [936]:
list(dist.long().numpy())

[7586, 0, 0, 2414]

In [914]:
model.leaves[0]

Parameter containing:
-2.0772
-0.6781
-1.6316
-1.6203
 0.3809
-1.5100
-1.1012
 1.9916
-1.4977
 1.8553
[torch.FloatTensor of size 10]

In [915]:
model.leaves[1]

Parameter containing:
-2.2890
 2.0111
-0.4270
 0.0372
-1.9340
 0.5527
-1.2689
-0.9730
 1.6722
-1.1189
[torch.FloatTensor of size 10]