# Distilling a Neural Network into a soft decision tree

https://arxiv.org/pdf/1711.09784.pdf

In [1]:
import os
import time

import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

### Setup our dataset loader

In [2]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline

class myNestedImgDataset(Dataset):
    def __init__(self, dir_path, transform=None, test=False):
        self.dir_path = dir_path
        self.transform = transform
        self.classes = [x for x in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path,x))]
        self.img_paths = []
        self.labels = []
        self.test = test
        if self.test:
            class_img_paths = [os.path.join(dir_path,x) for x in os.listdir(dir_path)]
            self.img_paths.extend(class_img_paths)
        else:
            for class_idx, folder_name in enumerate(self.classes):
                prefix = os.path.join(dir_path,folder_name)
                class_img_paths = [os.path.join(prefix,x) for x in os.listdir(prefix)]
                self.img_paths.extend(class_img_paths)
                self.labels.extend(np.ones(len(class_img_paths))*class_idx)
            
            self.labels = [int(x) for x in self.labels]
            
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self,idx):
        if self.test:
            return torch.FloatTensor(plt.imread(self.img_paths[idx])), None
        else:
            return torch.FloatTensor(plt.imread(self.img_paths[idx])), self.labels[idx]
    
    def show(self,idx):
        return plt.imshow(mpimg.imread(self.img_paths[idx]), cmap='Greys')
    

class Img2FlatVec(Dataset):
    def __init__(self, dataset):
        self.dataset= dataset
        self.n = len(self.dataset)
        samp_img, _ = self.dataset[1]
        self.h, self.w = samp_img.shape
    
    def __getitem__(self,idx):
        x, y = self.dataset[idx]
        x_flat = x 
        return x.view(self.h*self.w)/255, y

    def __len__(self): return self.n

We use soft binary decision trees trained with mini-batch gradient descent, where
each inner node i has a learned filter **`w_i`** and a bias **`b_i`**
, and each leaf node has a learned distribution **`Q`**. At each inner node, the probability of taking the
rightmost branch is:

$$p_i(x) = \sigma (xw_i + b_i)$$

where **`x`** is the input to the model and **$\sigma$** is the sigmoid logistic function.

For example if we have a 2 x 2 image, and 3 classes

In [3]:
w = torch.randn((2*2,3))
x = torch.randn((2*2,1))
b = torch.randn((2*2,1))

In [4]:
w,x

(
 -0.4307 -0.3382  0.8263
 -1.1472 -0.0171  1.4875
 -0.6736  0.3648  0.2261
  1.3679 -0.9086 -1.7713
 [torch.FloatTensor of size 4x3], 
 -2.1332
 -1.3903
  0.7608
 -0.5484
 [torch.FloatTensor of size 4x1])

In [5]:
x*w + b


-0.5247 -0.7219 -3.2061
 3.0537  1.4827 -0.6091
-0.2432  0.5469  0.4414
-1.2541 -0.0056  0.4675
[torch.FloatTensor of size 4x3]

In [226]:
F.sigmoid(x*w + b)

RuntimeError: mul() received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:
 * (float other)
      didn't match because some of the arguments have invalid types: ([31;1mtorch.FloatTensor[0m)
 * (Variable other)
      didn't match because some of the arguments have invalid types: ([31;1mtorch.FloatTensor[0m)


here **`x`** is the input to the model and σ is the sigmoid logistic function.
This model is a hierarchical mixture of experts [Jordan and Jacobs, 1994], but
each expert is a actually a bigot who does not look at the data after training, and
therefore always produces the same distribution. The model learns a hierarchy of
filters that are used to assign each example to a particular bigot with a particular
path probability, and each bigot learns a simple, static distribution over the
possible output classes, **`k`**.

$$ Q_k^l = \frac{exp(\phi_k^l)}{\sum_{k'}{exp(\phi_k^l)}}$$

where `Q` denotes the probability distribution at the lth leaf, and each $\phi$. is a learned parameter at that leaf.

In [7]:
max_Q = torch.zeros(2*2,1)
max_Q


 0
 0
 0
 0
[torch.FloatTensor of size 4x1]

In [8]:
Q_init = torch.rand(3,1)
Q_init


 0.4271
 0.7438
 0.0892
[torch.FloatTensor of size 3x1]

In [9]:
Q_nn = nn.Parameter(Q_init.view(1,-1))
print(Q_nn)

Parameter containing:
 0.4271  0.7438  0.0892
[torch.FloatTensor of size 1x3]



In [10]:
sm = torch.nn.Softmax(dim=1)
sm(Q_nn.view(1,-1))

Variable containing:
 0.3240  0.4448  0.2311
[torch.FloatTensor of size 1x3]

## Prototyping a Soft Decision Tree

In [11]:
bz = 64
input_dim = 28*28
no_classes = 10
max_depth = 8
epochs = 4
lr = 0.01
lmbda = 0.1
momentum = 0.5
seed = 1
cuda = False
log_interval = 10

### Prototyping a `LeafNode`

The leaf node will give the softmax calculation for **$n$** number of classes. For tree consistency, it will also implement a `reset` method, and a `forward` method. 

<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/c/c7/Lisc_lipy.jpg/220px-Lisc_lipy.jpg' />

In [169]:
class LeafNode():
    def __init__(self,batch_size, no_classes, cuda=False):
        self.param = torch.randn(no_classes)
        if cuda:
            self.param = self.param.cuda()
        self.param = nn.Parameter(self.param)
        self.batch_size = batch_size
        self.leaf = True
        self.softmax_func = nn.Softmax(1)
        
    def forward(self):
        return(self.softmax_func(self.param.view(1,-1)))
    
    def reset(self):
        pass
    
    def calc_prob(self, x, path_prob):
        # the x is kept there for consistency
        # is because its at the tail of the tree
        Q = self.forward()
        
        ## broad casting for every image in the batchsize, same probability for number of classes
        Q = Q.expand(self.batch_size, no_classes)
        return([[path_prob,Q]])


### Testing our `LeafNode Object`

In [170]:
# given the following parameters ...
batch_size = 6 # say 15 images per batch
no_classes = 3
input_dim = 2*2
path_prob = Variable(torch.ones(batch_size, 1))
x = torch.randn(batch_size,input_dim)    # approximating some images

# create a leaf node
ln = LeafNode(batch_size, no_classes, False)
print(' ============ Testing  leaf node softmax ============')
print(ln.forward())


print('============ Testing probability calculation per item in batch ============')
print(ln.calc_prob(x, path_prob))

Variable containing:
 0.4693  0.2613  0.2694
[torch.FloatTensor of size 1x3]

[[Variable containing:
 1
 1
 1
 1
 1
 1
[torch.FloatTensor of size 6x1]
, Variable containing:
 0.4693  0.2613  0.2694
 0.4693  0.2613  0.2694
 0.4693  0.2613  0.2694
 0.4693  0.2613  0.2694
 0.4693  0.2613  0.2694
 0.4693  0.2613  0.2694
[torch.FloatTensor of size 6x3]
]]


### A note on regularization

To avoid getting stuck at poor solutions during the training, we introduced a penalty term that encourage each internal node to make equal use of both left and right sub-tress. Without this penalty, the tree tended to get stuck on plateaus in which one or more of the internal nodes always assigned almost all the probability to one of its sub-tres and the gradient of the logistic for this decision was always very close to zero. 

The penalty is the cross entropy between teh diered average distribution (0.5, 0.5) for the two sub-tress and teh actual average distribution $\alpha$. Where $\alpha$ for a node $i$ is given by:

$$\alpha_i = \frac {\sum_x P^i(x)p_i(x)}{\sum_x P^i(x)}$$

where $P^i(x)$ is the path probability from teh root node to node i. The penalty summed over all internal nodes is then:

$$ C = -\lambda \sum 0.5 log(\alpha_i) + 0.5 log(1-\alpha_i)$$

#### On Lambda:

is a hyper-parameter that determines the strength of the penalty and is set prior to training. We found that we achieved better test accuracy results when the strength of the penalty decayed exponentially with the depth d of the
node in the tree so that it was proportional to $2^{−d}$

## Prototyping Inner node

<img src='http://mimi.kaktusteam.de/uploads/pics/simple_tree.jpg' />

In [171]:
class InnerNode():
    def __init__(self, depth, batch_size, input_dim, no_classes, lmbda,cuda=False,tree_depth=1):
        self.depth = depth
        self.tree_depth = tree_depth
        self.input_dim = input_dim
        self.no_classes = no_classes
        self.cuda = cuda
        self.batch_size = batch_size
        self.fc_layer = nn.Linear(self.input_dim,1)
        
        self.beta = torch.randn(1)
        if cuda:
            self.beta = self.beta.cuda()
        self.beta = nn.Parameter(self.beta)
        self.leaf = False
        self.prob = None
        
        # for regularization the strength of the penalty decays (as mentioned in the paper)
        self.lmbda = lmbda * 2** (-tree_depth)
        
        if depth > 1:
            # recursive part
            self.left = InnerNode(depth-1, batch_size, input_dim, no_classes, lmbda,cuda, tree_depth+1)
            self.right = InnerNode(depth-1, batch_size, input_dim, no_classes, lmbda,cuda, tree_depth+1)
        else:
            # when the depth is exactly 1, then only have 2 leaf nodes
            self.left = LeafNode(batch_size, no_classes, cuda)
            self.right = LeafNode(batch_size, no_classes, cuda)
            
        self.all_leaf_probs = []
        self.prob_dict = {}
        self.penalties = []
        
        
    def reset(self):
        self.all_leaf_probs = []
        self.penalties = []
        
        # recursively 
        self.left.reset()
        self.right.reset()
        
        
    def forward(self, x):  
        # this is the branch probability calculation
        return (F.sigmoid(self.beta*self.fc_layer(x)))
    
    def calc_prob(self, x, path_prob):
        # calculate the inner probability with sigmoid
        self.prob = self.forward(x)
        
        # store the current path probability
        self.path_prob = path_prob
        
        # pull the Q prob distributions from left and right leaves
        left_probs = self.left.calc_prob(x, path_prob*(1-self.prob))
        right_probs = self.right.calc_prob(x, path_prob*(self.prob))
        
        # append them to master list
        self.all_leaf_probs.extend(left_probs)
        self.all_leaf_probs.extend(right_probs)
        
        # return only the leaf prob distributions
        return (self.all_leaf_probs)
      
    def select_next(self, x):
        # the probability is defined as probability of the right side
        prob = self.forward(x)
        
        if prob < 0.5: 
            return(self.left, prob)
        else:
            return(self.right, prob)
    
    
    def get_penalty(self):
        alpha_num = torch.sum(self.path_prob*self.prob) 
        alpha_den = torch.sum(self.path_prob)
        alpha = alpha_num / alpha_den
        C_i = -self.lmbda * 0.5 * (torch.log(alpha) + torch.log(1-alpha))
        
        self.penalties.append(C_i)
        if not self.left.leaf:
            left_C_i = self.left.get_penalty()
            right_C_i = self.right.get_penalty()
            self.penalties.extend(left_C_i)
            self.penalties.extend(right_C_i)
        return (self.penalties)
    
    def collect_params(self):
        self.module_list = []
        self.param_list = []
        self.module_list.append(self.fc_layer)
        self.param_list.append(self.beta)        
        if self.left.leaf:
            self.param_list.append(self.left.param)
        else:
            mod, params = self.left.collect_params()
            self.module_list.extend(mod)
            self.param_list.extend(params)
            
        if self.right.leaf:
            self.param_list.append(self.right.param)
        else:
            mod, params = self.right.collect_params()
            self.module_list.extend(mod)
            self.param_list.extend(params)
        
        return(self.module_list, self.param_list)

### Testing out our inner node class

In [172]:
from pprint import pprint
# given the following parameters ...
batch_size = 8 # say 15 images per batch
no_classes = 10
input_dim = 2*2
lmbda = 0.01
depth = 1
path_prob = Variable(torch.ones(batch_size, 1))
x = Variable(torch.randn(batch_size,input_dim)) # approximating some images

inner_N = InnerNode(depth, batch_size, input_dim, no_classes,lmbda)
print('=============== Probabilities and Path Probabilities =============== ')
res = inner_N.calc_prob(x, path_prob)
print('Number of leafs: %d' % len(res))
#print(res)
total = Variable(torch.zeros((batch_size,1)))
for row in res:
    print('=== path Prob ===')
    print(row[0])
    print('=== Q Dist ===')
    print(row[1])
    total += row[0]
print('== checking total probability for each of the img in the batch (rows)')
print(total)

print('=============== Penalities =============== ')
print(inner_N.get_penalty())

Number of leafs: 2
=== path Prob ===
Variable containing:
 0.4057
 0.3627
 0.5319
 0.5258
 0.4195
 0.3670
 0.5704
 0.3298
[torch.FloatTensor of size 8x1]

=== Q Dist ===
Variable containing:
 0.1501  0.0525  0.0990  0.1238  0.0192  0.0679  0.0708  0.1899  0.1823  0.0445
 0.1501  0.0525  0.0990  0.1238  0.0192  0.0679  0.0708  0.1899  0.1823  0.0445
 0.1501  0.0525  0.0990  0.1238  0.0192  0.0679  0.0708  0.1899  0.1823  0.0445
 0.1501  0.0525  0.0990  0.1238  0.0192  0.0679  0.0708  0.1899  0.1823  0.0445
 0.1501  0.0525  0.0990  0.1238  0.0192  0.0679  0.0708  0.1899  0.1823  0.0445
 0.1501  0.0525  0.0990  0.1238  0.0192  0.0679  0.0708  0.1899  0.1823  0.0445
 0.1501  0.0525  0.0990  0.1238  0.0192  0.0679  0.0708  0.1899  0.1823  0.0445
 0.1501  0.0525  0.0990  0.1238  0.0192  0.0679  0.0708  0.1899  0.1823  0.0445
[torch.FloatTensor of size 8x10]

=== path Prob ===
Variable containing:
 0.5943
 0.6373
 0.4681
 0.4742
 0.5805
 0.6330
 0.4296
 0.6702
[torch.FloatTensor of size 8x1]


### Defining a Loss Function

In [173]:
bz = 3
mnist_trn = myNestedImgDataset('/Users/timlee/data/MNIST/trn/')
trn_vec = Img2FlatVec(mnist_trn)
trn_dl = DataLoader(trn_vec, batch_size=bz, shuffle=True, num_workers=4)
x_test, y_test = iter(trn_dl).next()
x_var = Variable(x_test)
y_var = y_test

In [174]:
y_test


 2
 4
 3
[torch.LongTensor of size 3]

In [182]:
def target2onehot(batch_size, no_classes, y):  
    """
    takes in a single vector 
    """
    template = torch.FloatTensor(batch_size, no_classes)
    template.zero_()
    template = Variable(template)
    if type(y) == torch.autograd.variable.Variable:
        target = y
    else:
        target = Variable(y)
    template.scatter_(1,target.view(-1,1),1)
    
    # was getting nan's in some of the small numbers
    template[template != template] = 0
    return template

bz = 3
no_classes = 10
t2o_test = target2onehot(bz, no_classes, y_test)
print(y_test)
torch.sum(t2o_test,dim=0)


 2
 4
 3
[torch.LongTensor of size 3]



Variable containing:
 0
 0
 1
 1
 1
 0
 0
 0
 0
 0
[torch.FloatTensor of size 10]

In [183]:
tr = Variable(torch.randn(no_classes))
sm = nn.Softmax(1)
Q_sample = sm(tr.view(1,-1))
print(Q_sample)
Q_batch_sample = Q_sample.expand(bz, no_classes)
print(Q_batch_sample)

Variable containing:
 0.2440  0.0179  0.0367  0.0441  0.0869  0.2843  0.0697  0.0750  0.0745  0.0669
[torch.FloatTensor of size 1x10]

Variable containing:
 0.2440  0.0179  0.0367  0.0441  0.0869  0.2843  0.0697  0.0750  0.0745  0.0669
 0.2440  0.0179  0.0367  0.0441  0.0869  0.2843  0.0697  0.0750  0.0745  0.0669
 0.2440  0.0179  0.0367  0.0441  0.0869  0.2843  0.0697  0.0750  0.0745  0.0669
[torch.FloatTensor of size 3x10]



In [184]:
t2o_test

Variable containing:
    0     0     1     0     0     0     0     0     0     0
    0     0     0     0     1     0     0     0     0     0
    0     0     0     1     0     0     0     0     0     0
[torch.FloatTensor of size 3x10]

In [189]:
print(Q_batch_sample.shape, t2o_test.shape)

bb = t2o_test.view(bz,1,no_classes)
#print(bb)

aa = Q_batch_sample.contiguous().view(bz, no_classes, 1)
#print(aa)

torch.bmm(bb, torch.log(bb))

torch.Size([3, 10]) torch.Size([3, 10])


RuntimeError: invalid argument 2: wrong matrix size, batch1: 1x10, batch2: 1x10 at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1638

In [190]:
path_prob = Variable(torch.ones(batch_size, 1))
resr = path_prob * torch.bmm(bb, aa)
resr

Variable containing:
(0 ,.,.) = 
1.00000e-02 *
   3.6699
   3.6699
   3.6699
   3.6699
   3.6699
   3.6699
   3.6699
   3.6699

(1 ,.,.) = 
1.00000e-02 *
   8.6950
   8.6950
   8.6950
   8.6950
   8.6950
   8.6950
   8.6950
   8.6950

(2 ,.,.) = 
1.00000e-02 *
   4.4149
   4.4149
   4.4149
   4.4149
   4.4149
   4.4149
   4.4149
   4.4149
[torch.FloatTensor of size 3x8x1]

In [191]:
resr.mean()

Variable containing:
1.00000e-02 *
  5.5933
[torch.FloatTensor of size 1]

In [198]:
def node_loss( y, Q, path_prob, batch_size, no_classes):
    """
    takes in target label y (batch size x 1)
    and probability dist (batch size x class)

    returns a (batch size x 1) which is the loss constant per batch size
    """
    target = target2onehot(batch_size, no_classes, y)
    
    T_k = torch.sum(target,dim=0).view(-1,1)
    print(T_k)
    logQ_k = torch.log(Q).view(batch_size, no_classes, 1)
    print(logQ_k)
    TQ = torch.bmm(T_k,logQ_k).view(-1,1)
    print(TQ)
    return(TQ*path_prob)


path_prob = Variable(torch.ones(bz, 1))
node_loss(y_test, Q_batch_sample, path_prob, bz, no_classes)

Variable containing:
    0
    0
    1
    1
    1
    0
    0
    0
    0
    0
[torch.FloatTensor of size 10x1]

Variable containing:
(0 ,.,.) = 
 -1.4108
 -4.0234
 -3.3050
 -3.1202
 -2.4424
 -1.2576
 -2.6641
 -2.5905
 -2.5970
 -2.7048

(1 ,.,.) = 
 -1.4108
 -4.0234
 -3.3050
 -3.1202
 -2.4424
 -1.2576
 -2.6641
 -2.5905
 -2.5970
 -2.7048

(2 ,.,.) = 
 -1.4108
 -4.0234
 -3.3050
 -3.1202
 -2.4424
 -1.2576
 -2.6641
 -2.5905
 -2.5970
 -2.7048
[torch.FloatTensor of size 3x10x1]



RuntimeError: invalid argument 1: expected 3D tensor, got 2D at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1630

### Prototyping the decision tree

The loss function:

$$L(x) = -log(\sum_{leaf nodes} P^l(x)\sum_k T_k log{Q_k^l})$$

- $Q_k^l$ - represents the distribution of probabilities for `k` classes at `l` leaf
- $T_k$ - represents the target class (think one hot encoding)
- $P^l(x)$ - represents the path probability at that leaf (this will be treated as compound prob)


1. **`Q_k`** - the distribution matrix size is `batch_size x classes`
2. **`T_k`** - the target matrix size is `batch size x classes`
3. **`path_prob`** - the path probability is `batch size x 1` per node

#### Notes on Path Probability:



In [221]:
class SoftDTree(nn.Module):
    def __init__(self, 
                 batch_size = 64, 
                 input_dim = 28*28, 
                 no_classes = 10, 
                 max_depth=8,
                 epochs = 4,
                 lr = 0.01,
                 lmbda = 0.1,
                 momentum = 0.5,
                 seed = 1,
                 cuda = False,
                 log_interval = 10
                ):
        
        # parameters
        super(SoftDTree, self).__init__()
        self.batch_size = batch_size
        self.input_dim = input_dim
        self.no_classes = no_classes
        self.max_depth = max_depth
        self.epochs = epochs
        self.lr = lr
        self.lmbda = lmbda
        self.momentum = momentum
        self.seed = seed
        self.cuda = cuda
        self.log_interval = log_interval

        #setup target structures
        self.root = InnerNode(self.max_depth, self.batch_size, self.input_dim, self.no_classes, self.lmbda)
       
        
        # collects all the parameters to optimize from the nested nodes
        self.collect_params() 

        # training objects
        self.optimizer = optim.SGD(self.parameters(), 
                                   lr=lr, 
                                   momentum=momentum)
        
        self.initialize()
        self.test_acc = []
        self.best_accuracy = 0.0
        torch.manual_seed(seed)
        
        if cuda:
            torch.cuda.manual_seed(seed)

        
        
    def initialize(self, batch_size = None):
        if batch_size is None:
            batch_size = self.batch_size
        self.target_onehot = torch.FloatTensor(batch_size, self.no_classes)
        self.target_onehot = Variable(self.target_onehot)
        self.path_prob_init = Variable(torch.ones(batch_size,1))
        if cuda:
            self.target_onehot = self.target_onehot.cuda()
            self.path_prob_init = self.path_prob_init.cuda()
            
        
    
    
    def target2onehot(self, y):  
        """
        takes in a single vector 
        """
        template = torch.FloatTensor(self.batch_size, self.no_classes)
        template.zero_()        
        template = Variable(template)
        if type(y) == torch.autograd.variable.Variable:
            target = y
        else:
            target = Variable(y)
        template.scatter_(1,target.view(-1,1),1)
        template[template != template] = 0
        return template
    
    
    def node_loss(self, y, Q, path_prob):
        """
        takes in target label y (batch size x 1)
        and probability dist (batch size x class)
        
        returns a (batch size x 1) which is the loss constant per batch size
        """
        target = self.target2onehot(y)
        T_k = target.view(self.batch_size, 1, self.no_classes)
        logQ_k = torch.log(Q).view(self.batch_size, self.no_classes, 1)
        TQ = torch.bmm(T_k,logQ_k).view(-1,1)
        return(TQ*path_prob)
    
    
    def most_prob_Q(self, list_prob_n_Q):
        """
        takes in a list of (path_prob, Q)
        
        based on the path_prob
        """
        # will store the max node probability per batch
        max_prob = [-1. for _ in range(self.batch_size)]
        
        # will store the most likely distribution
        max_Q = [torch.zeros(self.no_classes) for _ in range(self.batch_size)]        
        for (path_prob, Q) in list_prob_n_Q:
            path_prob_numpy = path_prob.cpu().data.numpy().reshape(-1)
            for i in range(self.batch_size):
                if max_prob[i] < path_prob_numpy[i]:
                    max_prob[i] = path_prob_numpy[i]
                    max_Q[i] = Q[i]

        return(max_prob, max_Q)
    
    
    def calc_loss(self, x, y):

        all_leaf_probs = self.root.calc_prob(x, self.path_prob_init)
        
        # based on the path (max prob), get the distribution 
        max_path_prob, max_Q = self.most_prob_Q(all_leaf_probs)        

        total_loss = torch.mean(torch.sum(torch.stack([self.node_loss(y, Q, path_prob) for path_prob, Q in all_leaf_probs]), dim=0))
        total_C = torch.sum(torch.stack(self.root.get_penalty()))
        
        output = torch.stack(max_Q)
        self.root.reset()
        
        return(-total_loss + total_C, output)
    
    
    def collect_params(self):
        self.nn_module_list = nn.ModuleList()
        self.nn_param_list = nn.ParameterList()
        mod, params = self.root.collect_params()
        self.nn_module_list.extend(mod)
        self.nn_param_list.extend(params)
        
        
    def train_(self, train_loader, epoch):
        for epoch_idx in range(1, epoch+1):
            self.train()
            self.initialize()
            for batch_idx, (data, target) in enumerate(train_loader):
                correct = 0
                if self.cuda:
                    data, target = data.cuda()
                    
                # if the batch size doesn't match (uneven division)
                if not y_test.shape[0] == self.batch_size:
                    self.initialize(batch_size)
                    
                data = data.view(self.batch_size, -1)
                data = Variable(data)

                self.optimizer.zero_grad()

                loss, output = self.calc_loss(data, target)
                loss.backward(retain_variables=True)

                self.optimizer.step()

                pred = output.data.max(1)[1]

                correct += pred.eq(target).cpu().sum()
                accuracy = 100. * correct / len(data)

                if batch_idx % self.log_interval == 0:

                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accuracy: {}/{} ({:.4f}%)'.format(
                        epoch_idx, batch_idx * len(data), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss.data[0],
                        correct, len(data),
                        accuracy))

   

In [222]:
bz = 64
input_dim = 28*28
no_classes = 10
max_depth = 2
epochs = 4 
lr = 0.01
lmbda = 0.1
momentum = 0.5
seed = 1
cuda = False
log_interval = 10



In [223]:
from torchvision import datasets, transforms

try:
    os.makedirs('./data')
except:
    print('directory ./data already exists')

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=bz, shuffle=True, **kwargs)


test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=bz, shuffle=True, **kwargs)

directory ./data already exists


In [224]:
model = SoftDTree(bz, input_dim, no_classes, max_depth, epochs, lr, lmbda, momentum, seed, False, log_interval)

In [225]:
model.train_(train_loader, 4)





RuntimeError: size mismatch, m1: [64 x 392], m2: [784 x 1] at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1416

In [212]:
# all_leaf_probs = model.root.calc_prob(x_var, model.path_prob_init)
# loss = 0.
# list_ = []
# for path_prob, Q in all_leaf_probs:
#     #print(path_prob.shape, Q.shape)
#     #print(path_prob)
#     #print(model.node_loss(y_var, Q, path_prob))
#     loss += model.node_loss(y_var, Q, path_prob)
#     list_.append(model.node_loss(y_var, Q, path_prob))
# print(torch.sum(torch.stack(list_),dim=0))
# print(loss)

In [213]:
# model.calc_loss(x_var, y_var)

In [214]:
model.train_(train_loader, 4)





RuntimeError: size mismatch, m1: [64 x 49], m2: [784 x 1] at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1416