# Metric Learning: Prototypical Network for Few-Shot Learning



In [None]:
!wget https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip
!wget https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip

In [None]:
!unzip -qq images_background.zip
!unzip -qq images_evaluation.zip

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from os import listdir
from os.path import isfile, join
import os
from tqdm.notebook import tqdm
from tqdm.notebook import trange

import torch
from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from skimage import io

from scipy import ndimage
import multiprocessing as mp
import os
import cv2

from sklearn.preprocessing import LabelEncoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Dataset

In [None]:
def read_alphabets(alphabet_directory_path, alphabet_directory_name):
    """
    Reads all the characters from a given alphabet_directory
    """
    datax = []
    datay = []
    characters = os.listdir(alphabet_directory_path)
    for character in characters:
        images = os.listdir(alphabet_directory_path + character + '/')
        for img in images:
            image = cv2.resize(
                cv2.imread(alphabet_directory_path + character + '/' + img),
                (28,28)
                )
            #rotations of image
            rotated_90 = ndimage.rotate(image, 90)
            rotated_180 = ndimage.rotate(image, 180)
            rotated_270 = ndimage.rotate(image, 270)
            datax.extend((image, rotated_90, rotated_180, rotated_270))
            datay.extend((
                alphabet_directory_name + '_' + character + '_0',
                alphabet_directory_name + '_' + character + '_90',
                alphabet_directory_name + '_' + character + '_180',
                alphabet_directory_name + '_' + character + '_270'
            ))
    return np.array(datax), np.array(datay)

def read_images(base_directory):
    """
    Reads all the alphabets from the base_directory
    Uses multithreading to decrease the reading time drastically
    """
    datax = None
    datay = None
    pool = mp.Pool(mp.cpu_count())
    results = [pool.apply(read_alphabets,
                          args=(
                              base_directory + '/' + directory + '/', directory, 
                              )) for directory in os.listdir(base_directory)]
    pool.close()
    for result in results:
        if datax is None:
            datax = result[0]
            datay = result[1]
        else:
            datax = np.vstack([datax, result[0]])
            datay = np.concatenate([datay, result[1]])
    return datax, datay

In [None]:
class OmniglotDataset(Dataset):

    def __init__(self, dataset_dir):
      self.dataset_dir = f'{dataset_dir}'
      self.images, self.labels = read_images(self.dataset_dir)

      encoder = LabelEncoder()
      self.labels = encoder.fit_transform(self.labels)
      
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):

        data = torch.tensor(self.images[index])
        data = data.permute(2,0,1)
        data = data.type(torch.FloatTensor)
        label = torch.tensor(self.labels[index])

        return data, label

In [None]:
train_dataset = OmniglotDataset('/content/images_background')
valid_dataset = OmniglotDataset('/content/images_evaluation')

In [None]:
class CustomSampler(Sampler):

    def __init__(self, labels, n_batch, n_ways, n_support, n_query):
        self.n_batch = n_batch # training episodes (количество батчей на 1 эпоху)
        self.n_ways = n_ways # количество классов
        self.n_shots = n_support # количество support векторов в одном классе
        self.n_query = n_query # количество query векторов в одном классе
        self.n_elmts = n_support + n_query

        unique_labels = np.unique(labels)

        self.indices_per_class = {idx: np.where(labels == idx)[0] for idx in unique_labels}
        
        
    def __iter__(self):
        for _ in range(self.n_batch):
            batch = []

            # 1.1. choose random classes
            n_classes = len(self.indices_per_class) # number of unique classes
            classes = np.random.choice(n_classes, self.n_ways, replace=False)

            # 1.2. save indices of elements for randomly chosen classes
            for class_k in classes:
              # indexes of elements for class k
                indices_k = self.indices_per_class[class_k]

                # choose random elements inside class k
                n_elements = len(indices_k) 
                pos = np.random.choice(n_elements, self.n_elmts, replace=False)

                # save indices of chosen elements into batch
                batch.append(indices_k[pos])
            
            # from 2d array to 1d array of indices      # class 1: [element 1, element 2], class 2: [element 1, element 2], ...]
            
            batch = np.stack(batch, axis=-1).reshape(1,-1)[0]
            # batch = batch.numpy()

            yield batch # на выходе: i - class, j - element's number  [x11, x21, x12, x22]

    def __len__(self): # количество элементов в итераторе
        return self.n_batch

In [None]:
train_sampler = CustomSampler(labels = train_dataset.labels,
                                    n_batch = 2000,
                                    n_ways = 60, # n_way
                                    n_support = 5, # n_shots
                                    n_query = 5)

val_sampler = CustomSampler(labels = valid_dataset.labels, 
                                    n_batch = 60,
                                    n_ways = 5, 
                                    n_support = 5, 
                                    n_query = 5)

In [None]:
train_dataloader = DataLoader(dataset=train_dataset, 
                          batch_sampler=train_sampler,
                          num_workers=2, pin_memory=True)

val_dataloader = DataLoader(dataset=valid_dataset, 
                        batch_sampler=val_sampler,
                        num_workers=2, pin_memory=True)

## 2. Model

In [None]:
class Convnet(nn.Module):

    def __init__(self, x_dim=(3,28,28), hid_dim=64, z_dim=64):
        super(Convnet, self).__init__()
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2)
            )

        self.encoder = nn.Sequential(
            conv_block(x_dim[0], hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, z_dim),
        )

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)

In [None]:
model = Convnet().to(device)

## 3. Optimizer + Scheduler

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.5, last_epoch=-1)

## 4. Loss function

In [None]:
def calculate_dist(query, prototype):
        '''
        Calculate squared euclidean distance between each sample and prototype 
        from each class.
        i.e.
        Q = [q11, # shape: n_query*n_ways, emb_size qij - i - class, j - element
            q21,
            q31, 
            q12,
            q22,
            q32]

        Parameters
        ----------
        query: 2d tensor with shape [n_query*n_ways, emb_size]
        prototype: 2d tensor with shape [n_ways, emb_size]
        dist_type: distance type to calculate. 
                   It can be either "squared_euclidean" or "cosine_similarity".
        
        Return
        -----------
        distances: 2d tensor with shape torch.Size([n_query*n_ways, n_ways])
            d(q11, c1), d(q11, c2), d(q11, c3)
            d(q21, c1), d(q21, c2), d(q21, c3)
            d(q31, c1), d(q31, c2), d(q31, c3)

            d(q12, c1), d(q12, c2), d(q11, c3)
            d(q22, c1), d(q22, c2), d(q21, c3)
            d(q32, c1), d(q32, c2), d(q32, c3)

        '''
        # x: N x D
        # y: M x D
        n = query.size(0)
        m = prototype.size(0)
        d = query.size(1)

        # with shape [n_query*n_ways, n_ways, emb_size]
        query = query.unsqueeze(1).expand(n, m, d) 
        prototype = prototype.unsqueeze(0).expand(n, m, d) 
        
        distance = torch.pow(query - prototype, 2).sum(2)
        
        return distance 

In [None]:
class PrototypicalLoss(nn.Module):

    def __init__(self):
        super(PrototypicalLoss, self).__init__()
    
    def forward(self, data, labels, n_ways, n_shots, n_query):

        # 1. Divide data on support and query
        p = n_ways * n_shots
        support = data[:p,:]
        query = data[p:,:]

        # 1.1. make shape [n_shots, n_ways, emb_size]
        support = support.reshape(n_shots, n_ways, -1)

        # 2. Compute prototype from support examples for each class
        prototype = support.mean(0)

        # 3. Compute euclidean distances between query samples and prototypes
        dist = calculate_dist(query, prototype)

        # 4. Calculate CrossEntopyLoss(-d(q,ck))
        logits = -dist
        log_p_y = F.log_softmax(-dist, dim=1).view(n_ways, n_query, -1)
        
        target_inds = torch.arange(0, n_ways)
        target_inds = target_inds.view(n_ways, 1, 1)
        target_inds = target_inds.expand(n_ways, n_query, 1).long()

        loss = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()

        _, y_hat = log_p_y.max(2)
        acc = y_hat.eq(target_inds.squeeze(2)).float().mean()

        return loss, acc

In [None]:
criterion = PrototypicalLoss()

## 5. Train

In [None]:
def train_singe_epoch(model,
                      train_dataloader, 
                      epoch, 
                      n_ways, 
                      n_shots, 
                      n_query):
    model.train()
    pbar = tqdm(train_dataloader, desc=f'Train (epoch = {epoch})', leave=False)  

    total_loss = 0
    total_acc = 0
    for batch in pbar:

        data, _ = batch 
        data = data.to(device)
        data = model(data)

        label = torch.arange(n_ways).repeat(n_query)
        label = label.to(device)

        loss, acc = criterion(data, label, n_ways, n_shots, n_query)

        total_loss += loss
        total_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_loss = total_loss / len(train_dataloader)
    print("\nAverage train loss: {}".format(avg_loss))

    avg_acc = total_acc / len(train_dataloader)
    print("\nAverage train accuracy: {}".format(avg_acc))

    return model, avg_loss, avg_acc

## 6. Evaluate

In [None]:
def evaluate_single_epoch(model, 
                          val_dataloader, 
                          epoch,
                          n_ways, 
                          n_shots, 
                          n_query):
    model.eval()
    pbar = tqdm(val_dataloader, desc=f'Validation (epoch = {epoch})', leave=False)  

    total_loss = 0
    total_acc = 0
    for batch in pbar:

        data, _ = batch 
        data = data.to(device)
        data = model(data)

        label = torch.arange(n_ways).repeat(n_query)
        label = label.to(device)

        loss, acc = criterion(data, label, n_ways, n_shots, n_query)

        total_loss += loss
        total_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_loss = total_loss / len(val_dataloader)
    print("\nAverage validation loss: {}".format(avg_loss))

    avg_acc = total_acc / len(val_dataloader)
    print("\nAverage validation accuracy: {}".format(avg_acc))

    return model, avg_loss, avg_acc

## 7. Train loop

In [None]:
def train_model(model, 
                num_epochs, 
                train_dataloader, 
                valid_dataloader,
                train_sampler,
                val_sampler):

    logs = {}
    logs['train_loss'] = []
    logs['train_acc'] = []
    logs['val_loss'] = []
    logs['val_acc'] = []
    logs['best_acc'] = 0.0

    for epoch in trange(num_epochs, desc="Epoch"):
        
        model, train_loss, train_acc = train_singe_epoch(model, 
                                  train_dataloader,
                                  epoch, 
                                  train_sampler.n_ways, 
                                  train_sampler.n_shots, 
                                  train_sampler.n_query)
        
        model, val_loss, val_acc = evaluate_single_epoch(model, 
                                  valid_dataloader,
                                  epoch, 
                                  val_sampler.n_ways, 
                                  val_sampler.n_shots, 
                                  val_sampler.n_query)
        scheduler.step()

        logs['train_loss'].append(train_loss)
        logs['train_acc'].append(train_acc)
        logs['val_loss'].append(val_loss)
        logs['val_acc'].append(val_acc)

        if logs['best_acc'] < val_acc:
            logs['best_acc'] = val_acc

        torch.save(logs,'logs')


    return model

In [None]:
model = train_model(model=model,
    num_epochs=5,
    train_dataloader=train_dataloader, 
    valid_dataloader=val_dataloader,
    train_sampler=train_sampler,
    val_sampler=val_sampler)

In [None]:
logs = torch.load("logs")

In [None]:
plt.plot(np.arange(len(logs["train_loss"])),logs["train_loss"])
plt.title("Train Loss")
plt.xlabel("Epoch")
plt.ylabel("Prototypical Loss");

In [None]:
plt.plot(np.arange(len(logs["train_acc"])),logs["train_acc"])
plt.title("Train Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy");

In [None]:
plt.plot(np.arange(len(logs["val_loss"])),logs["val_loss"])
plt.title("Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Prototypical Loss");

In [None]:
plt.plot(np.arange(len(logs["val_acc"])),logs["val_acc"])
plt.title("Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy");