# Coding Exercise: Model Adaption Meta Learning(MiniImageNet Dataset)

In this tutorial, we will implement Model Adaption Meta Learning, to learn Omniglot classes with few examples.
If you recall Model Agnostic Meta Learning, consists of 2 loops:
1. To learn parameters for all tasks.
2. To learn task specific parameters

MAML algorithm aims to learn such parameters that can adapt to new tasks with very few examples. 

<img src="Images/parameters.png" width="600"/>


<img src="Images/maml_algo.png" width="500"/>

In [None]:
! pip install torchmeta

In [None]:
import torch
import torch.nn as nn
from torchmeta.modules import (MetaModule, MetaSequential, MetaConv2d, MetaConv3d,
                               MetaBatchNorm2d, MetaLinear)
from torchmeta.modules.utils import get_subdict
from collections import OrderedDict
import os
import torch.nn.functional as F
from tqdm import tqdm
import torchmeta
from torchmeta.utils.data import BatchMetaDataLoader 
from torchmeta.datasets import MiniImagenet
from torchmeta.transforms import Categorical, ClassSplitter, Rotation
from torchvision.transforms import Compose, Resize, ToTensor
from torchmeta.utils.data import BatchMetaDataLoader

# Load Omniglot dataset

Download Mini ImageNet Dataset: https://drive.google.com/file/d/1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk/view


In [None]:
def conv3x3(in_channels, out_channels, **kwargs):
    return MetaSequential(
        MetaConv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
        MetaBatchNorm2d(out_channels, momentum=1., track_running_stats=False),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

class ConvolutionalNeuralNetwork(MetaModule):
    def __init__(self, in_channels, out_features, hidden_size=64):
        super(ConvolutionalNeuralNetwork, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = hidden_size

        self.features = MetaSequential(
            conv3x3(in_channels, hidden_size),
            conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size),
            # conv3x3(hidden_size, hidden_size),
            # conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size)
        )

        self.classifier = MetaLinear(hidden_size, out_features)

    def forward(self, inputs, params=None):
        features = self.features(inputs, params=get_subdict(params, 'features'))
        features = features.view((features.size(0), -1))
        logits = self.classifier(features, params=get_subdict(params, 'classifier'))
        return logits

In [None]:
def update_parameters(model, loss, step_size=0.5, first_order=False):
    grads = torch.autograd.grad(loss, model.meta_parameters(),
        create_graph=not first_order)

    params = OrderedDict()
    for (name, param), grad in zip(model.meta_named_parameters(), grads):
        params[name] = param - step_size * grad

    return params

def get_accuracy(logits, targets):
    _, predictions = torch.max(logits, dim=-1)
    return torch.mean(predictions.eq(targets).float())

In [None]:
#  set, S, consisting of n examples each from k different unseen classes
device = torch.device('cuda')
num_ways = 5
step_size = 0.4
batch_size = 16
num_batches = 2000
first_order = True
hidden_size = 64
output_folder = './'
num_classes_per_task = 5

dataset = MiniImagenet("data",
                   # Number of ways
                   num_classes_per_task=num_classes_per_task,
                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
                   transform=Compose([Resize(28), ToTensor()]),
                   # Transform the labels to integers
                   target_transform=Categorical(num_classes=5),
                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
                   class_augmentations=[Rotation([90, 180, 270])],
                   meta_train=True,
                   download=True)
dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

In [None]:
model = ConvolutionalNeuralNetwork(3, num_ways, hidden_size=hidden_size)
model.to(device=device)
model.train()
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training loop
with tqdm(dataloader, total=num_batches) as pbar:
    for batch_idx, batch in enumerate(pbar):
        model.zero_grad()

        train_inputs, train_targets = batch['train']
        train_inputs = train_inputs.to(device=device)
        train_targets = train_targets.to(device=device)

        test_inputs, test_targets = batch['test']
        test_inputs = test_inputs.to(device=device)
        test_targets = test_targets.to(device=device)

        outer_loss = torch.tensor(0., device=device)
        accuracy = torch.tensor(0., device=device)
        
        
        for task_idx, (train_input, train_target, test_input,
                test_target) in enumerate(zip(train_inputs, train_targets,
                test_inputs, test_targets)):
            train_logit = model(train_input)
            inner_loss = F.cross_entropy(train_logit, train_target)

            model.zero_grad()
            params = update_parameters(model, inner_loss,
                step_size=step_size, first_order=first_order)

            test_logit = model(test_input, params=params)
            outer_loss += F.cross_entropy(test_logit, test_target)

            with torch.no_grad():
                accuracy += get_accuracy(test_logit, test_target)

        outer_loss.div_(batch_size)
        accuracy.div_(batch_size)

        outer_loss.backward()
        meta_optimizer.step()

        pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
        if batch_idx >= num_batches:
            break

# Save model
if output_folder is not None:
    filename = os.path.join(output_folder, 'maml_omniglot_'
        '{0}shot_{1}way.pt'.format(num_classes_per_task, num_ways))
    with open(filename, 'wb') as f:
        state_dict = model.state_dict()
        torch.save(state_dict, f)