In [1]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from PIL import Image

In [2]:
# maxmium number of epochs to train the model
max_epoch = 10

# number of iterations in each epoch
iter_per_epoch = 50

# number of samples in each iteration
batchSize = 64

# gpu option. set 1 if available, else 0
gpu = 1

# learning rate used for Adam optimizer
learn_rate = 0.01

# momentum parameter for Adam
momentum = 0.9

# weight decay
weightDecay = 0.0005

# left patch is searched for 2*half_range + 1 locations
#in the right patch.

# If left patch is 37*37, and we consider 
# right patch of size 37 * (37+2*half_range)

half_range = 100

In [3]:
class Net(nn.Module):    
    def __init__(self, nChannel, max_dips):
        super(Net, self).__init__()                        
        self.l_max_dips = max_dips
        self.conv1 = nn.Conv2d(nChannel, 32, 5)    # first conv layer: 32 filters of size 5x5
        self.batchnorm1 = nn.BatchNorm2d(32, 1e-3) # first batch normalization layer
        
        self.conv2 = nn.Conv2d(32, 32, 5)          # second conv layer: 32 filters of size 5x5
        self.batchnorm2 = nn.BatchNorm2d(32, 1e-3) # second normalization layer
        
        self.conv3 = nn.Conv2d(32, 64, 5)          # third conv layer: 64 filters of size 5x5
        self.batchnorm3 = nn.BatchNorm2d(64, 1e-3) # third batch normalization layer
        
        self.conv4 = nn.Conv2d(64, 64, 5)          # fourth conv layer: 64 filters of size 5x5
        self.batchnorm4 = nn.BatchNorm2d(64, 1e-3) # fourth batch normalization layer
        
        self.conv5 = nn.Conv2d(64, 64, 5)          # fifth conv layer: 64 filters of size 5x5
        self.batchnorm5 = nn.BatchNorm2d(64, 1e-3) # fifth batch normalization layer
        
        self.conv6 = nn.Conv2d(64, 64, 5)          # sixth conv layer: 64 filters of size 5x5
        self.batchnorm6 = nn.BatchNorm2d(64, 1e-3) # sixth batch normalization layer
        
        self.conv7 = nn.Conv2d(64, 64, 5)          # seventh conv layer: 64 filters of size 5x5
        self.batchnorm7 = nn.BatchNorm2d(64, 1e-3) # seventh batch normalization layer
        
        self.conv8 = nn.Conv2d(64, 64, 5)          # eighth conv layer: 64 filters of size 5x5
        self.batchnorm8 = nn.BatchNorm2d(64, 1e-3) # eigth batch normalization layer        
            
        self.conv9 = nn.Conv2d(64, 64, 5)          # ninth conv layer: 64 filters of size 5x5
        self.batchnorm9 = nn.BatchNorm2d(64, 1e-3) # ninth batch normalization layer        
        self.logsoftmax = nn.LogSoftmax()                        
            
    def forward_pass(self, x):
        x = self.conv1(x)                
        x = F.relu(self.batchnorm1(x))
        
        x = self.conv2(x)
        x = F.relu(self.batchnorm2(x))
        
        x = self.conv3(x)
        x = F.relu(self.batchnorm3(x))
        
        x = self.conv4(x)
        x = F.relu(self.batchnorm4(x))
        
        x = self.conv5(x)
        x = F.relu(self.batchnorm5(x))
        
        x = self.conv6(x)
        x = F.relu(self.batchnorm6(x))
        
        x = self.conv7(x)
        x = F.relu(self.batchnorm7(x))
        
        x = self.conv8(x)
        x = F.relu(self.batchnorm8(x))
        
        x = self.conv9(x)
        x = self.batchnorm9(x)
        return x
             
    def forward(self, x1, x2):
        # forward pass left patch of 37x37
        x1 = self.forward_pass(x1)
        # forward pass right patch of 37*237
        x2 = self.forward_pass(x2)
        
        # left patch feature vector (1,64) dimension
        x1 = x1.view(x1.size(0),1,64)        
        
        # right patch feature matrix (64, 201) dimension
        # at 201 locations        
        x2 = x2.squeeze().view(x2.size(0),64,self.l_max_dips)
        
        # multiply the features to get correlation at 201 location
        x3 = x1.bmm(x2).view(x2.size(0),self.l_max_dips)
        
        # compute log p_i(y_i) of scores
        x3 = self.logsoftmax(x3)
        
        return x1,x2,x3        

In [4]:
# three pixel error
def loss_function(x3, t, w):    
    error = 0
    for i in range(x3.size(0)):          
        # scores at ground truth target locations.
        # instead of taking single score at exact target location
        # take two more locations on either side
        # and weigh them.
        sc = x3[i,t[i][0]-2:t[i][0]+2+1] 
        
         #class_weight_y_i* log p_i(y_i)
        loss_sample = torch.mul(sc, w).sum()        
        
        error = error - loss_sample
    return error

In [5]:
model = Net(3, half_range*2+1)
optimizer = optim.Adam(model.parameters(), lr=learn_rate, eps=1e-08, weight_decay=weightDecay)
class_wts = Variable(torch.Tensor([1, 4, 10, 4, 1]))

if gpu:
    model = model.cuda()
    class_wts = class_wts.cuda()

In [6]:
max_epoch = 5
from torch.utils.serialization import load_lua
model.train() # train mode.

Net (
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
  (batchnorm1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (batchnorm2): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (batchnorm3): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True)
  (conv4): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (batchnorm4): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True)
  (conv5): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (batchnorm5): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True)
  (conv6): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (batchnorm6): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True)
  (conv7): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (batchnorm7): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True)
  (conv8): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (batchnorm8): B

In [7]:
left_patches = load_lua('data/left_patches.t7')
right_patches = load_lua('data/right_patches.t7')
targets = load_lua('data/targets.t7')

In [9]:
for epoch in range(max_epoch):
    train_loss = 0    
    for _iter in range(iter_per_epoch):        
        
        # zero the gradient buffers
        optimizer.zero_grad()   
        
        # sample batch data
        id1 = epoch*iter_per_epoch*batchSize +  _iter*batchSize
        id2 = epoch*iter_per_epoch*batchSize + (_iter+1)*batchSize        
        left_batch = left_patches[id1:id2, :, :, :]
        right_batch = right_patches[id1:id2, :, :, :]
        t_batch = targets[id1:id2].view(batchSize,1).int()
        
        # convert to cuda if gpu available
        if gpu:
            left_batch = left_batch.cuda()
            right_batch = right_batch.cuda()
            t_batch = t_batch.cuda()
         
        # forward pass
        x1, x2, x3 = model(Variable(left_batch), Variable(right_batch))   
        
        # compute loss
        loss = loss_function(x3, t_batch, class_wts)
        
        # backward pass. compute gradients
        loss.backward()
        
        # update the weights
        optimizer.step()
        
        train_loss+=loss.data[0]
        
    print 'Loss at epoch ', epoch, train_loss/iter_per_epoch

Loss at epoch  0 6940.05801758
Loss at epoch  1 5446.15808594
Loss at epoch  2 5080.10720703
Loss at epoch  3 4862.2950293
Loss at epoch  4 4546.17910156
