In [None]:
import os
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import logging

from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
from torchmeta.utils.gradient_based import gradient_update_parameters

In [None]:
import torch.nn as nn
from torchmeta.modules import MetaModule, MetaLinear 

In [None]:
import matplotlib.pyplot as plt

In [None]:
def conv3x3(in_channels, out_channels, **kwargs):
    # The convolutional layers (for feature extraction) use standard layers from
    # `torch.nn`, since they do not require adaptation.
    # See `examples/maml/model.py` for comparison.
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
        nn.BatchNorm2d(out_channels, momentum=1., track_running_stats=False),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

In [None]:
class ConvolutionalNeuralNetwork(MetaModule):
    def __init__(self, in_channels, out_features, hidden_size=64):
        super(ConvolutionalNeuralNetwork, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = hidden_size

        self.features = nn.Sequential(
            conv3x3(in_channels, hidden_size),
            conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size)
        )

        # Only the last (linear) layer is used for adaptation in ANIL
        self.classifier = MetaLinear(hidden_size, out_features)

    def forward(self, inputs, params=None):
        features = self.features(inputs)
        features = features.view((features.size(0), -1))
        logits = self.classifier(features, params=self.get_subdict(params, 'classifier'))
        return logits


In [None]:
from collections import OrderedDict

In [None]:
def get_accuracy(logits, targets):
    """Compute the accuracy (after adaptation) of MAML on the test/query points

    Parameters
    ----------
    logits : `torch.FloatTensor` instance
        Outputs/logits of the model on the query points. This tensor has shape
        `(num_examples, num_classes)`.

    targets : `torch.LongTensor` instance
        A tensor containing the targets of the query points. This tensor has 
        shape `(num_examples,)`.

    Returns
    -------
    accuracy : `torch.FloatTensor` instance
        Mean accuracy on the query points
    """
    _, predictions = torch.max(logits, dim=-1)
    return torch.mean(predictions.eq(targets).float())

In [None]:
class ARGS():
    folder = "data"
    num_shots = 1
    num_ways = 20
    download = False
    batch_size = 16
    num_workers = 0
    hidden_size = 64
    num_batches = 500
    step_size = 0.4
    first_order = False
    
    device = torch.device("cuda")

In [None]:
args = ARGS()

In [None]:
dataset = omniglot(args.folder,
                   shots=args.num_shots,
                   ways=args.num_ways,
                   shuffle=True,
                   test_shots=15,
                   meta_train=True,
                   download=args.download)

In [None]:
#dataset[(0, 1, 2, 3, 4)]

In [None]:
dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

In [None]:
model = ConvolutionalNeuralNetwork(1,
                                   args.num_ways,
                                   hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training loop
acc_list = []
with tqdm(dataloader, total=args.num_batches) as pbar:
     for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                    test_target) in enumerate(zip(train_inputs, train_targets,
                    test_inputs, test_targets)):
                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=args.step_size,
                                                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            
            acc_list.append(accuracy.item())
            if batch_idx >= args.num_batches:
                break

In [None]:
plt.plot(acc_list[:500])
plt.show()

In [None]:
acc_list

In [None]:
#import numpy as np

In [None]:
#np.savetxt("record/anil_5_5.csv", acc_list, delimiter=",")

In [None]:
pretrained_dict = torch.load("10_31.pth")

In [None]:
model_dict = model.state_dict()

In [None]:
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(model_dict)

In [None]:
# Training loop
acc_list2 = []
with tqdm(dataloader, total=args.num_batches) as pbar:
     for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                    test_target) in enumerate(zip(train_inputs, train_targets,
                    test_inputs, test_targets)):
                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=args.step_size,
                                                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            
            acc_list2.append(accuracy.item())
            if batch_idx >= args.num_batches:
                break

In [None]:
plt.plot(acc_list2)
plt.show()

In [None]:
acc_list2

In [None]:
# from torchmeta.datasets.helpers import omniglot
# from torchmeta.utils.data import BatchMetaDataLoader

# dataset = omniglot("data", ways=20, shots=5, test_shots=15, meta_train=True, download=True)
# dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

# for batch in dataloader:
#     train_inputs, train_targets = batch["train"]
#     print('Train inputs shape: {0}'.format(train_inputs.shape))    # (16, 25, 1, 28, 28)
#     print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 25)

#     test_inputs, test_targets = batch["test"]
#     print('Test inputs shape: {0}'.format(test_inputs.shape))      # (16, 75, 1, 28, 28)
#     print('Test targets shape: {0}'.format(test_targets.shape))    # (16, 75)
#     break

In [None]:
from torchmeta.datasets.helpers import doublemnist
from torchmeta.utils.data import BatchMetaDataLoader

In [None]:
from torchvision.transforms import Compose, Resize, ToTensor, Grayscale

In [None]:
kwargs = {
        'transform': Compose([ Grayscale(), Resize(28), ToTensor()])
    }

In [None]:
dataset = doublemnist(args.folder,
                   shots=args.num_shots,
                   ways=args.num_ways,
                   shuffle=True,
                   meta_train=True,
                   download=True,
                   transform = kwargs["transform"])

In [None]:
# dataset = miniimagenet(args.folder,
#                    shots=args.num_shots,
#                    ways=args.num_ways,
#                    shuffle=True,
#                    test_shots=15,
#                    meta_train=True,
#                    download=args.download,
#                    transform = kwargs["transform"]
#                       )

In [None]:
dataloader = BatchMetaDataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)

for batch in dataloader:
    train_inputs, train_targets = batch["train"]
    print('Train inputs shape: {0}'.format(train_inputs.shape))    # (16, 25, 1, 28, 28)
    print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 25)

    test_inputs, test_targets = batch["test"]
    print('Test inputs shape: {0}'.format(test_inputs.shape))      # (16, 75, 1, 28, 28)
    print('Test targets shape: {0}'.format(test_targets.shape))    # (16, 75)
    break

In [None]:
plt.imshow(train_inputs[2][12].permute(1,2,0).data.numpy(), cmap="gray")

In [None]:
model = ConvolutionalNeuralNetwork(1,
                                   args.num_ways,
                                   hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training loop
acc_list = []
with tqdm(dataloader, total=args.num_batches) as pbar:
     for batch_idx, batch in enumerate(pbar):
            model.zero_grad()
            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                    test_target) in enumerate(zip(train_inputs, train_targets,
                    test_inputs, test_targets)):
                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=args.step_size,
                                                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            
            acc_list.append(accuracy.item())
            if batch_idx >= args.num_batches:
                break

In [None]:
plt.plot(acc_list[:500])
plt.show()

In [None]:
acc_list