# Verification task or Siamese neural networks training

This notebook presents the paper ["Siamese Neural Networks for One-shot Image Recognition"](https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf) coded with PyTorch framework. 

In this part we train Siamese network on the Omniglot dataset to perform the classification task to distinguish two images of the same class or different classes.


References:
- [paper](https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf)
- [omniglot](https://github.com/brendenlake/omniglot)
- [keras-oneshot](https://github.com/sorenbouma/keras-oneshot)


In [1]:
# https://ipython.org/ipython-doc/3/config/extensions/autoreload.html
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys
import numpy as np
import cv2

In [3]:
sys.path.append("..")

In [19]:
HAS_GPU = False

## Setup dataflow

In [21]:
from dataflow import OmniglotDataset, SameOrDifferentPairsDataset, PairTransformedDataset
from common_utils.imgaug import RandomAffine, RandomApply
from common_utils.dataflow import TransformedDataset, OnGPUDataLoader
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
import torch

In [11]:
np.random.seed(12345)

OMNIGLOT_REPO_PATH='omniglot'

TRAIN_DATA_PATH = os.path.join(OMNIGLOT_REPO_PATH, 'python', 'images_background')
train_alphabets = !ls {TRAIN_DATA_PATH}
train_alphabets = list(train_alphabets)

TEST_DATA_PATH = os.path.join(OMNIGLOT_REPO_PATH, 'python', 'images_evaluation')
test_alphabets = !ls {TEST_DATA_PATH}
test_alphabets = list(test_alphabets)

assert len(train_alphabets) > 1 and len(test_alphabets) > 1, "%s \n %s" % (train_alphabets[0], test_alphabets[0])

train_alphabet_char_id_drawer_ids = {}
for a in train_alphabets:
    res = !ls "{os.path.join(TRAIN_DATA_PATH, a)}"
    char_ids = list(res)
    train_alphabet_char_id_drawer_ids[a] = {}
    for char_id in char_ids:
        res = !ls "{os.path.join(TRAIN_DATA_PATH, a, char_id)}"
        train_alphabet_char_id_drawer_ids[a][char_id] = [_id[:-4] for _id in list(res)]
        
        
test_alphabet_char_id_drawer_ids = {}
for a in test_alphabets:
    res = !ls "{os.path.join(TEST_DATA_PATH, a)}"
    char_ids = list(res)
    test_alphabet_char_id_drawer_ids[a] = {}
    for char_id in char_ids:
        res = !ls "{os.path.join(TEST_DATA_PATH, a, char_id)}"
        test_alphabet_char_id_drawer_ids[a][char_id] = [_id[:-4] for _id in list(res)]


# Sample 12 drawers out of 20
all_drawers_ids = np.arange(20) 
train_drawers_ids = np.random.choice(all_drawers_ids, size=12, replace=False)
# Sample 4 drawers out of remaining 8
val_drawers_ids = np.random.choice(list(set(all_drawers_ids) - set(train_drawers_ids)), size=4, replace=False)
test_drawers_ids = np.array(list(set(all_drawers_ids) - set(val_drawers_ids) - set(train_drawers_ids)))

def create_str_drawers_ids(drawers_ids):
    return ["_{0:0>2}".format(_id) for _id in drawers_ids]

train_drawers_ids = create_str_drawers_ids(train_drawers_ids)
val_drawers_ids = create_str_drawers_ids(val_drawers_ids)
test_drawers_ids = create_str_drawers_ids(test_drawers_ids)

train_ds = OmniglotDataset("Train", data_path=TRAIN_DATA_PATH, 
                           alphabet_char_id_drawers_ids=train_alphabet_char_id_drawer_ids, 
                           drawers_ids=train_drawers_ids)

val_ds = OmniglotDataset("Test", data_path=TEST_DATA_PATH, 
                         alphabet_char_id_drawers_ids=test_alphabet_char_id_drawer_ids, 
                         drawers_ids=val_drawers_ids)

test_ds = OmniglotDataset("Test", data_path=TEST_DATA_PATH, 
                          alphabet_char_id_drawers_ids=test_alphabet_char_id_drawer_ids, 
                          drawers_ids=test_drawers_ids)

In [12]:
train_pairs = SameOrDifferentPairsDataset(train_ds, nb_pairs=int(30e3))
val_pairs = SameOrDifferentPairsDataset(val_ds, nb_pairs=int(10e3))
test_pairs = SameOrDifferentPairsDataset(test_ds, nb_pairs=int(10e3))

len(train_pairs), len(val_pairs), len(test_pairs)

(30000, 10000, 10000)

In [15]:
train_data_aug = Compose([
    RandomApply(
        RandomAffine(rotation=(-10, 10), scale=(0.8, 1.2), translate=(-0.05, 0.05)),
        proba=0.5
    ),
    ToTensor()
])

test_data_aug = Compose([
    ToTensor()
])

train_aug_pairs = PairTransformedDataset(train_pairs, x_transforms=train_data_aug)
val_aug_pairs = PairTransformedDataset(val_pairs, x_transforms=test_data_aug)
test_aug_pairs = PairTransformedDataset(test_pairs, x_transforms=test_data_aug)

In [24]:
batch_size = 32

_DataLoader = OnGPUDataLoader if HAS_GPU and torch.cuda.is_available() else DataLoader

train_batches = _DataLoader(train_aug_pairs, batch_size=batch_size, 
                            shuffle=True, num_workers=5, 
                            drop_last=True)

val_batches = _DataLoader(val_aug_pairs, batch_size=batch_size, 
                          shuffle=True, num_workers=5,
                          pin_memory=True, drop_last=True)

test_batches = _DataLoader(test_aug_pairs, batch_size=batch_size, 
                           shuffle=False, num_workers=5,                   
                           pin_memory=True, drop_last=False)


len(train_batches), len(val_batches), len(test_batches)

(937, 312, 313)

In [25]:
x, y = train_ds[0]
x.shape

(105, 105, 1)

## Setup model, loss function and optimisation algorithm

In [4]:
from model import SiameseNetworks

In [26]:
siamese_net = SiameseNetworks(input_shape=(105, 105, 1))
if HAS_GPU and torch.cuda.is_available():
    siamese_net = siamese_net.cuda()

#### Weight regularization

L2 weights regularization: 

#### Loss function

Binary cross-entropy

In [45]:
from torch.nn import CrossEntropyLoss
from torch.optim import RMSprop
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau

In [46]:
from torch.backends import cudnn
cudnn.benchmark = True

In [47]:
from common.dataflow import OnGPUDataLoader

In [48]:
from datetime import datetime
from common.training_utils import train_one_epoch, validate, write_csv_log, write_conf_log, verbose_optimizer, save_checkpoint

In [56]:
siamese_net = SiameseNetworks(input_shape=(105, 105, 1)).cuda()

RuntimeError: cuda runtime error (59) : device-side assert triggered at /pytorch/torch/lib/THC/generic/THCTensorCopy.c:18

In [49]:
conf = {
    'weight_decay': 0.01,
    'lr_features': 0.001,
    'lr_classifier': 0.01,
    
    'n_epochs': 10,
    'batch_size': 32,
    
    'gamma': 0.99
}

In [50]:
criterion = CrossEntropyLoss().cuda()

In [51]:
optimizer = RMSprop([{
    'params': siamese_net.net.features.parameters(),
    'lr': conf['lr_features'],    
}, {
    'params': siamese_net.classifier.parameters(),
    'lr': conf['lr_classifier']
}], 
    eps=1.0, 
    weight_decay=conf['weight_decay']
)

In [52]:
# lr <- lr_init * gamma ** epoch
scheduler = ExponentialLR(optimizer, gamma=conf['gamma'])
onplateau_scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

Start training

In [53]:
now = datetime.now()
logs_path = os.path.join('logs', 'seamese_networks_verification_task_%s' % (now.strftime("%Y%m%d_%H%M")))
if not os.path.exists(logs_path):
    os.makedirs(logs_path)

In [55]:
write_conf_log(logs_path, "{}".format(conf))
write_conf_log(logs_path, verbose_optimizer(optimizer))

train_batches = OnGPUDataLoader(train_aug_pairs, batch_size=conf['batch_size'], 
                                shuffle=True, num_workers=8, 
                                pin_memory=True, drop_last=True)

val_batches = OnGPUDataLoader(val_pairs, batch_size=conf['batch_size'], 
                              shuffle=True, num_workers=8, 
                              pin_memory=True, drop_last=True)

write_csv_log(logs_path, "epoch,train_loss,train_prec1,val_loss,val_prec1")

best_prec1 = 0.0
for epoch in range(conf['n_epochs']):
    scheduler.step()
    # Verbose learning rates:
    print(verbose_optimizer(optimizer))

    # train for one epoch
    train_loss, train_prec1 = train_one_epoch(siamese_net, train_batches, 
                                              criterion, optimizer, 
                                              epoch, conf['n_epochs'])
    assert train_loss, train_prec1

    # evaluate on validation set
    val_loss, val_prec1 = validate(siameze_net, val_batches, criterion)
    assert val_loss, val_prec1
    onplateau_scheduler.step(val_loss)

    # Write a csv log file
    write_csv_log(logs_path, "%i,%f,%f,%f,%f" % (epoch, train_loss, train_prec1, val_loss, val_prec1))

    # remember best prec@1 and save checkpoint
    if val_prec1 > best_prec1:
        best_prec1 = max(val_prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': siameze_net.state_dict(),
            'val_prec1': val_prec1,
            'optimizer': optimizer.state_dict()
        })        

  0%|          | 0/937 [00:00<?, ?it/s]


Optimizer: RMSprop
Optimizer parameters: 
- Param group: 
	lr: 0.00099
	initial_lr: 0.001
	alpha: 0.99
	weight_decay: 0.01
	momentum: 0
	eps: 1.0
	centered: False
- Param group: 
	lr: 0.0099
	initial_lr: 0.01
	alpha: 0.99
	weight_decay: 0.01
	momentum: 0
	eps: 1.0
	centered: False






RuntimeError: cuda runtime error (59) : device-side assert triggered at /pytorch/torch/lib/THC/generic/THCTensorCopy.c:126

In [60]:
!nvidia-smi

Sat Nov 18 01:04:07 2017       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 375.66                 Driver Version: 375.66                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  Off  | 0000:02:00.0     Off |                  N/A |
| 31%   55C    P2    78W / 250W |   1221MiB / 11170MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID  Type  Process name                               Usage    

### One-shot learning evaluation

