# Appendix. Siamese neural networks training with one-shot learning evaluation

This notebook presents the paper ["Siamese Neural Networks for One-shot Image Recognition"](https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf) coded with PyTorch framework. 

In this part we train Siamese network on the Omniglot dataset to perform the classification task to distinguish two images of the same class or different classes.

Code is similar to `keras-oneshot`.

References:
- [paper](https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf)
- [omniglot](https://github.com/brendenlake/omniglot)
- [keras-oneshot](https://github.com/sorenbouma/keras-oneshot)


In [1]:
# https://ipython.org/ipython-doc/3/config/extensions/autoreload.html
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys
import numpy as np
import cv2

In [3]:
sys.path.append("..")

In [4]:
HAS_GPU = True

## Setup dataflow

In [5]:
from dataflow import OmniglotDataset, SameOrDifferentPairsBatchDataset
from common_utils.imgaug import RandomAffine, RandomApply
from common_utils.dataflow import TransformedDataset, OnGPUDataLoader
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
import torch

In [6]:
seed = 12345
np.random.seed(seed)

OMNIGLOT_REPO_PATH='omniglot'

TRAIN_DATA_PATH = os.path.join(OMNIGLOT_REPO_PATH, 'python', 'images_background')
train_alphabets = !ls {TRAIN_DATA_PATH}
train_alphabets = list(train_alphabets)

TEST_DATA_PATH = os.path.join(OMNIGLOT_REPO_PATH, 'python', 'images_evaluation')
test_alphabets = !ls {TEST_DATA_PATH}
test_alphabets = list(test_alphabets)

assert len(train_alphabets) > 1 and len(test_alphabets) > 1, "%s \n %s" % (train_alphabets[0], test_alphabets[0])

train_alphabet_char_id_drawer_ids = {}
for a in train_alphabets:
    res = !ls "{os.path.join(TRAIN_DATA_PATH, a)}"
    char_ids = list(res)
    train_alphabet_char_id_drawer_ids[a] = {}
    for char_id in char_ids:
        res = !ls "{os.path.join(TRAIN_DATA_PATH, a, char_id)}"
        train_alphabet_char_id_drawer_ids[a][char_id] = [_id[:-4] for _id in list(res)]
        
        
test_alphabet_char_id_drawer_ids = {}
for a in test_alphabets:
    res = !ls "{os.path.join(TEST_DATA_PATH, a)}"
    char_ids = list(res)
    test_alphabet_char_id_drawer_ids[a] = {}
    for char_id in char_ids:
        res = !ls "{os.path.join(TEST_DATA_PATH, a, char_id)}"
        test_alphabet_char_id_drawer_ids[a][char_id] = [_id[:-4] for _id in list(res)]


# Sample 12 drawers out of 20
all_drawers_ids = np.arange(20) 
train_drawers_ids = np.random.choice(all_drawers_ids, size=12, replace=False)
# Sample 4 drawers out of remaining 8
val_drawers_ids = np.random.choice(list(set(all_drawers_ids) - set(train_drawers_ids)), size=8, replace=False)

def create_str_drawers_ids(drawers_ids):
    return ["_{0:0>2}".format(_id) for _id in drawers_ids]

train_drawers_ids = create_str_drawers_ids(train_drawers_ids)
val_drawers_ids = create_str_drawers_ids(val_drawers_ids)

train_ds = OmniglotDataset("Train", data_path=TRAIN_DATA_PATH, 
                           alphabet_char_id_drawers_ids=train_alphabet_char_id_drawer_ids, 
                           drawers_ids=train_drawers_ids)

val_ds = OmniglotDataset("Test", data_path=TEST_DATA_PATH, 
                         alphabet_char_id_drawers_ids=test_alphabet_char_id_drawer_ids, 
                         drawers_ids=val_drawers_ids)

In [7]:
train_data_aug = Compose([
    RandomApply(
        RandomAffine(rotation=(-10, 10), scale=(0.8, 1.2), translate=(-0.05, 0.05)),
        proba=0.5
    ),
    ToTensor()
])

test_data_aug = Compose([
    ToTensor()
])


batch_size = 64
nb_train_batches = 100
nb_val_batches = 100

train_batches = SameOrDifferentPairsBatchDataset(train_ds,
                                                 batch_size=batch_size,
                                                 nb_batches=nb_train_batches,
                                                 x_transforms=train_data_aug,
                                                 pin_memory=HAS_GPU, on_gpu=HAS_GPU)

val_batches = SameOrDifferentPairsBatchDataset(val_ds,
                                               batch_size=batch_size,
                                               nb_batches=nb_val_batches,
                                               x_transforms=test_data_aug,
                                               pin_memory=HAS_GPU, on_gpu=HAS_GPU)

In [8]:
for (x1, x2), y in train_batches:
    print(x1.size(), x2.size(), y.size())
    print(type(x1), type(x1), type(y))    
    break

torch.Size([64, 1, 105, 105]) torch.Size([64, 1, 105, 105]) torch.Size([64, 1])
<class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'>


In [None]:

#     def make_oneshot_task(self,N,s="val",language=None):
#         """Create pairs of test image, support set for testing N way one-shot learning. """
#         X=self.data[s]
#         n_classes, n_examples = X.shape[0],X.shape[1]
#         if language is not None:
#             low, high = self.categories[s][language]
#             if N > high - low:
#                 raise ValueError("This language ({}) has less than {} letters".format(language, N))
#             categories = rng.choice(range(low,high),size=(N,),replace=False)
#             indices = rng.randint(0,self.n_examples,size=(N,))
            
#         else:#if no language specified just pick a bunch of random letters
#             categories = rng.choice(range(n_classes),size=(N,),replace=False)            
#             indices = rng.randint(0,self.n_examples,size=(N,))
#         true_category = categories[0]
#         ex1, ex2 = rng.choice(n_examples,replace=False,size=(2,))
#         test_image = np.asarray([X[true_category,ex1,:,:]]*N).reshape(N,self.w,self.h,1)
#         support_set = X[categories,indices,:,:]
#         support_set[0,:,:] = X[true_category,ex2]
#         support_set = support_set.reshape(N,self.w,self.h,1)
#         targets = np.zeros((N,))
#         targets[0] = 1
#         targets, test_image, support_set = shuffle(targets, test_image, support_set)
#         pairs = [test_image,support_set]

#         return pairs, targets

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler, Sampler

from collections import defaultdict


class OneShotLearningDataset2(Dataset):
    
    def __init__(self, n_trials, n_classes, ds, class_indices=None,
                 x_transforms=None, y_transforms=None,
                 pin_memory=True, on_gpu=True):
        """
        """
        super(OneShotLearningDataset2, self).__init__()

        self.n_trials = n_trials                
        self.n_classes = n_classes        
        self.ds = ds
        self.pin_memory = pin_memory
        self.on_gpu = on_gpu
        
        self.x_transforms = x_transforms if x_transforms is not None else lambda x: x
        self.y_transforms = y_transforms if y_transforms is not None else lambda y: y
        
        if class_indices is None:
            # get mapping y_label -> indices
            class_indices = defaultdict(list)
            for i, (_, y) in enumerate(ds):
                class_indices[y].append(i)

        self.class_indices = class_indices
        self.classes = list(self.class_indices.keys())
        
    def __len__(self):
        return len(self.n_trials)
    
    def __getitem__(self, index):

        if index >= self.n_trials:
            raise IndexError()
            
        random_classes = np.random.choice(self.classes, size=(self.n_classes, ), replace=False)        
        
#         support_set_indices = []        
#         for indices in self.class_indices.values():
#             index = np.random.randint(len(indices))
#             support_set_indices.append(indices[index])
#         np.random.shuffle(support_set_indices)

        true_class = random_classes[0]
        n_samples = len(self.class_indices[true_class])
        index1 = np.random.randint(0, n_samples)
        test_x, _ = self.ds[self.class_indices[true_class][index1]]
                
        targets = np.zeros((self.n_classes,), dtype=np.int)
        targets[0] = 1
        targets = torch.from_numpy(targets)

        
        
#        categories = rng.choice(range(n_classes),size=(N,),replace=False)                    
#        indices = rng.randint(0,self.n_examples,size=(N,))

#         true_category = categories[0]
#         ex1, ex2 = rng.choice(n_examples,replace=False,size=(2,))
#         test_image = np.asarray([X[true_category,ex1,:,:]]*N).reshape(N,self.w,self.h,1)
#         support_set = X[categories,indices,:,:]
#         support_set[0,:,:] = X[true_category,ex2]
#         support_set = support_set.reshape(N,self.w,self.h,1)
#         targets = np.zeros((N,))
#         targets[0] = 1
#         targets, test_image, support_set = shuffle(targets, test_image, support_set)
#         pairs = [test_image,support_set]

        return pairs, targets        



## Setup model, loss function and optimisation algorithm

#### Weight regularization

L2 weights regularization: 

#### Loss function

Binary cross-entropy

In [9]:
from torch.autograd import Variable
from torch.nn import BCEWithLogitsLoss
from torch.nn.functional import sigmoid
from torch.optim import Adam, RMSprop, SGD
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau

In [10]:
from datetime import datetime
from common_utils.training_utils import train_one_epoch, validate, write_csv_log, write_conf_log, verbose_optimizer, save_checkpoint
from common_utils.training_utils import accuracy

In [11]:
from model import SiameseNetworks

In [12]:
siamese_net = SiameseNetworks(input_shape=(105, 105, 1))
if HAS_GPU and torch.cuda.is_available():
    siamese_net = siamese_net.cuda()

In [13]:
conf = {
    'weight_decay': 0.01,
    
    'lr_features': 0.00006,
    'lr_classifier': 0.00006,
    
    'n_epochs': 15,    
    'gamma': 0.99
}

In [14]:
def accuracy_logits(y_logits, y_true):
    y_pred = sigmoid(y_logits).data
    return accuracy(y_pred, y_true)

In [15]:
criterion = BCEWithLogitsLoss()
if HAS_GPU and torch.cuda.is_available():
    criterion = criterion.cuda()

In [16]:
# Test single forward pass and loss function computation
siamese_net.eval()
for i, ((batch_x1, batch_x2), batch_y) in enumerate(train_batches):
    
    batch_x1 = Variable(batch_x1, requires_grad=True)
    batch_x2 = Variable(batch_x2, requires_grad=True)    
    batch_y = Variable(batch_y)
    batch_y_logits = siamese_net(batch_x1, batch_x2)
    print(type(batch_y.data), type(batch_y_logits.data), batch_y.size(), batch_y_logits.size())    
    loss = criterion(batch_y_logits, batch_y)
    print("Loss : ", loss.data)
    
    print("Accuracy : ", accuracy_logits(batch_y_logits.data, batch_y.data))
    break

<class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> torch.Size([64, 1]) torch.Size([64, 1])
Loss :  
 0.7024
[torch.cuda.FloatTensor of size 1 (GPU 0)]

Accuracy :  0.453125


In [17]:
optimizer = Adam([{
    'params': siamese_net.net.features.parameters(),
    'lr': conf['lr_features'],    
}, {
    'params': siamese_net.classifier.parameters(),
    'lr': conf['lr_classifier']
}],
    weight_decay=conf['weight_decay']
)

Note that we define L2 regularization weights through optimizer API as `weight_decay` parameter, [ref](http://pytorch.org/docs/master/optim.html?highlight=adam#torch.optim.Adam)

In [18]:
# lr <- lr_init * gamma ** epoch
scheduler = ExponentialLR(optimizer, gamma=conf['gamma'])
onplateau_scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

### Start training

In [27]:
now = datetime.now()
logs_path = os.path.join('logs', 'siamese_networks_verification_task_%s' % (now.strftime("%Y%m%d_%H%M")))
if not os.path.exists(logs_path):
    os.makedirs(logs_path)

In [28]:
write_conf_log(logs_path, "{}".format(conf))
write_conf_log(logs_path, verbose_optimizer(optimizer))

write_csv_log(logs_path, "epoch,train_loss,train_acc,val_loss,val_acc")

best_acc = 0.0
for epoch in range(conf['n_epochs']):
    scheduler.step()
    # Verbose learning rates:
    print(verbose_optimizer(optimizer))

    # train for one epoch
    ret = train_one_epoch(siamese_net, train_batches, 
                          criterion, optimizer,                                               
                          epoch, conf['n_epochs'], avg_metrics=[accuracy_logits,])
    if ret is None:
        break
    train_loss, train_acc = ret

    # evaluate on validation set
    ret = validate(siamese_net, val_batches, criterion, avg_metrics=[accuracy_logits, ])
    if ret is None:
        break
    val_loss, val_acc = ret
    
    onplateau_scheduler.step(val_loss)

    # Write a csv log file
    write_csv_log(logs_path, "%i,%f,%f,%f,%f" % (epoch, train_loss, train_acc, val_loss, val_acc))

    # remember best accuracy and save checkpoint
    if val_acc > best_acc:
        best_acc = max(val_acc, best_acc)
        save_checkpoint(logs_path, 'val_acc', 
                        {'epoch': epoch + 1,
                         'state_dict': siamese_net.state_dict(),
                         'val_acc': val_acc,           
                         'optimizer': optimizer.state_dict()})        

  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 6e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 6e-05



Epoch: 1/15: 100%|##########| 100/100 [00:08<00:00, 11.24it/s, Loss 0.6124 | accuracy_logits 0.647]
100%|##########| 100/100 [00:04<00:00, 23.54it/s, Loss 0.6022 | accuracy_logits 0.673]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.94e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.94e-05



Epoch: 2/15: 100%|##########| 100/100 [00:08<00:00, 11.18it/s, Loss 0.6002 | accuracy_logits 0.669]
100%|##########| 100/100 [00:04<00:00, 23.55it/s, Loss 0.6208 | accuracy_logits 0.666]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.8806e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.8806e-05



Epoch: 3/15: 100%|##########| 100/100 [00:08<00:00, 11.20it/s, Loss 0.5781 | accuracy_logits 0.694]
100%|##########| 100/100 [00:04<00:00, 23.50it/s, Loss 0.6297 | accuracy_logits 0.690]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.821794e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.821794e-05



Epoch: 4/15: 100%|##########| 100/100 [00:08<00:00, 11.11it/s, Loss 0.5659 | accuracy_logits 0.711]
100%|##########| 100/100 [00:04<00:00, 23.48it/s, Loss 0.5962 | accuracy_logits 0.700]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.7635760599999995e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.7635760599999995e-05



Epoch: 5/15: 100%|##########| 100/100 [00:08<00:00, 11.17it/s, Loss 0.5512 | accuracy_logits 0.718]
100%|##########| 100/100 [00:04<00:00, 23.16it/s, Loss 0.5784 | accuracy_logits 0.718]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.7059402994e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.7059402994e-05



Epoch: 6/15: 100%|##########| 100/100 [00:08<00:00, 11.08it/s, Loss 0.5576 | accuracy_logits 0.718]
100%|##########| 100/100 [00:04<00:00, 23.26it/s, Loss 0.5665 | accuracy_logits 0.728]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.6488808964060004e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.6488808964060004e-05



Epoch: 7/15: 100%|##########| 100/100 [00:08<00:00, 11.17it/s, Loss 0.5297 | accuracy_logits 0.741]
100%|##########| 100/100 [00:04<00:00, 23.27it/s, Loss 0.5624 | accuracy_logits 0.731]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.5923920874419396e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.5923920874419396e-05



Epoch: 8/15: 100%|##########| 100/100 [00:08<00:00, 11.08it/s, Loss 0.5059 | accuracy_logits 0.762]
100%|##########| 100/100 [00:04<00:00, 23.27it/s, Loss 0.5397 | accuracy_logits 0.745]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.536468166567521e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.536468166567521e-05



Epoch: 9/15: 100%|##########| 100/100 [00:08<00:00, 11.17it/s, Loss 0.4868 | accuracy_logits 0.775]
100%|##########| 100/100 [00:04<00:00, 23.21it/s, Loss 0.5411 | accuracy_logits 0.750]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.4811034849018454e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.4811034849018454e-05



Epoch: 10/15: 100%|##########| 100/100 [00:08<00:00, 11.15it/s, Loss 0.4973 | accuracy_logits 0.772]
100%|##########| 100/100 [00:04<00:00, 23.23it/s, Loss 0.5696 | accuracy_logits 0.741]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.4262924500528266e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.4262924500528266e-05



Epoch: 11/15: 100%|##########| 100/100 [00:08<00:00, 11.14it/s, Loss 0.4720 | accuracy_logits 0.785]
100%|##########| 100/100 [00:04<00:00, 23.21it/s, Loss 0.5487 | accuracy_logits 0.760]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch    10: reducing learning rate of group 0 to 2.7131e-05.
Epoch    10: reducing learning rate of group 1 to 2.7131e-05.

Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.372029525552299e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.372029525552299e-05



Epoch: 12/15: 100%|##########| 100/100 [00:08<00:00, 11.12it/s, Loss 0.4681 | accuracy_logits 0.792]
100%|##########| 100/100 [00:04<00:00, 23.23it/s, Loss 0.4785 | accuracy_logits 0.793]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.3183092302967755e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.3183092302967755e-05



Epoch: 13/15: 100%|##########| 100/100 [00:08<00:00, 11.19it/s, Loss 0.4511 | accuracy_logits 0.794]
100%|##########| 100/100 [00:04<00:00, 23.25it/s, Loss 0.5002 | accuracy_logits 0.778]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.2651261379938074e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.2651261379938074e-05



Epoch: 14/15: 100%|##########| 100/100 [00:08<00:00, 11.18it/s, Loss 0.4438 | accuracy_logits 0.804]
100%|##########| 100/100 [00:04<00:00, 23.23it/s, Loss 0.4868 | accuracy_logits 0.791]
  0%|          | 0/100 [00:00<?, ?it/s]


Optimizer: Adam
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.2124748766138696e-05
- Param group: 
	weight_decay: 0.01
	eps: 1e-08
	initial_lr: 6e-05
	betas: (0.9, 0.999)
	lr: 5.2124748766138696e-05



Epoch: 15/15: 100%|##########| 100/100 [00:08<00:00, 11.08it/s, Loss 0.4358 | accuracy_logits 0.808]
100%|##########| 100/100 [00:04<00:00, 23.25it/s, Loss 0.5162 | accuracy_logits 0.776]

Epoch    14: reducing learning rate of group 0 to 2.6062e-05.
Epoch    14: reducing learning rate of group 1 to 2.6062e-05.





In [100]:
list_of_classes = list(class_indices.keys()) 
y_transform = lambda y: torch.LongTensor([list_of_classes.index(y)])

val_aug_ds = TransformedDataset(val_ds, x_transforms=test_data_aug, y_transforms=y_transform)

In [58]:
from torch.utils.data.sampler import Sampler

class RandomSupportSetSampler(Sampler):
    """
    Random support set samplers for one-shot learning
    """

    def __init__(self, class_indices, seed=None):
        """
        :params class_indices: dictionary key=targets of `ds`, values=indices of `ds` corresponding to target
            Number N of N-way evalution is defined by number of keys in the dictionary         
            It can be obtained from a dataset with something like:
            ```
                class_indices = defaultdict(list)
                for i, (_, y) in enumerate(val_ds):
                    class_indices[y].append(i)
            ```
        """
        assert isinstance(class_indices, dict)
        self.class_indices = class_indices
        self.seed = seed
        
    def __iter__(self):
        if self.seed is not None:
            np.random.seed(seed)
        support_set_indices = []        
        for indices in self.class_indices.values():
            index = np.random.randint(len(indices))
            support_set_indices.append(indices[index])
        np.random.shuffle(support_set_indices)
        return iter(support_set_indices)

    def __len__(self):
        return len(self.class_indices)

In [121]:
from collections import defaultdict

class_indices = defaultdict(list)
for i, (_, y) in enumerate(val_aug_ds):
    y = y.numpy()[0] 
    class_indices[y].append(i)

sampler = RandomSupportSetSampler(class_indices)

In [116]:
from torch.utils.data import  DataLoader


def generate_support_set(ds, class_indices, seed=None, **kwargs):
    assert isinstance(class_indices, dict)
    assert isinstance(ds, Dataset)
    
    sampler = RandomSupportSetSampler(class_indices, seed)    
    data_loader = DataLoader(ds, batch_size=len(sampler), sampler=sampler, **kwargs)   
    return data_loader


def generate_test_dataset(ds, n_classes):
    for x, y in ds:
        x
        yield x, y

In [117]:
support_set = generate_support_set(val_aug_ds, class_indices, num_workers=4, pin_memory=True)

In [118]:
for set_x, set_y in support_set:
    print(type(set_x), set_x.size(), type(set_y), set_y.size())
    break

<class 'torch.FloatTensor'> torch.Size([659, 1, 105, 105]) <class 'torch.LongTensor'> torch.Size([659, 1])


In [199]:
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler, Sampler

from collections import defaultdict


class OneShotLearningDataset(Dataset):
    
    def __init__(self, test_ds, val_ds, val_class_indices=None, n=None,
                 seed=None, on_gpu=True, **kwargs):
        """
        Dataset for one-shot learning. It contains lenght of `test_ds` elements and at index `i` 
        returns `(test_x, support_set_x), (test_y, support_set_y)` where `test_x`, `test_y` are N copies of 
        a single data given by `test_ds[i]`. Variables `support_set_x`, `support_set_y` are tensors of N elements 
        from `val_ds` belonging to N different classes.
         
        :params test_ds: dataset that provides test data.
        :params val_ds: dataset from which to select a support set
        :params val_class_indices: dictionary key=targets of `val_ds`, values=indices of `val_ds` corresponding to target
            Number N of N-way evalution is defined by number of keys             
        :params **kwargs: for DataLoader
        """
        super(OneShotLearningDataset, self).__init__()
        
        if val_class_indices is None:
            val_class_indices = defaultdict(list)            
            for i, (_, y) in enumerate(val_ds):
                if torch.is_tensor(y):
                    y = y.numpy()[0] 
                val_class_indices[y].append(i)

        assert isinstance(val_class_indices, dict)
        
        self.test_ds = test_ds
        self.val_ds = val_ds
        self.val_class_indices = val_class_indices
        self.on_gpu = on_gpu
        
        sampler = RandomSupportSetSampler(val_class_indices, seed)    
        self.support_set_ds = DataLoader(self.val_ds, batch_size=len(sampler), sampler=sampler, **kwargs)   
        
    def __len__(self):
        return len(self.test_ds)
    
    def __getitem__(self, index):        
        test_x, test_y = self.test_ds[index]

        for support_set_x, support_set_y in self.support_set_ds:
            break

        if self.on_gpu:
            test_x = test_x.cuda()
            test_y = test_y.cuda()
            support_set_x = support_set_x.cuda()    
            support_set_y = support_set_y.cuda()
                
        test_x = test_x.expand_as(support_set_x)
        test_y = test_y.expand_as(support_set_y)
            
        return (test_x, support_set_x), (test_y, support_set_y)

In [200]:
# ds = OneShotLearningDataset(val_aug_ds, val_aug_ds)    
# val_acc = 0.0
# for (test_x, support_set_x), (test_y, support_set_y) in ds:
    
#     test_x = Variable(test_x, volatile=True)
#     support_set_x = Variable(support_set_x, volatile=True)    
    
#     y_logits = siamese_net(test_x, support_set_x)
#     y_proba = sigmoid(y_logits).data

#     if len(y_proba.size()) > 1:
#         y_proba = y_proba.view(-1)        

#     y_proba_top1, index_top1 = y_proba.topk(k=1, largest=True, dim=0)
#     if index_top1.is_cuda:
#         index_top1 = index_top1.cpu()    

#     classes_top1 = support_set_y[index_top1[0], 0]

#     if test_y[0, 0] == classes_top1:
#         val_acc += 1

In [207]:
from common_utils.training_utils import get_tqdm

def oneshot_learning_validation(model, val_ds, n=20):
    
    model.eval()
    try:
        ds = OneShotLearningDataset(val_aug_ds, val_aug_ds)    
        val_acc = 0.0

        with get_tqdm(total=len(ds)) as pbar:
            for (test_x, support_set_x), (test_y, support_set_y) in ds:
                
                test_x = Variable(test_x, volatile=True)
                support_set_x = Variable(support_set_x, volatile=True)    

                y_logits = model(test_x, support_set_x)
                y_proba = sigmoid(y_logits).data

                if len(y_proba.size()) > 1:
                    y_proba = y_proba.view(-1)        

                y_proba_top1, index_top1 = y_proba.topk(k=1, largest=True, dim=0)
                if index_top1.is_cuda:
                    index_top1 = index_top1.cpu()    

                classes_top1 = support_set_y[index_top1[0], 0]

                if test_y[0, 0] == classes_top1:
                    val_acc += 1
                    
                prefix_str = "One-shot learning eval : "
                pbar.set_description_str(prefix_str, refresh=False)
                    
                post_fix_str = "Accuracy: {}".format(val_acc)
                pbar.set_postfix_str(post_fix_str, refresh=False)
                pbar.update(1)            
        
        val_acc /= 1.0 * len(ds)
        return val_acc                    
    except KeyboardInterrupt:
        return None


In [208]:
val_acc = oneshot_learning_validation(siamese_net, val_aug_ds)

One-shot learning eval :   2%|2         | 126/5272 [00:35<24:04,  3.56it/s, Accuracy: 1.0]


### Inference on testing dataset

In [25]:
from common_utils.training_utils import load_checkpoint
from glob import glob

In [26]:
best_model_filenames = glob(os.path.join(logs_path, "model_val_acc=*"))
assert len(best_model_filenames) == 1
load_checkpoint(best_model_filenames[0], siamese_net)

Load checkpoint: logs/siamese_networks_verification_task_20171126_1148/model_val_acc=0.7575.pth.tar


In [27]:
# evaluate on validation set
test_loss, test_acc = validate(siamese_net, test_batches, criterion, avg_metrics=[accuracy_logits, ])
test_loss, test_acc

100%|##########| 157/157 [00:04<00:00, 34.78it/s, Loss 0.5082 | accuracy_logits 0.754]


(0.5081654835700988, 0.754)

### Run training script

In [28]:
!python3 train_model_with_oneshot_eval.py


Optimizer: Adam
- Param group: 
	initial_lr: 6e-05
	eps: 1e-08
	lr: 6e-05
	betas: (0.9, 0.999)
	weight_decay: 0.011
- Param group: 
	initial_lr: 8e-05
	eps: 1e-08
	lr: 8e-05
	betas: (0.9, 0.999)
	weight_decay: 0.011

Epoch: 1/50: 100%|#| 468/468 [00:42<00:00, 10.94it/s, Loss 0.6231 | accuracy_logits 0.631]
100%|####| 156/156 [00:04<00:00, 35.05it/s, Loss 0.5816 | accuracy_logits 0.687]

Optimizer: Adam
- Param group: 
	initial_lr: 6e-05
	eps: 1e-08
	lr: 5.4000000000000005e-05
	betas: (0.9, 0.999)
	weight_decay: 0.011
- Param group: 
	initial_lr: 8e-05
	eps: 1e-08
	lr: 7.2e-05
	betas: (0.9, 0.999)
	weight_decay: 0.011

Epoch: 2/50: 100%|#| 468/468 [00:42<00:00, 11.03it/s, Loss 0.5697 | accuracy_logits 0.701]
100%|####| 156/156 [00:04<00:00, 35.79it/s, Loss 0.5615 | accuracy_logits 0.718]

Optimizer: Adam
- Param group: 
	initial_lr: 6e-05
	eps: 1e-08
	lr: 4.86e-05
	betas: (0.9, 0.999)
	weight_decay: 0.011
- Param group: 
	initial_lr: 8e-05
	eps: 1e-08
	lr: 6.48e-05
	betas: (0.9, 0.999)

Epoch: 20/50: 100%|#| 468/468 [00:43<00:00, 10.91it/s, Loss 0.2109 | accuracy_logits 0.927]
100%|####| 156/156 [00:04<00:00, 35.62it/s, Loss 0.3449 | accuracy_logits 0.847]

Optimizer: Adam
- Param group: 
	initial_lr: 6e-05
	eps: 1e-08
	lr: 7.294599275434162e-06
	betas: (0.9, 0.999)
	weight_decay: 0.011
- Param group: 
	initial_lr: 8e-05
	eps: 1e-08
	lr: 9.726132367245548e-06
	betas: (0.9, 0.999)
	weight_decay: 0.011

Epoch: 21/50: 100%|#| 468/468 [00:42<00:00, 11.04it/s, Loss 0.2073 | accuracy_logits 0.930]
100%|####| 156/156 [00:04<00:00, 35.62it/s, Loss 0.3419 | accuracy_logits 0.853]

Optimizer: Adam
- Param group: 
	initial_lr: 6e-05
	eps: 1e-08
	lr: 6.5651393478907455e-06
	betas: (0.9, 0.999)
	weight_decay: 0.011
- Param group: 
	initial_lr: 8e-05
	eps: 1e-08
	lr: 8.753519130520995e-06
	betas: (0.9, 0.999)
	weight_decay: 0.011

Epoch: 22/50: 100%|#| 468/468 [00:43<00:00, 10.81it/s, Loss 0.2062 | accuracy_logits 0.929]
100%|####| 156/156 [00:04<00:00, 35.65it/s, Loss 0.3448 | acc