<a href="https://colab.research.google.com/github/shazzad-hasan/few-shot-learning/blob/main/siamese_neural_network/siamese_omniglot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/shazzad-hasan/few-shot-learning.git

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
%cd /content/few-shot-learning/siamese_neural_network

In [None]:
!ls

In [None]:
# import required libraries
import torch
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split 
from torch.utils.data.sampler import SubsetRandomSampler 

import os
import numpy as np
import random
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# import local helper functions
from helper_dataset import OmniglotDataset, nWayOneShotValidSet
from helper_train import train
from helper_evaluate import test

In [None]:
# check if cuda is available
train_on_gpu = torch.cuda.is_available()

if train_on_gpu:
  print("CUDA is available!")
else:
  print("CUDA is not available")

device = torch.device('cuda') if train_on_gpu else torch.device('cpu')

In [None]:
train_data = datasets.Omniglot(root="./data", download=True, transform=None)
test_data = datasets.Omniglot(root="./data", background = False, download=True, transform=None)

In [None]:
root_dir = '/content/few-shot-learning/siamese_neural_network/data/omniglot-py/images_evaluation/'
categories = [[folder, os.listdir(root_dir + folder)] for folder in os.listdir(root_dir)  if not folder.startswith('.') ]

In [None]:
data_size = 10000
# choose percentage of training data for validation
valid_pct = 0.2
valid_size = int(valid_pct * data_size)
train_size = data_size - valid_size

transform = transforms.Compose(
    [transforms.ToTensor()])

omniglot_data = OmniglotDataset(categories, root_dir, data_size, transform)
train_data, valid_data = random_split(omniglot_data, [train_size, valid_size])

## define dataloader parameters
# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 128
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers)
valid_loader = DataLoader(valid_data, batch_size=batch_size, num_workers=num_workers)

In [None]:
valid_size

In [None]:
for img1, img2, label in train_loader:
  if label[0] == 1.0:
    plt.subplot(1,2,1)
    plt.imshow(img1[0][0])
    plt.subplot(1,2,2)
    plt.imshow(img2[0][0])
    break

In [None]:
test_size = 5000
n_way = 20
batch_size = 1
num_workers = 0

test_data = nWayOneShotValidSet(categories, root_dir, test_size, n_way, transform)
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)

In [None]:
# print out some data stats
print("Number of training images: ", len(train_data))
print("Number of validation images: ", len(valid_data))
print("Number of test images: ", len(test_data))

for img, _, _ in train_loader:
  print("Image batch dimension: ", img.shape)
  break

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
  def __init__(self):
    super(SiameseNetwork, self).__init__()
    
    # Conv2d(in_channels, out_channels, kernel_size)
    self.conv1 = nn.Conv2d(1, 64, 10)
    self.conv2 = nn.Conv2d(64, 128, 7)
    self.conv3 = nn.Conv2d(128, 128, 4)
    self.conv4 = nn.Conv2d(128, 256, 4)

    self.bn1 = nn.BatchNorm2d(64)
    self.bn2 = nn.BatchNorm2d(128)
    self.bn3 = nn.BatchNorm2d(128)
    self.bn4 = nn.BatchNorm2d(256)

    self.dropout1 = nn.Dropout(0.1)
    self.dropout2 = nn.Dropout(0.5)

    self.fc1 = nn.Linear(256*6*6, 4096)
    self.fc2 = nn.Linear(4096, 1)

  def convolution(self, x):
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.max_pool2d(x, (2,2))
    x = F.relu(self.bn2(self.conv2(x)))
    x = F.max_pool2d(x, (2,2))
    x = F.relu(self.bn3(self.conv3(x)))
    x = F.max_pool2d(x, (2,2))
    x = F.relu(self.bn4(self.conv4(x)))

    return x

  def forward(self, x1, x2):
    x1 = self.convolution(x1)
    x2 = self.convolution(x2)

    # flatten input image
    x1 = x1.view(-1, 256*6*6)
    x2 = x2.view(-1, 256*6*6)

    x1 = torch.sigmoid(self.fc1(x1))
    x2 = torch.sigmoid(self.fc1(x2))

    x = torch.abs(x1 - x2)
    x = self.fc2(x)

    return x

model = SiameseNetwork()
model.to(device)
print(model)

In [None]:
import torch.optim as optim

# specify a loss function
criterion = nn.BCEWithLogitsLoss() # categorical cross-entropy

# specify optimizer
params = model.parameters()
optimizer = optim.Adam(params, lr=0.001)

In [None]:
num_epochs = 5
train_losses, valid_losses = train(model, train_loader, valid_loader, num_epochs, criterion, optimizer, device)

In [None]:
# plot trainining and validation loss for each epoch
epochs = range(1, num_epochs+1)
plt.plot(epochs, train_losses, 'bo', label="Training loss")
plt.plot(epochs, valid_losses, 'b', label="Validation loss")
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('Training and validation loss')
plt.legend(loc='upper right')
plt.show()

In [None]:
test(model, test_loader)