**Section 1: Introduction**

This notebook implements the Few Shot Learning Meta-Learning method described in https://openreview.net/forum?id=rJY0-Kcll. This section imports the required libraries, sets the computation device, and defines the hyperparameters. 

In [None]:
! pip install torchsummary

In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil
import glob
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# import torchvision.models as models

from torch.optim import Adam
from torch.utils import data
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from PIL import Image
from sklearn.neighbors import KNeighborsClassifier

# from collections import OrderedDict
# from tensorboardX import SummaryWriter

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU instead")

# HYPERPARAMETERS
NUM_EPOCH = 10
NUM_SHOT = 5    # 5-shot
# NUM_SHOT = 1    # 1-shot
NUM_EVAL = 3*NUM_SHOT
NUM_CLASS = 5
BATCH_SIZE = NUM_CLASS*NUM_SHOT
NUM_EPI_TRAIN = 12000  # reference repo: 50000
NUM_EPI_TEST = 1000
NUM_EPI_EVAL = 100
MOMENTUM = 0.95
EPS = 1e-03
GRAD_CLIP = 0.25
LR_DECAY = 0.0005
LR_INIT = 0.001
IMAGE_DIM = 64  # cifar-100, omniglot
# IMAGE_DIM = 96 # miniImageNet
INPUT_SIZE = 4
HIDDEN_SIZE = 20
VAL_FREQ = 100

# PATHS
OUTPUT_DIR = '/kaggle/working/'

# # CIFAR-100
# TRAIN_IMG_DIR = '/kaggle/input/cifar100-fs/cifar100/metatrain'
# TEST_IMG_DIR = '/kaggle/input/cifar100-fs/cifar100/metatest'
# VAL_IMG_DIR = '/kaggle/input/cifar100-fs/cifar100/metaval'

# # MiniImageNet
# TRAIN_IMG_DIR = '/kaggle/input/MiniImageNet/miniImagenet/train'
# TEST_IMG_DIR = '/kaggle/input/MiniImageNet/miniImagenet/test'
# VAL_IMG_DIR = '/kaggle/input/MiniImageNet/miniImagenet/val'

# Omniglot
BG_IMG_DIR = OUTPUT_DIR + 'bg_images'  # Background images
os.makedirs(BG_IMG_DIR, exist_ok=True)
EV_IMG_DIR = OUTPUT_DIR + 'ev_images'  # Evaluation images
os.makedirs(EV_IMG_DIR, exist_ok=True)

torchvision.datasets.Omniglot(root=BG_IMG_DIR, background=True, transform=None, target_transform=None, download=True)
BG_IMG_DIR = BG_IMG_DIR + '/omniglot-py/images_background'
torchvision.datasets.Omniglot(root=EV_IMG_DIR, background=False, transform=None, target_transform=None, download=True)
EV_IMG_DIR = EV_IMG_DIR + '/omniglot-py/images_evaluation'
# os.listdir(EV_IMG_DIR + '/omniglot-py/images_evaluation')

TRAIN_IMG_DIR = OUTPUT_DIR + 'train'
TEST_IMG_DIR = OUTPUT_DIR + 'test'
VAL_IMG_DIR = OUTPUT_DIR + 'val'

LOG_DIR = OUTPUT_DIR + 'tblogs'  # tensorboard logs
CHECKPOINT_DIR = OUTPUT_DIR + 'models'  # model checkpoints=
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

**Section 2: Class definitions**

This section defines the classes used for the training routine. They are:

* Learner: CNN4 architecture commonly used for Mini ImageNet

* MetaLearner: Two-layer LSTM architecture to learn the optimal parameters for Learner. First layer is a normal LSTM cell, and second layer is a customized MetaLSTMCell

* MetaLSTMCell: 2nd layer of LSTM meta-learner as proposed by the authors. The update equations for the parameters are:
$$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t}$$
    * $c_t = \theta_t$ are the Learner NN parameters at time $t$
    * $\tilde{c_t} = -\nabla_{\theta_{t-1}} \mathcal{L}_t$ is the candidate cell state, also equal to the negative gradient of the loss at time $t$
    * $i_t = \alpha_t$ is the learning rate at time step $t$ and is given by the sigmoid function of:
    $$i_t = \sigma \,(\boldsymbol{W}_I \cdot [\nabla_{\theta_{t-1}} \mathcal{L}_t, \mathcal{L}_t, c_{t-1}, i_{t-1}] + \boldsymbol{b}_I)$$
    * $f_t$ is the forget gate strength at time step $t$ and is given by the sigmoid function of:
    $$f_t = \sigma \,(\boldsymbol{W}_F \cdot [\nabla_{\theta_{t-1}} \mathcal{L}_t, \mathcal{L}_t, c_{t-1}, f_{t-1}] + \boldsymbol{b}_F)$$

* CustomDataset: Dataset class that returns the image and its label as two tensors

* EpisodeDataset: Dataset class for training and evaluating few-shot learning models by episodes. Each episode will contain num_class classes, with num_shot training images from each class, and num_eval query images from each of the num_class classes.

* EpisodeSampler: Sampler class that generates episodes for few-shot learning.

In [None]:
class CustomDataset(Dataset):
    """
    A custom dataset class that returns the image and its label as two tensors.
    """
    
    def __init__(self, images, label, transform=None):
        self.images = images
        self.label = label
        self.transform = transform

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
            
        # Convert image and label to PyTorch tensors
        image_ten = torch.Tensor(image)
        label_ten = torch.tensor(self.label)

        return image_ten, label_ten

    def __len__(self):
        return len(self.images)

class EpisodeDataset(Dataset):
    """
    A custom dataset for training and evaluating few-shot learning models by episodes.
    """

    def __init__(self, root, num_shot=5, num_eval=15, transform=None):
#         root = os.path.join(root, phase)
        self.labels = sorted(os.listdir(root))
        
        images = []
        self.episode_loader = []
        
        for label in self.labels:
            images.append(glob.glob(os.path.join(root, label, '*')))
            
        for idx, _ in enumerate(self.labels):
            loader = DataLoader(CustomDataset(images=images[idx], label=idx, transform=transform),
                                batch_size=num_shot + num_eval, shuffle=True, num_workers=0)
            self.episode_loader.append(loader)

    def __getitem__(self, idx):
        return next(iter(self.episode_loader[idx]))

    def __len__(self):
        return len(self.labels)

class EpisodeSampler(Sampler):
    """
    A sampler that generates episodes for few-shot learning.
    """

    def __init__(self, total_classes, num_class, num_episode):
        self.total_classes = total_classes
        self.num_class = num_class
        self.num_episode = num_episode

    def __iter__(self):
        for i in range(self.num_episode):
            episode_classes = torch.randperm(self.total_classes)[:self.num_class]
            yield episode_classes

    def __len__(self):
        return self.num_episode

In [None]:
# CLASS DEFINITIONS
class Learner(nn.Module):
    """
    Learner module uses a simple CNN containing 4 convolutional layers, each of which is a 3 × 3
    convolution with 32 filters, followed by batch normalization, a ReLU non-linearity, and lastly a
    2 × 2 max-pooling. The network then has a final linear layer followed by a softmax for the number
    of classes being considered. The loss function L is the average negative log-probability assigned by
    the learner to the correct class.
    """
    
    def __init__(self, num_class=NUM_CLASS, image_size=IMAGE_DIM, momentum=MOMENTUM, eps=EPS):
        super(Learner, self).__init__()
        self.image_size = image_size
        
        # Model: four conv layers and one linear layer
        self.model = nn.ModuleDict({
            'conv1': nn.Conv2d(3, 32, kernel_size=3, padding=1),
            'norm1': nn.BatchNorm2d(32, momentum=momentum, eps=eps),
            'relu1': nn.ReLU(),
            'pool1': nn.MaxPool2d(kernel_size=2),
            
            'conv2': nn.Conv2d(32, 32, kernel_size=3, padding=1),
            'norm2': nn.BatchNorm2d(32, momentum=momentum, eps=eps),
            'relu2': nn.ReLU(),
            'pool2': nn.MaxPool2d(kernel_size=2),
            
            'conv3': nn.Conv2d(32, 32, kernel_size=3, padding=1),
            'norm3': nn.BatchNorm2d(32, momentum=momentum, eps=eps),
            'relu3': nn.ReLU(),
            'pool3': nn.MaxPool2d(kernel_size=2),
            
            'conv4': nn.Conv2d(32, 32, kernel_size=3, padding=1),
            'norm4': nn.BatchNorm2d(32, momentum=momentum, eps=eps),
            'relu4': nn.ReLU(),
            'pool4': nn.MaxPool2d(kernel_size=2),
            
            'flatten': nn.Flatten(),
            'linear': nn.Linear(32 * (image_size // 16) ** 2, num_class) # 'pool4' out-channels * (image_size // 16) ** 2
        })
        
        # Softmax layer
        self.softmax_layer = nn.Softmax(dim=1)
        
        # CEL as loss function
        self.criterion = nn.CrossEntropyLoss()

    # Forward propagation
    def forward(self, x):
        for layer in self.model.values():
            x = layer(x)
#             print(x.size())
#         x = x.view(x.size(0), -1)
        x = self.softmax_layer(x)
        return x
    
    # Flatten parameters
    def get_flat_params(self):
        return torch.cat([p.view(-1) for p in self.model.parameters()], 0)
        # Initialize empty list to store flattened parameters
        flat_params = []
        
        # Iterate through each parameter in the model
        for param in self.model.parameters():
            flat_param = param.view(-1)
            flat_params.append(flat_param)
            
        # Concatenate all the flat parameters into a single tensor
        return torch.cat(flat_params, 0)
        
    # Copy flattened parameters of cI for transfer later
    def copy_flat_params(self, cI):
        idx = 0
        
        # Iterate over learnable parameters (model)
        for param in self.model.parameters():
            flattened_len = param.view(-1).size(0)
            
            # Copy the corresponding values from cI into the parameter tensor
            param.data.copy_(cI[idx: idx+flattened_len].view_as(param))
            idx += flattened_len
        
    # Replace weights and biases with cI values
    def transfer_params(self, learner, cI):
        
        # Copy the running mean/variance of batch normalization layers from trained_net
        self.load_state_dict(learner.state_dict())
        
        # Replace the weights and biases of each module with cI
        index = 0
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or \
               isinstance(module, nn.BatchNorm2d) or \
               isinstance(module, nn.Linear):

                # Replace the weights with cI
                weight_len = module.weight.view(-1).size(0)
                module._parameters['weight'] = cI[index:index+weight_len].view_as(module.weight).clone()
                index += weight_len

                # Replace the biases with cI
                if module.bias is not None:
                    bias_len = module.bias.view(-1).size(0)
                    module._parameters['bias'] = cI[index:index+bias_len].view_as(module.bias).clone()
                    index += bias_len        
                    
    # Erase running stats
    def erase_batch_stats(self):       
        for module in self.modules():
            # Erase the running mean and variance of batch normalization layers
            if isinstance(module, nn.BatchNorm2d): 
                module.reset_running_stats()
#                 print('Erased running stats')
                
class MetaLSTMCell(nn.Module):
    """
    MetaLSTMCell Learner module that learns an update rule for training a neural network.
    C_t = f_t * C_{t-1} - i_t * \Delta_{\theta_{t-1}} L_t
    """

    def __init__(self, input_size, hidden_size, n_learner_params):
        super(MetaLSTMCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_learner_params = n_learner_params
        self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1))

        # Weights and biases of forget gate
        self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
        self.bF = nn.Parameter(torch.Tensor(1, hidden_size))

        # Weights and biases of input gate
        self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
        self.bI = nn.Parameter(torch.Tensor(1, hidden_size))

        # Initialize weights and biases with random values
        self.reset_parameters()

    def reset_parameters(self):
        for weight in self.parameters():
            nn.init.uniform_(weight, -0.01, 0.01)

        # Initialize the forget bias to a high value and the input bias to a low value
        # so that the model starts with gradient descent
        nn.init.uniform_(self.bF, 4, 6)
        nn.init.uniform_(self.bI, -6, -4)

    def init_cI(self, flat_params):
        self.cI.data.copy_(flat_params.unsqueeze(1))
#         print('!',self.cI.shape,flat_params.shape)
#         self.cI.copy_(flat_params.unsqueeze(1))

    def forward(self, inputs, hx=None):
        
        lstm1_hidden_output, grad = inputs
        batch, _ = lstm1_hidden_output.size()

        # Initialize the previous state if not provided
        if hx is None:
            f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device)
            i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device)
            c_prev = self.cI
            hx = [f_prev, i_prev, c_prev]
        else:
            f_prev, i_prev, c_prev = hx
        
#         print(f' meta lstm c_prev shape: {  c_prev.shape}')
#         print(f' meta lstm f_prev shape: {  f_prev.shape}')
#         print(f' meta lstm i_prev shape: {  i_prev.shape}')
#         print(f' lstm1_hidden_output shape: {lstm1_hidden_output.shape}')
        test = torch.cat((lstm1_hidden_output, c_prev, f_prev), 1)
#         print(f'cat :{test.shape}')
#         print(f'self.WF :{self.WF.shape}')
        # Compute the forget gate using the previous forget gate, cell state, input and gradients
        f_next = torch.mm(torch.cat((lstm1_hidden_output, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev)
        sig_f_next = torch.sigmoid(f_next)

        # Compute the input gate using the previous input gate, cell state, input and gradients
        i_next = torch.mm(torch.cat((lstm1_hidden_output, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev)
        sig_i_next = torch.sigmoid(i_next)

        # Update the cell state using the forget gate and input gate
        c_next = sig_f_next * c_prev - sig_i_next * grad

        # Pack the outputs into a list and return
        outputs = [f_next, i_next, c_next]
        return c_next, outputs


class MetaLearner(nn.Module):
    """
    MetaLearner Module is a 2-layer LSTM, where the first layer is
    a normal LSTM and the second layer is the MetaLSTMCell. The gradients and losses
    are preprocessed and fed into the first layer LSTM, and the regular gradient coordinates are
    used by the second layer LSTM to implement the state update rule. At each time step,
    the learner’s loss and gradient is computed on a batch consisting of the entire training set Dtrain.
    """
    def __init__(self, input_size, hidden_size, n_learner_params):
        super(MetaLearner, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_learner_params = n_learner_params
        
        # First Layer LSTM
        self.lstm1 = nn.LSTMCell(input_size, hidden_size) 
        
        # Second Layer MetaLSTMCell
        self.lstm2 = MetaLSTMCell(hidden_size, 1, n_learner_params) 
        # input_size of MetaLSTMCell should be the hidden_size of self.lstm1
        # hidden_size of MetaLSTMCell should be 1
        # because cI is [n_params_len, 1] output is updated cI should also be [n_params_len, 1] 
    
    # Forward pass
    def forward(self, inputs, hs=None):
        
        # Unpack preprocessed gradients and losses
        losses_prep, grads_prep, grads = inputs
        
        if hs is None:
            hs = [None, None]
        
        # Concatenate and resize
        loss = losses_prep.expand_as(grads_prep)
        inputs = torch.cat((loss, grads_prep), 1)   # [n_learner_params, 4]
#         print(f' meta learner inputs shape {inputs.shape}')

        # Feed inputs to first layer LSTM
        lstmhx, lstmcx = self.lstm1(inputs, hs[0]) # input_size = 4, hidden_size = 20
#         print(f' meta learner lstm1 ht shape: {lstmhx.shape} ct shape :{lstmcx.shape}')
        
        # Feed output of first layer LSTM and regular gradient coordinates to second layer LSTM
        flat_learner_unsqzd, lstm2_outputs = self.lstm2([lstmhx, grads], hs[1]) 
        
        c_next = flat_learner_unsqzd.squeeze()
        f_next, i_next, c_next = lstm2_outputs

        return c_next, [(lstmhx, lstmcx), lstm2_outputs] 

**Section 3: Function definitions**

Additional functions used in the train/validate/test routines are defined below:

* reshape_input_target: Structures the train and test input into specified array sizes

* get_flat_grads: Extracts flattened gradient information from a model, normalized by batch size

* scale_sign: Normalizes values of losses and gradients so that the meta-learner is able to use them properly during training. The preprocessing adjusts the scaling of gradients and losses, and separates magnitude and sign information (hence a pair of variables). The equation is as follows:

$$x \rightarrow \begin{cases}
    [\frac{ln|x|}{p}, sgn(x)], &\text{if } x \geq e^{-p}\\
    [-1, e^p \,x], &\text{otherwise}
    \end{cases}$$

* epoch_loop: Trains the epoch_learner with $D_{train}$ datasets during meta-training

In [None]:
def reshape_input_target(episode, num_shot=5, num_eval=15, num_class=5):
    """
    Structures the train and test input into specified array sizes
    """
    train_input = episode[:, :num_shot].reshape(-1, *episode.shape[-3:]).to(device)   # [num_class * num_shot, 3x32x32]
    train_target = torch.LongTensor(np.repeat(range(num_class), num_shot)).to(device) # [num_class * num_shot]
    test_input = episode[:, num_shot:].reshape(-1, *episode.shape[-3:]).to(device)    # [num_class * num_eval, 3x32x32]
    test_target = torch.LongTensor(np.repeat(range(num_class), num_eval)).to(device)  # [num_class * num_eval]
    
    return train_input, train_target, test_input, test_target

def get_flat_grads(model, batch_size):
    grad_list = []
    
    for param in model.parameters():
        grad_list.append(param.grad.data.view(-1) / batch_size)
    
    return torch.cat(grad_list, 0)

def scale_sign(x):
    p = 10 # Paper: suggested value of p = 10 worked well
    
    # Check case and convert bool variable to float
    indicator = (x.abs() >= np.exp(-p)).to(torch.float32)

    # Preprocess x
    x1 = indicator * torch.log(x.abs() + 1e-8) / p + (1 - indicator) * -1
    x2 = indicator * torch.sign(x) + (1 - indicator) * np.exp(p) * x

    return torch.stack((x1, x2), 1)

def epoch_loop(epoch_learner, metalearner, train_input, train_target, num_epoch=NUM_EPOCH, batch_size=BATCH_SIZE):
    
    # Extract epoch_learner parameters and hidden state information from metalearner
    cI = metalearner.lstm2.cI.data
    hs = [None]
    
    # Epoch loop
    for _ in range(num_epoch):

        # Batch size loop
        for i in range(0, len(train_input), batch_size):

            # Extract training data (x) and labels (y) for each batch
            x = train_input[i:i+batch_size]
            y = train_target[i:i+batch_size]

            # Copy parameters of epoch_learner model from metalearner output cI
            epoch_learner.copy_flat_params(cI)

            # Compute predictions
            output = epoch_learner(x) # [batch_size , n_class]

            # Compute loss
            loss = epoch_learner.criterion(output, y) 
#                 print('Epoch: {} \tLearnerLoss: {:.4f}'.format(_ + 1, loss.item()))

            # Update gradients
            epoch_learner.zero_grad()
            loss.backward()

            # Extract gradient information as a flattened vector, normalized by batch size
            grads = get_flat_grads(epoch_learner, batch_size)

            # Preprocess gradients and loss information before feeding to metalearner
            grads_prep = scale_sign(grads)                   # [n_learner_params, 2] 
            losses_prep = scale_sign(loss.data.unsqueeze(0)) # [1, 2]

            metalearner_input = [losses_prep, grads_prep, grads.unsqueeze(1)]

            # Compute updated metalearner predictions on epoch_learner parameters cI
            cI, h = metalearner(metalearner_input, hs[-1])
#             hs.append(h)
    
    return cI

**Section 4: Data preparation**

This section preprocesses and loads the image data as follows:

* Resize to $64 \times 64$ or $96 \times 96$ pixels

* Random horizontal flip (train only)

* Normalization

* Meta-train/test/validation split of 964/347/312 classes

Original implementation also includes color jitter, which is excluded from this implementation for now due to the image size. Color will likely be a distinguishing feature for certain classes (e.g., apple).

In [None]:
# DATA PREPROCESSING
# Omniglot download and folder reshuffling
root = BG_IMG_DIR
train_root = TRAIN_IMG_DIR

if not os.path.isdir(train_root): 
    os.mkdir(train_root)
    
for language_folder in os.listdir(root):
    language_folder_path = os.path.join(root,language_folder)
    for character_folder in os.listdir(language_folder_path):
        current_folder = os.path.join(language_folder_path,character_folder)
        dst = os.path.join(train_root,language_folder+'_'+character_folder)
#         print(current_folder,dst)
        shutil.copytree(current_folder,dst)

root = EV_IMG_DIR
val_root = VAL_IMG_DIR
test_root = TEST_IMG_DIR

if not os.path.isdir(val_root): 
    os.mkdir(val_root)
if not os.path.isdir(test_root): 
    os.mkdir(test_root)
    
for i,language_folder in enumerate(os.listdir(root)):
    language_folder_path = os.path.join(root,language_folder)
    for character_folder in os.listdir(language_folder_path):
        current_folder = os.path.join(language_folder_path,character_folder)
        
        if i < 10: # put to val
            dst = os.path.join(val_root,language_folder+'_'+character_folder)
        else:
            dst = os.path.join(test_root,language_folder+'_'+character_folder)
            
#         print(current_folder,dst)
        shutil.copytree(current_folder,dst)

# Compute no. of classes
num_folders_train = len([f for f in os.listdir(TRAIN_IMG_DIR) if os.path.isdir(os.path.join(TRAIN_IMG_DIR, f))])
num_folders_val = len([f for f in os.listdir(VAL_IMG_DIR) if os.path.isdir(os.path.join(VAL_IMG_DIR, f))])
num_folders_test = len([f for f in os.listdir(TEST_IMG_DIR) if os.path.isdir(os.path.join(TEST_IMG_DIR, f))])
# print(f'The directory {TRAIN_IMG_DIR} contains {TRAIN_IMG_DIR} folders')

# Transformation for training data
transform_train = transforms.Compose([
#     transforms.RandomHorizontalFlip(),  # Remove for omniglot
    transforms.Resize((IMAGE_DIM,IMAGE_DIM)),
#     transforms.ColorJitter(
#         brightness=0.4,
#         contrast=0.4,
#         saturation=0.4,
#         hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Transformation for testing and validation data
transform_test_val = transforms.Compose([
    transforms.Resize((IMAGE_DIM,IMAGE_DIM)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Load train data
train_set = EpisodeDataset(root=TRAIN_IMG_DIR, num_shot=NUM_SHOT, num_eval=NUM_EVAL, transform=transform_train)
train_loader = DataLoader(train_set, num_workers=2, pin_memory=True,
        batch_sampler=EpisodeSampler(len(train_set), num_class=NUM_CLASS, num_episode=NUM_EPI_TRAIN))
print(f'Meta train sets created and loaded ({num_folders_train} classes)')

# Load validation data
val_set = EpisodeDataset(root=VAL_IMG_DIR, num_shot=NUM_SHOT, num_eval=NUM_EVAL, transform=transform_test_val)
val_loader = DataLoader(val_set, num_workers=2, pin_memory=False,
        batch_sampler=EpisodeSampler(len(val_set), num_class=NUM_CLASS, num_episode=NUM_EPI_EVAL))
print(f'Meta val sets created and loaded ({num_folders_val} classes)')

# Load test data
test_set = EpisodeDataset(root=TEST_IMG_DIR, num_shot=NUM_SHOT, num_eval=NUM_EVAL, transform=transform_test_val)
test_loader = DataLoader(test_set, num_workers=2, pin_memory=False,
        batch_sampler=EpisodeSampler(len(test_set), num_class=NUM_CLASS, num_episode=NUM_EPI_TEST))
print(f'Meta test sets created and loaded ({num_folders_test} classes)')

**Section 5: Initialize models and optimizer**

Prepare models and optimizer for training.

In [None]:
# MODEL INITIALIZATION
epoch_learner = Learner(num_class=NUM_CLASS, image_size=IMAGE_DIM, momentum=MOMENTUM, eps=EPS).to(device)   # Loops over epochs
episode_learner = Learner(num_class=NUM_CLASS, image_size=IMAGE_DIM, momentum=MOMENTUM, eps=EPS).to(device) # Loops over episodes
# baseline_learner = Learner(num_class=NUM_CLASS, image_size=IMAGE_DIM, momentum=MOMENTUM, eps=EPS).to(device) # For baseline implementation
print('LEARNER models created')
summary(epoch_learner, (3, IMAGE_DIM, IMAGE_DIM), BATCH_SIZE)
# summary(episode_learner, (3, IMAGE_DIM, IMAGE_DIM), BATCH_SIZE)

metalearner = MetaLearner(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE, n_learner_params=epoch_learner.get_flat_params().size(0)).to(device)
metalearner.lstm2.init_cI(epoch_learner.get_flat_params())
print('META-LEARNER model created')
# print(metalearner)

# OPTIMIZER
optimizer = torch.optim.Adam(metalearner.parameters(), LR_INIT)
# baseline_optimizer = torch.optim.Adam(baseline_learner.parameters(), LR_INIT)
print('Optimizers created')
# print(optimizer)

**Section 6: Meta-Training**

A total of 12000 episodes was used for meta-training, 100 episodes for meta-validation (every 100 meta-training episodes), and 1000 episodes for meta-testing. Each episode comprised n_shot + n_eval images for each of n_way classes. The batch size was set to 25, so each episode contained 25 training images and 75 testing images.

In [None]:
# MODEL TRAINING
# Set TRAIN_FLAG to 0 to skip training and go straight to meta-testing (next section)
TRAIN_FLAG = 1
seed = np.random.randint(1, 10000)
torch.manual_seed(seed)
print('Used seed : {}'.format(seed))

tbwriter = SummaryWriter(log_dir=LOG_DIR)
print('TensorboardX summary writer created')

total_steps = 1
   
if TRAIN_FLAG: 
    
    print('Start training')
    train_acc_100 = []# Training accuracy per episodes
    train_acc_all = [] # Training accuracy per 100 episodes
    best_train_acc = 0
    best_train_ep = 0
    val_acc_all = [] # Average validation accuracy per episode at each set of Val
    best_val_acc = 0
    best_val_ep = 0
    
    # Episode loop (meta-train)
    for ep, (episode, _) in enumerate(train_loader): 
        
        # Reshape meta-train episode data
        train_input, train_target, test_input, test_target = reshape_input_target(
            episode,
            num_shot=NUM_SHOT,
            num_eval=NUM_EVAL,
            num_class=NUM_CLASS
        )

        # Erase batch stats for each episode
        epoch_learner.erase_batch_stats()
        episode_learner.erase_batch_stats()     
        # this is the reason in the original repo, author create a learner_wo_grad
        # by right retain_graph=True should solve 
        # https://stackoverflow.com/questions/46774641/what-does-the-parameter-retain-graph-mean-in-the-variables-backward-method
        # not working
        
        # Set to train mode
        epoch_learner.train()
        episode_learner.train()
        
        # Epoch loop
        cI = epoch_loop(
            epoch_learner, 
            metalearner, 
            train_input, 
            train_target, 
            num_epoch=NUM_EPOCH, 
            batch_size=BATCH_SIZE
        )       

        # Copy parameters of learner model from metalearner output cI after the last epoch of the episode
        episode_learner.transfer_params(epoch_learner, cI)

        # Compute predictions on test data
        output = episode_learner(test_input)
        
        # Update loss
        loss = episode_learner.criterion(output, test_target)
        
        # Update gradients
        optimizer.zero_grad()
        loss.backward()
        
        # Apply gradient clipping (to manage exploding gradients issue with LSTM)
        nn.utils.clip_grad_norm_(metalearner.parameters(), GRAD_CLIP)
        
        optimizer.step()
        
        # Compute accuracy
        with torch.no_grad():
            _, preds = torch.max(output, 1)
            episode_acc = torch.sum(preds == test_target).item()/len(test_target)
        train_acc_100.append(episode_acc)
        
        # Log the information and add to tensorboard
        if (ep+1) % 50 == 0:
            print('Episode: {}/{} \tTrainingLoss: {:.4f} \tEpisodeAcc: {:.2f}%'
                .format(ep + 1, NUM_EPI_TRAIN, loss.item(), episode_acc * 100))
            tbwriter.add_scalar('train_loss', loss.item())
            tbwriter.add_scalar('train_accuracy', episode_acc)
        
        # Show best accuracy
        if episode_acc > best_train_acc:
            best_train_acc = episode_acc
            best_train_ep = ep+1
            print('Best training accuracy is from training episode {}: {:.2f}%'
                  .format(best_train_ep, best_train_acc * 100))
                
        # Episode loop (meta-val)
        if (ep+1) % VAL_FREQ == 0:
            
            # Compute average training loss
            train_acc_all.append(np.mean(train_acc_100))
            train_acc_100 = []
            
            # Reset mean val accuracy and loss
            val_acc_total = 0
            val_loss_total = 0
            
            # Episode loop (meta-val)
            for val_ep, (val_episode, _) in enumerate(val_loader):
                
                # Reshape meta-val episode data
                train_input, train_target, test_input, test_target = reshape_input_target(
                    val_episode,
                    num_shot=NUM_SHOT,
                    num_eval=NUM_EVAL,
                    num_class=NUM_CLASS
                )
        
                # Erase batch stats for each episode
                epoch_learner.erase_batch_stats()
                episode_learner.erase_batch_stats()     

                # Set to train/eval mode
                epoch_learner.train()
                episode_learner.eval()  # No training for episode_learner - why?

                # Epoch loop
                cI = epoch_loop(
                    epoch_learner, 
                    metalearner, 
                    train_input, 
                    train_target, 
                    num_epoch=NUM_EPOCH, 
                    batch_size=BATCH_SIZE
                )       

                # Copy parameters of learner model from metalearner output cI after the last epoch of the episode
                episode_learner.transfer_params(epoch_learner, cI)

                # Compute predictions on test data
                output = episode_learner(test_input)

                # Update validation loss
                loss = episode_learner.criterion(output, test_target)

                # Compute validation episode accuracy
                with torch.no_grad():
                    _, preds = torch.max(output, 1)
                    val_episode_acc = torch.sum(preds == test_target).item()/len(test_target)
#                     val_episode_acc = val_episode_acc.item()/test_target.size()
                
                # Compute total validation loss and accuracy
                val_loss_total = val_loss_total + loss.item()
                val_acc_total = val_acc_total + val_episode_acc
            
            # Compute validation episode mean accuracy and loss and add to tensorboard
            val_episode_loss_mean = val_loss_total/(val_ep+1)
            val_episode_acc_mean = val_acc_total/(val_ep+1)
            val_acc_all.append(val_episode_acc_mean)
            tbwriter.add_scalar('val_loss', val_episode_loss_mean)
            tbwriter.add_scalar('val_accuracy', val_episode_acc_mean)

            # Show best accuracy and save best models
            if val_episode_acc_mean > best_val_acc:
                best_val_acc = val_episode_acc_mean
                best_val_ep = ep+1
                print('Best validation accuracy is from training episode {}: {:.2f}%'
                      .format(best_val_ep, best_val_acc * 100))

                # Save checkpoints - metalearner
                meta_checkpoint_path = os.path.join(CHECKPOINT_DIR, 'metalearner_best.pth')
                torch.save(metalearner.state_dict(), meta_checkpoint_path)
#                 print(f'Saved metalearner model weights to {meta_checkpoint_path}')

                # Save checkpoints - learner
                learner_checkpoint_path = os.path.join(CHECKPOINT_DIR, 'learner_best.pth')
                torch.save(episode_learner.state_dict(), learner_checkpoint_path)
#                 print(f'Saved learner model weights to {learner_checkpoint_path}')

                # Save checkpoints - optimizer
                optimizer_checkpoint_path = os.path.join(CHECKPOINT_DIR, 'optimizer_best.pth')
                torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
#                 print(f'Saved optimizer weights to {optimizer_checkpoint_path}')
                    
                print(f'Saved model and optimizer weights to {CHECKPOINT_DIR}')
        
print('Best training accuracy is from training episode {}: {:.2f}%'
      .format(best_train_ep, best_train_acc * 100))
print('Best validation accuracy is from training episode {}: {:.2f}%'
      .format(best_val_ep, best_val_acc * 100))

plt.figure(1)
plt.plot(train_acc_all, label='Training')
plt.plot(val_acc_all, label='Validation')
plt.xlabel('Episodes (x100)')
plt.ylabel('Average accuracy per 100 episodes')
plt.title('Training and validation accuracy (0-1)')
plt.legend()

plt.figure(2)
plt.plot(train_acc_all, label='Training')
plt.xlabel('Episodes (x100)')
plt.ylabel('Average training accuracy per 100 episodes)')
plt.title('Training accuracy (0-1)')

plt.figure(3)
plt.plot(val_acc_all, label='Validation')
plt.xlabel('Episodes (x100)')
plt.ylabel('Average validation accuracy per 100 episodes)')
plt.title('Validation accuracy (0-1)')

**Section 7: Meta-Testing**

1000 episodes were used for meta-testing. The mean, max, min, and standard deviation of test accuracy were computed to assess the meta-learner's performance.

In [None]:
# MODEL TEST

# Load best validation round parameters
metalearner_state = torch.load(meta_checkpoint_path, map_location=device)
metalearner.load_state_dict(metalearner_state)
learner_state = torch.load(learner_checkpoint_path, map_location=device)
epoch_learner.load_state_dict(learner_state)
episode_learner.load_state_dict(learner_state)

# Initialize accuracy and loss
test_acc_total = 0
test_acc_all = np.zeros([NUM_EPI_TEST,1])
test_loss_total = 0

# Episode loop (meta-test)
for test_ep, (test_episode, _) in enumerate(test_loader):

    # Reshape meta-val episode data
    train_input, train_target, test_input, test_target = reshape_input_target(test_episode,num_shot=NUM_SHOT,num_eval=NUM_EVAL,num_class=NUM_CLASS)

    # Erase batch stats for each episode
    epoch_learner.erase_batch_stats()
    episode_learner.erase_batch_stats()     

    # Set to train/eval mode
    epoch_learner.train()
    episode_learner.eval()  # No training for episode_learner - why?

    # Epoch loop
    cI = epoch_loop(epoch_learner, metalearner, train_input, train_target, num_epoch=NUM_EPOCH, batch_size=BATCH_SIZE)       

    # Copy parameters of learner model from metalearner output cI after the last epoch of the episode
    episode_learner.transfer_params(epoch_learner, cI)

    # Compute predictions on test data
    output = episode_learner(test_input)

    # Update test loss
    loss = episode_learner.criterion(output, test_target)

    # Compute test episode accuracy
    with torch.no_grad():
        _, preds = torch.max(output, 1)
        test_episode_acc = torch.sum(preds == test_target).item()/len(test_target)
#         print(test_episode_acc)
    test_acc_all[test_ep] = test_episode_acc

    # Compute total test loss and accuracy
    test_loss_total = test_loss_total + loss.item()
    test_acc_total = test_acc_total + test_episode_acc

# Compute test episode accuracies and loss
test_loss_mean = test_loss_total/NUM_EPI_TEST
test_acc_mean = np.mean(test_acc_all)
test_acc_max = np.max(test_acc_all)
test_acc_min = np.min(test_acc_all)
test_acc_std = np.std(test_acc_all)

print('Mean meta-test loss: {:.4f}'.format(test_loss_mean))
print('Mean meta-test accuracy: {:.2f}%'.format(test_acc_mean*100))
print('Max accuracy: {:.2f}% \tMin accuracy: {:.2f}% \tStd deviation: {:.2f}%'.format(test_acc_max*100, test_acc_min*100, test_acc_std*100))

plt.figure(4)
plt.hist(test_acc_all)
plt.xlabel('Test accuracy')
plt.ylabel('No. of episodes')
plt.title('Test accuracy - histogram')

**Section 8: Baseline setting**

Baseline accuracy was found using KNN (n_shot-neighbors) for 1000 episodes.

In [None]:
def compute_mean_features(support_set):
    """
    Computes flattened 2-channel mean vector [NUM_IMAGES, IMAGE_DIM*IMAGE_DIM]
    """
    image_means = []
    
    for image in support_set:
        image_mean = np.mean(image, axis=0)
        image_means.append(image_mean)
    
    return np.stack(image_means).reshape([len(support_set),-1])

def classify_episode(knn, train_input, train_target, test_input, test_target):
    """
    Uses KNN classifier to classify test query set using mean features from support set
    """
    # Convert to mean features
    train_input_mean_features = compute_mean_features(train_input)
    test_input_mean_features = compute_mean_features(test_input)
    # train_input_mean_features.shape

    # Fit KNN classifier
    knn.fit(train_input_mean_features, train_target)

    # Predict classes for the query set
    predicted_classes = knn.predict(test_input_mean_features)

    # Compute the accuracy
    accuracy = (predicted_classes == test_target).mean()

    return accuracy

# Classification runs
baseline_test_acc_all = []
knn = KNeighborsClassifier(n_neighbors=NUM_SHOT)
# knn = KNeighborsClassifier(n_neighbors=1)

for test_ep, (test_episode, _) in enumerate(test_loader):

    # Reshape episode data and convery to numpy arrays
    train_input, train_target, test_input, test_target = reshape_input_target(
        test_episode,
        num_shot=NUM_SHOT,
        num_eval=NUM_EVAL,
        num_class=NUM_CLASS
    )
    train_input = train_input.cpu().numpy()
    train_target = train_target.cpu().numpy()
    test_input = test_input.cpu().numpy()
    test_target = test_target.cpu().numpy()
    
    # Classify with KNN
    baseline_test_acc = classify_episode(knn, train_input, train_target, test_input, test_target)
    baseline_test_acc_all.append(baseline_test_acc)

baseline_test_acc_mean = np.mean(baseline_test_acc_all)
baseline_test_acc_max = np.max(baseline_test_acc_all)
baseline_test_acc_min = np.min(baseline_test_acc_all)
baseline_test_acc_std = np.std(baseline_test_acc_all)

print('Mean baseline test accuracy: {:.2f}%'.format(baseline_test_acc_mean*100))
print('Max accuracy: {:.2f}% \tMin accuracy: {:.2f}% \tStd deviation: {:.2f}%'
      .format(baseline_test_acc_max*100, baseline_test_acc_min*100, baseline_test_acc_std*100))

plt.figure(5)
plt.hist(baseline_test_acc_all)
plt.xlabel('Baseline test accuracy')
plt.ylabel('No. of episodes')
plt.title('Baseline test accuracy - histogram')

**Appendix: debug codes**

In [None]:
# Visualization checks

# # train_set - print first 5 episodes
# print(f"Total no. of classes: {len(train_set)}")

# counter = 1

# for data, label in test_set:
#     print(f"Data shape: {data.shape}")
#     print(f"Label shape: {label.shape}")
#     print(label)
    
#     for i in range(data.size()[0]):
#         disp_image = data[i]

#         # Convert the tensor to a numpy array
#         disp_image = disp_image.numpy()

#         # Transpose the dimensions of the image to match the expected format
#         disp_image = disp_image.transpose(1, 2, 0)

#         # Display the image using matplotlib
#         plt.imshow(disp_image)
#         plt.show()
    
#     counter = counter + 1
    
#     if counter > 5:
#         break

# # train_loader
# print(len(train_loader))

# for batch in train_loader:
#     data, labels = batch
#     print(f"Data shape: {data.shape}")
#     print(f"Labels shape: {labels.shape}")
#     print(labels)

# # Reshape function
# for eps, (episode, _) in enumerate(train_loader): 
#     train_input, train_target, test_input, test_target = reshape_input_target(episode, num_shot=1, num_eval=15, num_class=5)
#     print(train_target)
#     print(test_target)
#     break

# Print model parameters
# for name, param in learner.named_parameters():
#     if name == 'model.linear.weight':
#         print('Learner linear weights', param)
        
# for name, param in new_learner.named_parameters():
#     if name == 'model.linear.weight':
#         print('New learner linear weights', param)

In [None]:
# torch.autograd.set_detect_anomaly(True)

In [None]:
# os.listdir('/kaggle/working/models')