In [1]:
import numpy as np
import random
from collections import defaultdict

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

device = "cuda"

In [2]:
class OmniglotDataset(Dataset):
  def __init__(self, nsupport: int, nquery: int, Train: bool):
    assert nsupport + nquery <= 20, "nsupport + nquery cannot be more than 20"

    self.nsupport = nsupport
    self.nquery = nquery

    self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((28,28)) ])
    self.images_dataset = torchvision.datasets.Omniglot(root="./data", download = True, background = Train, transform = self.transform)
    
    if Train:
      self.num_classes = 964 
    else:
      self.num_classes = 659
      
    self.label_to_idx = defaultdict(list)

    for i, datapoint in enumerate(self.images_dataset):
      label = datapoint[1]
      self.label_to_idx[label].append(i)
    
  def __len__(self):
    return self.num_classes

  def __getitem__(self, idx):
    idxs = self.label_to_idx[idx]
    random.shuffle(idxs)

    support_set = [1-self.images_dataset[i][0] for i in idxs[:self.nsupport]]
    query_set = [1-self.images_dataset[i][0] for i in idxs[self.nsupport: self.nsupport + self.nquery]]

    combined_set = support_set + query_set
    combined_set = torch.stack(combined_set)

    return combined_set.to(device)

In [3]:
class ConvolutionBlock(nn.Module):
  def __init__(self, in_channels, out_channels = 64):
    super().__init__()

    self.convBlock = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding = 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )
  
  def forward(self, input):
    return self.convBlock(input)

class FewShotLearning(nn.Module):
  def __init__(self, in_channels = 1, out_channels = 64):
    super().__init__()

    self.convBlocks = nn.Sequential(
        ConvolutionBlock(in_channels = in_channels),
        ConvolutionBlock(in_channels = out_channels),
        ConvolutionBlock(in_channels = out_channels),
        ConvolutionBlock(in_channels = out_channels),
        nn.Flatten(start_dim = 1, end_dim = -1)
    )

  def forward(self, input):
    return self.convBlocks(input)

In [4]:
def episode(model, images, nsupport, nquery):
  """
  images shape: [nways, nsupport + query, 1, 28, 28]
  """

  nways = images.shape[0]

  # embedding the images
  images = images.view(-1, *images.size()[2:])
  images = model(images)
  images = images.view(nways, -1, *images.size()[1:])

  # extracting the support images and computing the prototype
  support_images = images[:, :nsupport]
  support_images = torch.mean(support_images, dim = 1)

  # extracting the query images
  query_images = images[:, nsupport:]
  query_images = query_images.reshape(-1, 64)

  # computing the distance between query images and prototype
  distance = torch.cdist(query_images, support_images)

  target = []
  for i in range(nways):
    for j in range(nquery):
      target.append(i)
  
  prediction = F.log_softmax(-1*distance, dim = -1)
  predicted_class = torch.argmax(prediction, dim = -1)
  target = torch.tensor(target, device=device, dtype = torch.long)

  loss = F.nll_loss(prediction, target)
  accuracy = torch.eq(predicted_class, target).float().mean()

  return loss, accuracy

In [5]:
def test(model, testing_dataloader, nsupport, nquery):
  total_accuracy = 0

  for datapoint in testing_dataloader:
    _, accuracy = episode(model, datapoint, nsupport, nquery)
    total_accuracy += accuracy
  
  return total_accuracy/len(testing_dataloader)

def train(model, optimizer, scheduler, training_dataloader, testing_dataloader, nsupport, nquery, epochs = 50):
  print(f"There are going to be {len(training_dataloader)} episodes in one epoch of training.")
  print()

  for epoch in range(epochs):
    loss_every_epoch = []
    accuracy_every_epoch = []

    for datapoint in training_dataloader:
      loss, accuracy = episode(model, datapoint, nsupport, nquery)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      scheduler.step()

      loss_every_epoch.append(loss.item())
      accuracy_every_epoch.append(accuracy.item())
    
    print(f"Epoch: {epoch} \nTrain Loss: {np.round(np.mean(loss_every_epoch), decimals = 6)} \t Accuracy: {np.round(np.mean(accuracy_every_epoch)*100, decimals = 3)}")

    model.eval()
    testing_accuracy = test(model, testing_dataloader, nsupport, nquery)
    model.train()

    print(f"Test Accuracy: {np.round(testing_accuracy.item()*100, decimals = 3)}")
    print()

In [6]:
nsupport = 1
nquery = 5

training_dataset = OmniglotDataset(nsupport=nsupport,nquery=nquery, Train = True)
testing_dataset = OmniglotDataset(nsupport=nsupport,nquery=nquery, Train = False)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
training_nways = 5
testing_nways = 5

training_dataloader = DataLoader(training_dataset, batch_size = training_nways, shuffle = True)
testing_dataloader = DataLoader(testing_dataset, batch_size = testing_nways, shuffle = True)

In [9]:
model = FewShotLearning().to(device)
optimizer = torch.optim.Adam(params=model.parameters() ,lr=0.001) 
scheduler = optim.lr_scheduler.StepLR(optimizer, 2000, gamma=0.5, last_epoch=-1)
train(model, optimizer, scheduler, training_dataloader, testing_dataloader, nsupport, nquery)

There are going to be 193 episodes in one epoch of training.

Epoch: 0 
Train Loss: 0.466694 	 Accuracy: 84.943
Test Accuracy: 88.053%

Epoch: 1 
Train Loss: 0.253091 	 Accuracy: 91.684
Test Accuracy: 90.197%

Epoch: 2 
Train Loss: 0.188804 	 Accuracy: 94.093
Test Accuracy: 90.076%

Epoch: 3 
Train Loss: 0.184587 	 Accuracy: 94.508
Test Accuracy: 91.159%

Epoch: 4 
Train Loss: 0.157052 	 Accuracy: 95.368
Test Accuracy: 94.454%

Epoch: 5 
Train Loss: 0.14489 	 Accuracy: 95.751
Test Accuracy: 93.053%

Epoch: 6 
Train Loss: 0.12511 	 Accuracy: 96.389
Test Accuracy: 93.75%

Epoch: 7 
Train Loss: 0.128075 	 Accuracy: 95.938
Test Accuracy: 92.023%

Epoch: 8 
Train Loss: 0.11653 	 Accuracy: 96.285
Test Accuracy: 93.212%

Epoch: 9 
Train Loss: 0.094023 	 Accuracy: 97.036
Test Accuracy: 95.379%

Epoch: 10 
Train Loss: 0.083858 	 Accuracy: 97.368
Test Accuracy: 94.758%

Epoch: 11 
Train Loss: 0.076848 	 Accuracy: 97.762
Test Accuracy: 94.076%

Epoch: 12 
Train Loss: 0.069271 	 Accuracy: 97.731
T