## Necessary Imports

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import os, sys

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random

## Siamese networks

In [5]:
from dataloader import SiameseTestData_ImageFolder, SiameseTrainData_ImageFolder
from networks import SiameseNet
from losses import ContrastiveLoss
import torchvision.models as models
import torch.optim.lr_scheduler as lr_scheduler
from utils import AverageMeter
import time

#### Hyperparameters
- input size: the input size of the image for the model
- learning rate: the learning rate for ADAM
- epochs: total number of epochs to train
- batch_size: size of the batch for the training data
- num_workers: number of works for the dataloader
- way: the n-way split for the test set

In [6]:
input_size = 105
learning_rate = 0.0005
epochs = 100
sched_reset = 0
batch_size = 256
num_workers = 4
way = 20

In [7]:
cuda = False
pin_memory = False
if torch.cuda.is_available():
    device = torch.device("cuda")
    cuda = True
    pin_memory = True
else:
    device = torch.device("cpu")
device

device(type='cuda')

#### Data transforms and Load Data
- Train
    - Grayscale: ImageFolder returns RGB images, convert this to 1-channel
    - Resize: Resize to input_size
    - RandomHorizontalFlip: Flip the image horizontally (maybe should not be doing this)
    - ToTensor: Convert PIL image to tensor
    - Normalize the images
- Val: Same as train except for the random horizontal flip

In [8]:
data_transforms = {
        'train': transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((input_size, input_size)),
            transforms.RandomRotation(10),
#             transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        'val': transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    }

Load the datasets using the generic ImageFolder class from pytorch.  
The train, valid, test split is done in `train_test_split.ipynb`

In [9]:
trainset = datasets.ImageFolder('./omniglot_data/changed/train', transform=data_transforms['train'])
valset = datasets.ImageFolder('./omniglot_data/changed/valid', transform=data_transforms['val'])
testset =  datasets.ImageFolder('./omniglot_data/changed/test', transform=data_transforms['val'])

Convert the ImageFolder datasets to a Siamese form, i.e. which returns 2 imgs.  
Classes are present in `dataloader.py`

In [10]:
train_siamese = SiameseTrainData_ImageFolder(trainset)
val_siamese = SiameseTestData_ImageFolder(valset, times=int(len(valset)/way))

Create the dataloaders

In [11]:
trainloader = torch.utils.data.DataLoader(train_siamese, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
valloader = torch.utils.data.DataLoader(val_siamese, batch_size=way, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

#### Define the Siamese Network according to the Koch et al. paper

In [12]:
class KochNetv2(nn.Module):

    def __init__(self):
        super(KochNetv2, self).__init__()
        self.features = nn.Sequential(
            # 1x105x105
            nn.Conv2d(1, 64, kernel_size=10),
            # 64x96x96
            nn.ReLU(inplace=True), 
            nn.MaxPool2d(kernel_size=2),
            # 64x48x48
            nn.Conv2d(64, 128, kernel_size=7),
            # 128x42x42
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            # 128x21x21
            nn.Conv2d(128, 128, kernel_size=4),
            # 128x18x18
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            # 128x9x9
            nn.Conv2d(128, 256, kernel_size=4),
            # 256x6x6
            nn.ReLU(inplace=True),
        )
        self.fc = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.Sigmoid(),
            nn.Linear(4096, 128)
        )
#         self.output = nn.Linear(4096, 1)
    
    def forward_one(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def forward(self, x1, x2):
        out1 = self.forward_one(x1)
        out2 = self.forward_one(x2)
#         dist = torch.abs(out1 - out2)
#         out = self.output(dist)
        return out1, out2

#### Initialize the model and optimizer
- Use ADAM optimizer with previously defined learning rate

In [13]:
model = KochNetv2()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
optimizer.zero_grad()
if cuda:
    model.to(device)

#### Define the loss function and (optionally) learning rate scheduler
- Binary cross entropy loss as specified by Koch et. al
- Cosine Annealing scheduler

In [14]:
from losses import ContrastiveLoss

In [15]:
if sched_reset != 0:
    T_max = sched_reset
else:
    T_max = epochs
# eta_min = 0.01 
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
criterion = ContrastiveLoss()

#### Functions to train and validate for 1 epoch

In [16]:
def train(train_loader, model, criterion, optimizer, epoch, device, debug=False, print_freq=200):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for batch_idx, (imgs1, imgs2, targets) in enumerate(train_loader):
        data_time.update(time.time() - end)

        imgs1 = imgs1.to(device).float()
        imgs2 = imgs2.to(device).float()
        targets = targets.to(device).float()
        
        optimizer.zero_grad()
#         print(imgs1, imgs2)
        out1, out2 = model(imgs1, imgs2)
        output = F.pairwise_distance(out1, out2)
#         print(output)
#         print(out1)
#         print('out2', out2)
#         print('targets', targets)
        loss = criterion(out1, out2, targets)
#         print(loss)
        losses.update(loss.item(), imgs1.size(0))
        
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                      epoch, batch_idx, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses))
        if debug:
            break

    return losses.avg

In [17]:
def validate(val_loader, model, epoch, device, print_freq=100):
    batch_time = AverageMeter()
    accuracy = AverageMeter()
    # switch to evaluate mode
    model.eval()
    correct, wrong = 0, 0
    with torch.no_grad():
        end = time.time()
        for batch_idx, (imgs1, imgs2) in enumerate(val_loader):
            imgs1 = imgs1.to(device).float()
            imgs2 = imgs2.to(device).float()
            
            out1, out2 = model(imgs1, imgs2)
#             print(out1.shape, out2.shape)
            output = F.pairwise_distance(out1, out2)
#             print(output)
            pred = np.argmin(output.cpu().numpy())
#             print(pred)
            if pred == 0:
                correct += 1
            else: 
                wrong += 1
           
            acc = float(correct)/(correct+wrong)
            accuracy.update(acc, correct+wrong)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
#             break
    print('Test: [{0}][{1}/{2}]\t'
          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
          'Correct {correct} \t Wrong {wrong}\t'
          'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t'.format(
              epoch, batch_idx, len(val_loader), batch_time=batch_time,
              correct=correct, wrong=wrong,
              acc=accuracy))
        
    return accuracy.avg

### Train the model

In [18]:
train_losses = []
val_accs = []
epoch_time = AverageMeter()
ep_end = time.time()
best_acc = 0.5
for e in range(epochs):
    scheduler.step()
    print('*'*70)
    train_loss = train(trainloader, model, criterion, optimizer, e, device, False, 2)
    train_losses.append(train_loss)

    print('-'*70)
    val_acc = validate(valloader, model, e, device) 
    print('Avg validation acc: {:.3f}'.format(val_acc))
    val_accs.append(val_acc)
    print('-'*70)
        
    if best_acc < val_acc and e % 3 == 0:
        best_acc = val_acc
        model_path = os.path.join('weights', 'kochnetv2_{}.pth'.format(e))
        torch.save(model.state_dict(), model_path)
        print('New model saved to {} with accuracy {:.3f}'.format(model_path, best_acc))
        
    epoch_time.update(time.time() - ep_end)
    ep_end = time.time()
    print('Epoch {}/{}\t'
          'Time {epoch_time.val:.3f} ({epoch_time.avg:.3f})'.format(e, epochs - 1, epoch_time=epoch_time))
    
    # restarts
    if sched_reset != 0 and e % sched_reset == 0 and e > 0:
        print('$'*70)
        print('WARM RESTART')
        sched_reset = sched_reset * 2
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        T_max = sched_reset
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)

**********************************************************************
Epoch: [0][0/63]	Time 1.754 (1.754)	Data 0.846 (0.846)	Loss 1.6473 (1.6473)
Epoch: [0][2/63]	Time 0.911 (1.191)	Data 0.000 (0.282)	Loss 0.9975 (1.2172)
Epoch: [0][4/63]	Time 0.909 (1.078)	Data 0.000 (0.169)	Loss 1.1729 (1.2054)
Epoch: [0][6/63]	Time 0.908 (1.029)	Data 0.000 (0.121)	Loss 1.0452 (1.1832)
Epoch: [0][8/63]	Time 0.907 (1.002)	Data 0.000 (0.094)	Loss 1.1933 (1.1672)
Epoch: [0][10/63]	Time 0.907 (0.985)	Data 0.000 (0.077)	Loss 1.0264 (1.1409)
Epoch: [0][12/63]	Time 0.906 (0.972)	Data 0.000 (0.065)	Loss 1.0775 (1.1302)
Epoch: [0][14/63]	Time 0.905 (0.964)	Data 0.000 (0.057)	Loss 1.0063 (1.1150)
Epoch: [0][16/63]	Time 0.904 (0.957)	Data 0.000 (0.050)	Loss 1.0226 (1.1032)
Epoch: [0][18/63]	Time 0.906 (0.951)	Data 0.000 (0.045)	Loss 1.0070 (1.0928)
Epoch: [0][20/63]	Time 0.902 (0.947)	Data 0.000 (0.041)	Loss 1.0130 (1.0850)
Epoch: [0][22/63]	Time 0.906 (0.943)	Data 0.000 (0.037)	Loss 1.0088 (1.0784)
Epoch: [0]

Process Process-42:
Process Process-44:
Process Process-43:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Process Process-41:
Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/var/anaconda3/envs/diss/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i i

Epoch: [5][10/63]	Time 0.902 (0.970)	Data 0.000 (0.083)	Loss 1.0144 (1.0209)


KeyboardInterrupt: 