In [1]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset
import numpy as np

In [2]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
cuda = False

In [4]:
train_ds = datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
test_ds = datasets.MNIST('data', train=False, download=True, 
                       transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

In [5]:
class FlatMNIST(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.n = len(self.dataset)
        
    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        return x.view(28*28), y

    def __len__(self): return self.n

In [6]:
tr_ds = FlatMNIST(train_ds)
ts_ds = FlatMNIST(test_ds)

In [7]:
batch_size = 64
#batch_size = 5 # for testing
kwargs = {'num_workers': 1, 'pin_memory': True} 
if cuda:
    train_loader = torch.utils.data.DataLoader(tr_ds, batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(ts_ds, batch_size=batch_size, shuffle=False, **kwargs)
else:
    train_loader = torch.utils.data.DataLoader(tr_ds, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(ts_ds, batch_size=batch_size, shuffle=False)
    

In [8]:
F.softmax(Variable(torch.randn(10,1)), dim=0)

Variable containing:
 0.3180
 0.0370
 0.1363
 0.0285
 0.0128
 0.0388
 0.1669
 0.1886
 0.0333
 0.0397
[torch.FloatTensor of size 10x1]

In [9]:
class NTree2(nn.Module):
    def __init__(self, tree_depth=2, n_classes=10, ni=28*28, lmbda = 0.1):
        super().__init__()
        self.num_leaves = 2**tree_depth
        self.n_classes = n_classes
        self.num_nodes = self.num_leaves -1
        self.tree_depth = tree_depth
        
        # regularization
        self.lmbda = lmbda
        self.nodes =  nn.ModuleList([nn.Linear(ni, 1) for i in range(self.num_nodes)])
        self.leaves = nn.ParameterList([nn.Parameter(torch.randn(self.n_classes)) for i in range(self.num_leaves)])
        
        # inverse temperature filter
        self.betas = nn.ParameterList([nn.Parameter(torch.randn(1)) for i in range(self.num_nodes)])
        
        
    def forward(self, x):
        bz = x.size()[0]
        sigmoid = nn.Sigmoid()
        softmax = nn.Softmax(dim=0)
        
        # create distributions at each leaf - store for later
        leaf_pcts = [softmax(leaf_param) for leaf_param in self.leaves]
        leaf_dist = [pct.expand(bz, self.n_classes) for pct in leaf_pcts]
        self.leaf_dist = leaf_dist
        
        # probabilities for inner nodes
        p_x = [sigmoid(self.betas[i]*self.nodes[i](x)) for i in range(self.num_nodes)]
    
        # manually doing 2 tiers
        path_prob = [p_x[0]*p_x[1], p_x[0]*(1-p_x[1]), (1-p_x[0])*p_x[2], (1-p_x[0])*(1-p_x[2])]
        
        return leaf_dist, path_prob, p_nodes

In [10]:
class NTree3(nn.Module):
    def __init__(self, tree_depth=3, n_classes=10, ni=28*28, lmbda = 0.1):
        super().__init__()
        self.num_leaves = 2**tree_depth
        self.n_classes = n_classes
        self.num_nodes = self.num_leaves -1
        self.tree_depth = tree_depth
        
        # regularization
        self.lmbda = lmbda
        self.nodes =  nn.ModuleList([nn.Linear(ni, 1) for i in range(self.num_nodes)])
        self.leaves = nn.ParameterList([nn.Parameter(torch.randn(self.n_classes)) for i in range(self.num_leaves)])
        
        # inverse temperature filter
        self.betas = nn.ParameterList([nn.Parameter(torch.randn(1)) for i in range(self.num_nodes)])
        
        
    def forward(self, x):
        bz = x.size()[0]
        sigmoid = nn.Sigmoid()
        softmax = nn.Softmax(dim=0)
        
        # create distributions at each leaf - store for later
        leaf_pcts = [softmax(leaf_param) for leaf_param in self.leaves]
        leaf_dist = [pct.expand(bz, self.n_classes) for pct in leaf_pcts]
        self.leaf_dist = leaf_dist
        
        # probabilities for inner nodes
        p_x = [sigmoid(self.betas[i]*self.nodes[i](x)) for i in range(self.num_nodes)]
    
        # manually doing 3 tiers
        path_prob = [p_x[0]*p_x[1]*p_x[3],
                     p_x[0]*p_x[1]*(1-p_x[3]),
                     p_x[0]*(1-p_x[1])*p_x[4],
                     p_x[0]*(1-p_x[1])*(1-p_x[4]),
                     (1-p_x[0])*p_x[2]*p_x[5], 
                     (1-p_x[0])*p_x[2]*(1-p_x[5]),                      
                     (1-p_x[0])*(1-p_x[2])*p_x[6],
                     (1-p_x[0])*(1-p_x[2])*(1-p_x[6])                     
                    ]
        
        return leaf_dist, path_prob, p_x

In [11]:
def bigot_leaf_loss(path_prob, leaf_dist, labels):
    ymask = torch.FloatTensor(leaf_dist.size()).zero_()
    ymask.scatter_(1, labels.data.view(-1,1), 1)
    ymask = Variable(ymask)
    Tk_log_Qk = (torch.log(leaf_dist) * ymask).sum(1)
    PTQ = Tk_log_Qk[:,None]*path_prob
    return torch.sum(PTQ)


def total_loss(path_probs, leaf_dists, labels):
    L_x = [bigot_leaf_loss(path_prob, leaf_dist, labels) for path_prob, leaf_dist in zip(path_probs, leaf_dists)]
    return(torch.log(-torch.sum(torch.stack(L_x))))


def which_node(path_prob, n_leaves):
    node_id = torch.max(torch.stack(path_prob),dim=0)[1]
    nodes_onehot = torch.FloatTensor(path_prob[0].size()[0], n_leaves).zero_()
    node_mask = nodes_onehot.scatter_(1, node_id.data,1)
    return(node_id,node_mask)


def which_class(path_prob, leaf_dist):
    n_leaves = len(leaf_dist)
    node_id, node_mask = which_node(path_prob, n_leaves)
    max_class_per_node = torch.t(torch.max(torch.stack(leaf_dist),dim=2)[1])
    pred_class = torch.sum(Variable(node_mask.long())*max_class_per_node,dim=1)
    return(pred_class)


def acc_calc(val_dl, model):
    model.eval()
    val_ = iter(val_dl)
    correct = 0
    total = 0
    final_dist = 0
    n_leaves = 8
    for i, batch in enumerate(val_dl):
        data, labels = batch
        data_var, labels_var = Variable(data), Variable(labels)        
        leaf_dist, path_prob, p_nodes = model(data_var)
        final_dist += which_node(path_prob, n_leaves = 8)[1].sum(0)
        
        preds = which_class(path_prob, leaf_dist)
        match = labels.eq(preds.data)
        correct += match.sum()
        total += match.size()[0]
    return(correct/total, correct, total, final_dist) 
    

In [12]:
train_dl = iter(train_loader)

### Try one batch of 64 images

In [13]:
data, labels = train_dl.next()
data_var, labels_var = Variable(data), Variable(labels)

In [14]:
model = NTree2()
model = NTree3()

In [15]:
train_dl = iter(train_loader)
data, labels = train_dl.next()
data_var, labels_var = Variable(data), Variable(labels)
leaf_dist, path_prob, p_nodes = model(data_var)

In [16]:
len(leaf_dist)
len(path_prob)

8

In [17]:
len(p_nodes)

7

In [18]:
p_nodes

[Variable containing:
  0.5243
  0.7274
  0.4359
  0.6828
  0.4383
  0.6025
  0.4270
  0.5395
  0.5404
  0.3634
  0.3062
  0.4253
  0.8027
  0.4718
  0.2910
  0.4861
  0.6226
  0.3500
  0.4169
  0.4069
  0.4033
  0.4691
  0.4266
  0.5239
  0.6072
  0.6877
  0.6229
  0.4558
  0.4724
  0.6173
  0.5637
  0.6454
  0.2830
  0.4617
  0.5207
  0.4583
  0.4228
  0.3241
  0.6258
  0.4920
  0.3861
  0.2906
  0.6850
  0.5055
  0.4674
  0.5462
  0.5243
  0.3371
  0.6368
  0.6639
  0.4933
  0.4415
  0.3350
  0.5018
  0.5203
  0.4005
  0.4713
  0.4937
  0.4421
  0.6957
  0.5062
  0.6551
  0.4072
  0.6278
 [torch.FloatTensor of size 64x1], Variable containing:
  0.4496
  0.4299
  0.7984
  0.4764
  0.6367
  0.6180
  0.3971
  0.4953
  0.6042
  0.5410
  0.5075
  0.3589
  0.6318
  0.5393
  0.6392
  0.5596
  0.7495
  0.5330
  0.8445
  0.7054
  0.8474
  0.5503
  0.6392
  0.7086
  0.7280
  0.5346
  0.5052
  0.7178
  0.5114
  0.3825
  0.4774
  0.4351
  0.6000
  0.5906
  0.7340
  0.7333
  0.5567
  0.4645
  0.

In [979]:
which_node(path_prob,8)[1].sum(0)


  2
 10
 19
  6
 13
  6
  3
  5
[torch.FloatTensor of size 8]

In [983]:
model = NTree3()
n_epochs = 15
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_dl = iter(train_loader)
final_dist = 0

for epoch in range(n_epochs):
    running_loss = 0.0
    train_dl = iter(train_loader)
    
    for i, batch in enumerate(train_dl):
        data, labels = batch
        bz = data.size()[0]
        data_var, labels_var = Variable(data), Variable(labels)
        optimizer.zero_grad()
        
        leaf_dist, path_prob, p_nodes = model(data_var)
        
        loss = total_loss(path_prob, leaf_dist, labels_var)
        loss.backward()
        optimizer.step()
        
        if i%200 == 0:
            acc, correct, total, dist = acc_calc(test_loader, model)
            print('Ep: %d , %05d/60000,  L: %.03f, A: %.03f, dist %s' % (epoch, i*bz ,loss.data[0], acc, list(dist.long().numpy())))

         

Ep: 0 , 00000/60000,  L: 5.181, A: 0.094, dist [1253, 472, 161, 104, 2951, 2903, 1017, 1139]
Ep: 0 , 12800/60000,  L: 5.020, A: 0.192, dist [116, 273, 0, 0, 2370, 0, 2521, 4720]
Ep: 0 , 25600/60000,  L: 4.953, A: 0.195, dist [7, 7, 0, 0, 2439, 591, 3259, 3697]
Ep: 0 , 38400/60000,  L: 4.860, A: 0.195, dist [0, 0, 0, 0, 2801, 1491, 3150, 2558]
Ep: 0 , 51200/60000,  L: 4.861, A: 0.207, dist [0, 0, 0, 0, 2897, 1683, 3128, 2292]
Ep: 1 , 00000/60000,  L: 4.814, A: 0.204, dist [0, 0, 0, 0, 2934, 1774, 3069, 2223]
Ep: 1 , 12800/60000,  L: 4.839, A: 0.203, dist [0, 0, 0, 0, 2992, 1813, 2986, 2209]
Ep: 1 , 25600/60000,  L: 4.844, A: 0.288, dist [0, 0, 0, 0, 3009, 1829, 3041, 2121]
Ep: 1 , 38400/60000,  L: 4.778, A: 0.287, dist [0, 0, 0, 0, 3034, 1859, 2998, 2109]
Ep: 1 , 51200/60000,  L: 4.747, A: 0.289, dist [0, 0, 0, 0, 2916, 1871, 3004, 2209]
Ep: 2 , 00000/60000,  L: 4.790, A: 0.389, dist [0, 0, 0, 0, 3040, 1900, 2975, 2085]
Ep: 2 , 12800/60000,  L: 4.764, A: 0.389, dist [0, 0, 0, 0, 3003, 1

In [941]:
model.leaf_dist

[Variable containing:
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.0291  0.0282  0.0724  0.0615  0.0158  0.3144  0.0237  0.0089  0.4373  0.0088
  0.02

In [925]:
acc, correct, total, dist = acc_calc(test_loader, model)

In [926]:
correct

1936

In [927]:
total

10000

In [936]:
list(dist.long().numpy())

[7586, 0, 0, 2414]

In [914]:
model.leaves[0]

Parameter containing:
-2.0772
-0.6781
-1.6316
-1.6203
 0.3809
-1.5100
-1.1012
 1.9916
-1.4977
 1.8553
[torch.FloatTensor of size 10]

In [915]:
model.leaves[1]

Parameter containing:
-2.2890
 2.0111
-0.4270
 0.0372
-1.9340
 0.5527
-1.2689
-0.9730
 1.6722
-1.1189
[torch.FloatTensor of size 10]