## Imports

In [1]:
import os
import random
import time

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

## Omnigot Dataset class
This class is derived from the Dataset class. It is responsible for downloading, extracting and loading the Omniglot dataset. Also allows getting i-th datapoints and preparing sets for n-way validation

In [2]:
class Omniglot(Dataset):
    def __init__(self):
        """ Initializes the Omniglot dataset. Downloads, extracts and loads the data in memory"""
        super(Omniglot, self).__init__()
        np.random.seed(0)

        self.__acquire_dataset()
        self.transform = transforms.Compose([transforms.RandomAffine(15), transforms.ToTensor()])
        self.data = self.__load_data()
        print("Dataset loaded in memory!")

    def __download_dataset(self):
        """ Downloads the Omniglot dataset"""
        print("Downloading the Omniglot dataset...")
        os.system('wget https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip')
        os.system('wget https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip')

    def __extract_dataset(self):
        """ Extracts the Omniglot dataset"""
        print("Extracting the Omniglot dataset...")
        os.system('unzip images_background')
        os.system('unzip images_evaluation')

    def __acquire_dataset(self):
        """ Downloads and extracts the Omniglot dataset"""
        print("Acquiring the Omniglot dataset...")

        if os.path.exists('/content/images_background'):
            print("Dataset downloaded and extracted!")
            return

        if os.path.isfile('/content/images_background.zip'):
            self.__extract_dataset()
        else:
            self.__download_dataset()
            self.__extract_dataset()           

        print("Dataset downloaded and extracted!")

    def __load_data(self):
        print("Loading the dataset in memory...")
        data_path = '/content/images_background/'
        data = {}

        for alphabet in os.listdir(data_path):
            data[alphabet] = {}
            alpha_path = os.path.join(data_path, alphabet)

            for character in os.listdir(alpha_path):
                imgs = []
                img_path = os.path.join(data_path, alphabet, character)
                for img in os.listdir(img_path):
                    img_path = os.path.join(data_path, alphabet, character, img)
                    imgs.append(Image.open(img_path).convert('L'))
                
                data[alphabet][character] = imgs
        
        return data

    def __len__(self):
        length = 0
        for alphabet in self.data:
            for character in self.data[alphabet]:
                for img in character:
                    length += 1
        return length      
    
    def get_num_classes(self):
        count = 0
        for alphabet in self.data:
            for character in self.data[alphabet]:
                count += 1
        return count

    def __getitem__(self, index):
        target = None
        img1 = None
        img2 = None

        alphabets = list(self.data.keys())

        if index % 2 == 0:
            target = torch.from_numpy(np.array([1.0], dtype=np.float32))
            alphabet = random.choice(alphabets)
            characters = list(self.data[alphabet].keys())
            character = random.choice(characters)
            imgs = self.data[alphabet][character]

            img1 = random.choice(imgs)
            img2 = random.choice(imgs)

        else:
            target = torch.from_numpy(np.array([0.0], dtype=np.float32))

            alphabet_1 = random.choice(alphabets)
            alphabet_2 = random.choice(alphabets)

            while alphabet_1 == alphabet_2:
                alphabet_2 = random.choice(alphabets)
            
            characters_1 = list(self.data[alphabet_1].keys())
            characters_2 = list(self.data[alphabet_2].keys())
            char1 = random.choice(characters_1)
            char2 = random.choice(characters_2)

            img1 = random.choice(self.data[alphabet_1][char1])
            img2 = random.choice(self.data[alphabet_2][char2])

        img1 = self.transform(img1)
        img2 = self.transform(img2)

        return img1, img2, target

    # Functions for validation
    def make_n_way_sets(self, n=3):
        """ Prepares lists of images for n-way validation
            Takes: n (int) as arugment
            Returns: a list of lists containing n+1 images, with the first two images belonging to the same class,
                        while the remaining n-2 images are from random classes
                    The targets for similarity should be 1 for the 2nd image, and 0 for the rest (1st image is the image being validated)"""
        images = []

        for alphabet in self.data:
            for character in self.data[alphabet]:
                current_image = random.choice(self.data[alphabet][character])

                other_image = current_image
                while other_image == current_image:
                    other_image = random.choice(self.data[alphabet][character])

                current_image_set = []

                current_image_set.append(current_image)
                current_image_set.append(other_image)

                for i in range(n - 1):
                    random_alphabet = random.choice(list(self.data))
                    random_character = random.choice(list(self.data[random_alphabet])) 

                    while random_alphabet == language and random_character == character:
                        random_character =  random.choice(list(self.data[random_alphabet]))

                    random_image = random.choice(self.data[random_alphabet][random_character])
                    current_image_set.append(random_image)

                images.append(current_image_set)

        return images


## Siamese Neural Network
This class defines the architecture of the Siamese Network used for One-Shot learning

In [3]:
class Siamese_Net(nn.Module):
    def __init__(self):
        super(Siamese_Net, self).__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=10),  
            nn.ReLU(inplace=True),
   
            nn.MaxPool2d(kernel_size=2),  
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=7),
            nn.ReLU(),   
            
            nn.MaxPool2d(kernel_size=2),  
            
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4),
            nn.ReLU(), 
            
            nn.MaxPool2d(kernel_size=2), 
            
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4),
            nn.ReLU(),  
        )

        self.linear = nn.Sequential(nn.Linear(9216, 4096), nn.Sigmoid())
        self.output = nn.Linear(4096, 1)

    def forward_pass(self, inp):
        inp = self.conv_block(inp)
        inp = inp.view(inp.size()[0], -1)
        inp = self.linear(inp)
        return inp

    def forward(self, inp1, inp2):
        pass_1 = self.forward_pass(inp1)
        pass_2 = self.forward_pass(inp2)
        output = self.output(torch.abs(pass_1 - pass_2))
        return output

## Training the Network

### Loading Data

In [0]:
batch_size = 128
num_workers = 2

train_dataset = Omniglot()

device = "cuda" if torch.cuda.is_available() else "cpu"

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

### Training

In [0]:
criterion = torch.nn.BCEWithLogitsLoss(size_average=True)
model = Siamese_Net()
model.cuda()
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.00006)
optimizer.zero_grad()

time_start = time.time()

for epoch in range(100):
    run_loss = 0.0
    print ("Epoch: ", epoch)
    for batch_id, (img1, img2, label) in enumerate(train_loader, 1):
        img1, img2, label = Variable(img1.cuda()), Variable(img2.cuda()), Variable(label.cuda())
        optimizer.zero_grad()
        output = model.forward(img1, img2)
        loss = criterion(output, label)
        run_loss += loss.item()
        loss.backward()
        optimizer.step()

        if batch_id % 10 == 0:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, batch_id + 1, run_loss))
            run_loss = 0.0

In [0]:
torch.save(net.state_dict(), "trained_model.pt")