In [None]:
dataset = DeepLesion(mode='test')

spl = dataset[0]
print(spl['y'])
plt.imshow(spl['x'])

root_dir='/home/Surya/Meta-Learning-Deep-Leision/'
mode = 'test'

with open( root_dir + 'X_' + mode + '.p', 'rb' ) as f:
    X = pickle.load(f)
    X = np.transpose(X, (0,3,1,2))
    
print(X.shape)

In [1]:
import os
import argparse
import pickle
import gc
import math
import random
from collections import namedtuple

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import learn2learn as l2l
from learn2learn.data import MetaDataset, TaskDataset
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels

In [2]:
class DeepLesion(Dataset):

    def __init__(self, mode, root_dir='/home/Surya/Meta-Learning-Deep-Leision/', transform=None):
        """
        Args:
            pickle_file (string): Path to the file with datasets.
            mode (string): Train or Val or Test.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        x = 'X_train.p'
        y = 'Y_train.p'
        with open( root_dir + 'X_' + mode + '.p', 'rb' ) as f:
            self.X = pickle.load(f)
        with open( root_dir + 'Y_' + mode + '.p', 'rb' ) as f:
            self.Y = pickle.load(f)
        self.X = np.float32(self.X)
        self.X = np.transpose(self.X, (0,3,1,2))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        x = self.X[idx]
        y = self.Y[idx]

        if self.transform:
            sample = self.transform(x)

        return np.float32(x), int(y)

In [3]:
def pairwise_distances_logits(a, b):
    n = a.shape[0]
    m = b.shape[0]
    logits = -((a.unsqueeze(1).expand(n, m, -1) -
                b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2)
    return logits


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


class Convnet(nn.Module):

    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
        super().__init__()
        self.encoder = l2l.vision.models.ConvBase(output_size=z_dim,
                                                  hidden=hid_dim,
                                                  channels=x_dim)
        self.out_channels = 1600

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)


def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    if device is None:
        device = model.device()
    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)
    n_items = shot * ways

    # Sort data samples by labels
    # TODO: Can this be replaced by ConsecutiveLabels ?
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)

    # Compute support and query embeddings
    embeddings = model(data)
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shot + query_num)
    for offset in range(shot):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    support = support.reshape(ways, shot, -1).mean(dim=1)
    query = embeddings[query_indices]
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)
    loss = F.cross_entropy(logits, labels)
    acc = accuracy(logits, labels)
    return loss, acc

In [4]:
gc.collect()

print('Loading Training Data')
train_dataset = DeepLesion(mode='test')
print('Loading Validation Data')
valid_dataset = DeepLesion(mode='test')
print('Loading Test Data')
test_dataset = DeepLesion(mode='test')

Loading Training Data
Loading Validation Data
Loading Test Data


In [5]:
parser = argparse.ArgumentParser()
parser.add_argument('--max-epoch', type=int, default=10)
parser.add_argument('--shot', type=int, default=3)
parser.add_argument('--test-way', type=int, default=5)
parser.add_argument('--test-shot', type=int, default=3)
parser.add_argument('--test-query', type=int, default=1)
parser.add_argument('--train-query', type=int, default=15)
parser.add_argument('--train-way', type=int, default=5)
parser.add_argument('--gpu', default=1)
args = parser.parse_args('')

In [8]:
device = torch.device('cpu')
if args.gpu and torch.cuda.device_count():
    print("Using gpu")
    torch.cuda.manual_seed(43)
    device = torch.device('cuda')

for N in [2,3,5]:
    for k in [1,3,5]:
        print('k = ' + str(k) + '; N = ' + str(N))
        args.shot = k
        args.test_shot = k
        args.train_way = N
        args.test_way = N
        
        model = Convnet()
        model.to(device)
        
        print(model)

        # print('Loading Training Data')
        # train_dataset = DeepLesion(mode='train')
        # print('Loading Validation Data')
        # valid_dataset = DeepLesion(mode='val')
        # print('Loading Test Data')
        # test_dataset = DeepLesion(mode='test')

        train_dataset = l2l.data.MetaDataset(train_dataset)
        train_transforms = [
            NWays(train_dataset, args.train_way),
            KShots(train_dataset, args.train_query + args.shot),
            LoadData(train_dataset),
            RemapLabels(train_dataset),
        ]
        train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms)
        train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True)

        valid_dataset = l2l.data.MetaDataset(valid_dataset)
        valid_transforms = [
            NWays(valid_dataset, args.test_way),
            KShots(valid_dataset, args.test_query + args.test_shot),
            LoadData(valid_dataset),
            RemapLabels(valid_dataset),
        ]

        num_tasks_val = math.floor(len(valid_dataset)/(args.test_way*(args.test_query + args.test_shot)))
        valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                           task_transforms=valid_transforms,
                                           num_tasks= num_tasks_val)
        valid_loader = DataLoader(valid_tasks, pin_memory=True, shuffle=True)

        test_dataset = l2l.data.MetaDataset(test_dataset)
        test_transforms = [
            NWays(test_dataset, args.test_way),
            KShots(test_dataset, args.test_query + args.test_shot),
            LoadData(test_dataset),
            RemapLabels(test_dataset),
        ]

        num_tasks_test = math.floor(len(test_dataset)/(args.test_way*(args.test_query + args.test_shot)))
        test_tasks = l2l.data.TaskDataset(test_dataset,
                                          task_transforms=test_transforms,
                                          num_tasks=num_tasks_test)
        test_loader = DataLoader(test_tasks, pin_memory=True, shuffle=True)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=20, gamma=0.5)

        for epoch in range(1, args.max_epoch + 1):
            model.train()

            loss_ctr = 0
            n_loss = 0
            n_acc = 0

            for i in range(100):
                batch = next(iter(train_loader))

                loss, acc = fast_adapt(model,
                                       batch,
                                       args.train_way,
                                       args.shot,
                                       args.train_query,
                                       metric=pairwise_distances_logits,
                                       device=device)

                loss_ctr += 1
                n_loss += loss.item()
                n_acc += acc

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            lr_scheduler.step()

            print('epoch {}, train, loss={:.4f} acc={:.4f}'.format(
                epoch, n_loss/loss_ctr, n_acc/loss_ctr))

            model.eval()

            loss_ctr = 0
            n_loss = 0
            n_acc = 0
            for i, batch in enumerate(valid_loader):
                loss, acc = fast_adapt(model,
                                       batch,
                                       args.test_way,
                                       args.test_shot,
                                       args.test_query,
                                       metric=pairwise_distances_logits,
                                       device=device)

                loss_ctr += 1
                n_loss += loss.item()
                n_acc += acc

            print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(
                epoch, n_loss/loss_ctr, n_acc/loss_ctr))

        loss_ctr = 0
        n_acc = 0

        for i, batch in enumerate(test_loader, 1):
            loss, acc = fast_adapt(model,
                                   batch,
                                   args.test_way,
                                   args.test_shot,
                                   args.test_query,
                                   metric=pairwise_distances_logits,
                                   device=device)
            loss_ctr += 1
            n_acc += acc
            print('batch {}: {:.2f}({:.2f})'.format(
                i, n_acc/loss_ctr * 100, acc * 100))

Using gpu
k = 1; N = 2
Convnet(
  (encoder): ConvBase(
    (0): ConvBlock(
      (normalize): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (1): ConvBlock(
      (normalize): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (2): ConvBlock(
      (normalize): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (3): ConvBlock(
      (normalize): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )
)


NameError: name 'train_dataset' is not defined

In [12]:
print(len(train_dataset))
print(len(valid_dataset))
print(len(test_dataset))


2994
268
277


### 2. FC MAML Benchmark

In [2]:
class FlatDeepLesion(Dataset):

    def __init__(self, mode, root_dir='/home/Surya/Meta-Learning-Deep-Leision/', transform=None):
        """
        Args:
            pickle_file (string): Path to the file with datasets.
            mode (string): Train or Val or Test.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        x = 'X_train.p'
        y = 'Y_train.p'
        with open( root_dir + 'X_' + mode + '.p', 'rb' ) as f:
            self.X = pickle.load(f)
        with open( root_dir + 'Y_' + mode + '.p', 'rb' ) as f:
            self.Y = pickle.load(f)
        self.X = np.float32(self.X)
        self.X = np.transpose(self.X, (0,3,1,2))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        x = self.X[idx]
        N = x.shape[0]
        x = x.reshape((N,-1))
        y = self.Y[idx]

        if self.transform:
            x = self.transform(x)

        return np.float32(x), int(y)

In [7]:
gc.collect()

print('Loading Training Data')
train_dataset = FlatDeepLesion(mode='train')
print('Loading Validation Data')
valid_dataset = FlatDeepLesion(mode='val')
print('Loading Test Data')
test_dataset = FlatDeepLesion(mode='test')

Loading Training Data
Loading Validation Data
Loading Test Data


In [8]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        train_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(train_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    valid_error = loss(predictions, evaluation_labels)
    valid_accuracy = accuracy(predictions, evaluation_labels)
    return valid_error, valid_accuracy

In [9]:
ways=5
shots=5
meta_lr=0.003
fast_lr=0.5
meta_batch_size=math.floor(len(test_dataset)/(ways*shots*2))
adaptation_steps=1
num_iterations=2000
cuda=True
seed=42

NameError: name 'test_dataset' is not defined

In [10]:
model = l2l.vision.models.OmniglotFC(512 ** 2 * 3, ways)
print(model)

OmniglotFC(
  (features): Sequential(
    (0): Flatten()
    (1): Sequential(
      (0): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(256, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=786432, out_features=256, bias=True)
      )
      (1): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(128, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=256, out_features=128, bias=True)
      )
      (2): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(64, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=128, out_features=64, bias=True)
      )
      (3): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(64, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=64, out_features=64, bias=True)
      )


In [13]:
import tensorflow as tf

layers = [
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(2**7, input_shape = (K,D), return_sequences=True)),
# tf.keras.layers.Dense(2**config.log_LSTM_HIDDEN_UNITS, activation='softmax', return_sequences=True),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(NUM_CLASSES, activation="softmax")),    
tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
]

meta_model = tf.keras.Sequential(layers)

print(meta_model.summary())

NameError: name 'K' is not defined

In [14]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cpu')
if cuda and torch.cuda.device_count():
    torch.cuda.manual_seed(seed)
    device = torch.device('cuda')

for ways in [5,3,2]:
    for shots in [1,3,5]:
        meta_batch_size=math.floor(len(test_dataset)/(ways*shots*2))
        print('Ways: ', ways)
        print('Shots: ', shots)
        print('Meta Size: ', meta_batch_size)
        

    # Create Tasksets using the benchmark interface
        train_dataset = l2l.data.MetaDataset(train_dataset)
        train_transforms = [
            NWays(train_dataset, ways),
            KShots(train_dataset, 2*shots),
            LoadData(train_dataset),
            RemapLabels(train_dataset),
            ConsecutiveLabels(train_dataset),
        ]

        valid_dataset = l2l.data.MetaDataset(valid_dataset)
        valid_transforms = [
            NWays(valid_dataset, ways),
            KShots(valid_dataset, 2*shots),
            LoadData(valid_dataset),
            ConsecutiveLabels(valid_dataset),
            RemapLabels(valid_dataset),
        ]

        test_dataset = l2l.data.MetaDataset(test_dataset)
        test_transforms = [
            NWays(test_dataset, ways),
            KShots(test_dataset, 2*shots),
            LoadData(test_dataset),
            RemapLabels(test_dataset),
            ConsecutiveLabels(test_dataset),
        ]

        # Instantiate the tasksets
        train_tasks = l2l.data.TaskDataset(
            dataset=train_dataset,
            task_transforms=train_transforms,
            num_tasks=-1,
        )
        validation_tasks = l2l.data.TaskDataset(
            dataset=valid_dataset,
            task_transforms=valid_transforms,
            num_tasks=-1,
        )
        test_tasks = l2l.data.TaskDataset(
            dataset=test_dataset,
            task_transforms=test_transforms,
            num_tasks=-1,
        )

        BenchmarkTasksets = namedtuple('BenchmarkTasksets', ('train', 'validation', 'test'))
        tasksets = BenchmarkTasksets(train_tasks, validation_tasks, test_tasks)


        #model = l2l.vision.models.MiniImagenetCNN(ways)
        # Create model
        model = l2l.vision.models.OmniglotFC(512 ** 2 * 3, ways)
        model.to(device)
        maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
        opt = optim.Adam(maml.parameters(), meta_lr)
        loss = nn.CrossEntropyLoss(reduction='mean')

        for iteration in range(num_iterations):
            opt.zero_grad()
            meta_train_error = 0.0
            meta_train_accuracy = 0.0
            meta_valid_error = 0.0
            meta_valid_accuracy = 0.0
            for task in range(meta_batch_size):
                # Compute meta-training loss
                learner = maml.clone()
                batch = tasksets.train.sample()
                evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                   learner,
                                                                   loss,
                                                                   adaptation_steps,
                                                                   shots,
                                                                   ways,
                                                                   device)
                evaluation_error.backward()
                meta_train_error += evaluation_error.item()
                meta_train_accuracy += evaluation_accuracy.item()

                # Compute meta-validation loss
                learner = maml.clone()
                batch = tasksets.validation.sample()
                evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                   learner,
                                                                   loss,
                                                                   adaptation_steps,
                                                                   shots,
                                                                   ways,
                                                                   device)
                meta_valid_error += evaluation_error.item()
                meta_valid_accuracy += evaluation_accuracy.item()

            # Print some metrics
            print('\n')
            print('Iteration', iteration)
            print('Meta Train Error', meta_train_error / meta_batch_size)
            print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
            print('Meta Valid Error', meta_valid_error / meta_batch_size)
            print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

            # Average the accumulated gradients and optimize
            for p in maml.parameters():
                p.grad.data.mul_(1.0 / meta_batch_size)
            opt.step()

        meta_test_error = 0.0
        meta_test_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-testing loss
            learner = maml.clone()
            batch = tasksets.test.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_test_error += evaluation_error.item()
            meta_test_accuracy += evaluation_accuracy.item()
        print('Meta Test Error', meta_test_error / meta_batch_size)
        print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

Ways:  5
Shots:  1
Meta Size:  27


Iteration 0
Meta Train Error 1.6108689617227625
Meta Train Accuracy 0.24444444919074024
Meta Valid Error 1.6053319816236142
Meta Valid Accuracy 0.2518518567085266


Iteration 1
Meta Train Error 1.5452510206787675
Meta Train Accuracy 0.23703704167295386
Meta Valid Error 1.6306430763668485
Meta Valid Accuracy 0.2370370422248487


Iteration 2
Meta Train Error 1.5805925263298883
Meta Train Accuracy 0.2518518567085266
Meta Valid Error 1.6403859500531797
Meta Valid Accuracy 0.20740741104991348


Iteration 3
Meta Train Error 1.4399870060108326
Meta Train Accuracy 0.385185193684366
Meta Valid Error 1.5512998854672466
Meta Valid Accuracy 0.2592592631225233


Iteration 4
Meta Train Error 1.5132488497981318
Meta Train Accuracy 0.2888888959531431
Meta Valid Error 1.621858036076581
Meta Valid Accuracy 0.2518518572604215


Iteration 5
Meta Train Error 1.608691899864762
Meta Train Accuracy 0.28888889540124824
Meta Valid Error 1.5454569966704756
Meta Valid Accuracy 



Iteration 50
Meta Train Error 1.395808948410882
Meta Train Accuracy 0.40000000982372846
Meta Valid Error 1.53866974512736
Meta Valid Accuracy 0.3481481555435393


Iteration 51
Meta Train Error 1.4358262574231182
Meta Train Accuracy 0.39259260175404725
Meta Valid Error 1.4690927929348416
Meta Valid Accuracy 0.33333334216365107


Iteration 52
Meta Train Error 1.4720825177651864
Meta Train Accuracy 0.31111111905839706
Meta Valid Error 1.6276521329526548
Meta Valid Accuracy 0.26666667174409936


Iteration 53
Meta Train Error 1.5844547307049786
Meta Train Accuracy 0.26666667174409936
Meta Valid Error 1.4940660021923207
Meta Valid Accuracy 0.3407407480257529


Iteration 54
Meta Train Error 1.5238940848244562
Meta Train Accuracy 0.33333334105986134
Meta Valid Error 1.5592752209416143
Meta Valid Accuracy 0.32592593409396986


Iteration 55
Meta Train Error 1.418464148486102
Meta Train Accuracy 0.37037038030447783
Meta Valid Error 1.391467880319666
Meta Valid Accuracy 0.4000000109275182


Iter

Meta Test Error 1.508516828219096
Meta Test Accuracy 0.3407407491295426
Ways:  5
Shots:  3
Meta Size:  9


Iteration 0
Meta Train Error 1.5961785581376817
Meta Train Accuracy 0.22222223298417199
Meta Valid Error 1.5736786656909518
Meta Valid Accuracy 0.2592592736085256


Iteration 1
Meta Train Error 1.5293483601676092
Meta Train Accuracy 0.32592594623565674
Meta Valid Error 1.4694450431399875
Meta Valid Accuracy 0.3111111306481891


Iteration 2
Meta Train Error 1.4983340369330511
Meta Train Accuracy 0.31111112899250454
Meta Valid Error 1.5293911430570815
Meta Valid Accuracy 0.2814814978175693


Iteration 3
Meta Train Error 1.4289744297663372
Meta Train Accuracy 0.31851853761408067
Meta Valid Error 1.5229289399252997
Meta Valid Accuracy 0.37037039465374416


Iteration 4
Meta Train Error 1.5276050832536485
Meta Train Accuracy 0.2888889014720917
Meta Valid Error 1.439652509159512
Meta Valid Accuracy 0.3629629843764835


Iteration 5
Meta Train Error 1.3972584803899128
Meta Train Accuracy 0



Iteration 49
Meta Train Error 1.2958056661817763
Meta Train Accuracy 0.43703706396950615
Meta Valid Error 1.3691450622346666
Meta Valid Accuracy 0.37037039796511334


Iteration 50
Meta Train Error 1.328812109099494
Meta Train Accuracy 0.4148148364490933
Meta Valid Error 1.3362724317444696
Meta Valid Accuracy 0.38518520527415806


Iteration 51
Meta Train Error 1.2159553501341078
Meta Train Accuracy 0.48148150576485527
Meta Valid Error 1.299242032898797
Meta Valid Accuracy 0.40740742948320174


Iteration 52
Meta Train Error 1.2113478713565402
Meta Train Accuracy 0.4296296603149838
Meta Valid Error 1.493204567167494
Meta Valid Accuracy 0.34074076513449353


Iteration 53
Meta Train Error 1.234691580136617
Meta Train Accuracy 0.4444444709353977
Meta Valid Error 1.344341066148546
Meta Valid Accuracy 0.4074074327945709


Iteration 54
Meta Train Error 1.3573199113210042
Meta Train Accuracy 0.3481481687890159
Meta Valid Error 1.5160019132826064
Meta Valid Accuracy 0.32592594457997215


Iterat



Iteration 99
Meta Train Error 1.123640278975169
Meta Train Accuracy 0.4888889127307468
Meta Valid Error 1.3618729246987238
Meta Valid Accuracy 0.4222222500377231
Meta Test Error 1.3559336264928181
Meta Test Accuracy 0.41481483976046246
Ways:  5
Shots:  5
Meta Size:  5


Iteration 0
Meta Train Error 1.580819296836853
Meta Train Accuracy 0.23999999314546586
Meta Valid Error 1.5619542121887207
Meta Valid Accuracy 0.23199999779462815


Iteration 1
Meta Train Error 1.4210361242294312
Meta Train Accuracy 0.2639999896287918
Meta Valid Error 1.5435815572738647
Meta Valid Accuracy 0.24799998998641967


Iteration 2
Meta Train Error 1.4931968688964843
Meta Train Accuracy 0.36799998581409454
Meta Valid Error 1.5365561246871948
Meta Valid Accuracy 0.32799998819828036


Iteration 3
Meta Train Error 1.442399787902832
Meta Train Accuracy 0.37599999010562896
Meta Valid Error 1.4095865488052368
Meta Valid Accuracy 0.3839999854564667


Iteration 4
Meta Train Error 1.4469724416732788
Meta Train Accuracy



Iteration 48
Meta Train Error 1.3777262687683105
Meta Train Accuracy 0.3439999908208847
Meta Valid Error 1.4384215116500854
Meta Valid Accuracy 0.34399998784065244


Iteration 49
Meta Train Error 1.4078989505767823
Meta Train Accuracy 0.44799998998641966
Meta Valid Error 1.3935215711593627
Meta Valid Accuracy 0.407999986410141


Iteration 50
Meta Train Error 1.2903289794921875
Meta Train Accuracy 0.40799999237060547
Meta Valid Error 1.3204124689102172
Meta Valid Accuracy 0.46399998664855957


Iteration 51
Meta Train Error 1.2656408309936524
Meta Train Accuracy 0.407999986410141
Meta Valid Error 1.2985406160354613
Meta Valid Accuracy 0.4159999847412109


Iteration 52
Meta Train Error 1.2518980026245117
Meta Train Accuracy 0.44799998998641966
Meta Valid Error 1.1995737075805664
Meta Valid Accuracy 0.5039999902248382


Iteration 53
Meta Train Error 1.3223745703697205
Meta Train Accuracy 0.43999998867511747
Meta Valid Error 1.3441490411758423
Meta Valid Accuracy 0.39999998807907106


Ite



Iteration 98
Meta Train Error 1.2700043678283692
Meta Train Accuracy 0.43999999165534975
Meta Valid Error 1.3559139728546143
Meta Valid Accuracy 0.4159999907016754


Iteration 99
Meta Train Error 1.0241880655288695
Meta Train Accuracy 0.551999980211258
Meta Valid Error 1.2814533948898315
Meta Valid Accuracy 0.3839999854564667
Meta Test Error 1.3854875564575195
Meta Test Accuracy 0.39999998509883883
Ways:  3
Shots:  1
Meta Size:  46


Iteration 0
Meta Train Error 1.098006854886594
Meta Train Accuracy 0.42028986435869464
Meta Valid Error 1.1098157996716707
Meta Valid Accuracy 0.3623188500818999


Iteration 1
Meta Train Error 1.1111362161843672
Meta Train Accuracy 0.4130434873311416
Meta Valid Error 1.1160558384397756
Meta Valid Accuracy 0.3913043562484824


Iteration 2
Meta Train Error 1.1341944168443265
Meta Train Accuracy 0.39855073392391205
Meta Valid Error 1.0231785372547482
Meta Valid Accuracy 0.4420289954413538


Iteration 3
Meta Train Error 1.0750326805788537
Meta Train Accuracy



Iteration 47
Meta Train Error 1.1289451212986656
Meta Train Accuracy 0.4565217462570771
Meta Valid Error 1.2665272193110508
Meta Valid Accuracy 0.3768116015454997


Iteration 48
Meta Train Error 0.9461905969225842
Meta Train Accuracy 0.6014492816251257
Meta Valid Error 1.231677022965058
Meta Valid Accuracy 0.3768116015454997


Iteration 49
Meta Train Error 1.1165373156899991
Meta Train Accuracy 0.44927536987740063
Meta Valid Error 1.0556075035229973
Meta Valid Accuracy 0.44202899284984754


Iteration 50
Meta Train Error 1.0351593669341959
Meta Train Accuracy 0.5144927624775015
Meta Valid Error 1.023693039689375
Meta Valid Accuracy 0.5072463854499485


Iteration 51
Meta Train Error 1.1309264710415965
Meta Train Accuracy 0.4710145003121832
Meta Valid Error 1.0710555881909702
Meta Valid Accuracy 0.5000000097181486


Iteration 52
Meta Train Error 1.1100151882223461
Meta Train Accuracy 0.4637681245803833
Meta Valid Error 1.2469742304604987
Meta Valid Accuracy 0.42028986241506494


Iterati



Iteration 97
Meta Train Error 1.0848563404186913
Meta Train Accuracy 0.4565217456092005
Meta Valid Error 1.1206089510865833
Meta Valid Accuracy 0.3913043543048527


Iteration 98
Meta Train Error 0.997960424941519
Meta Train Accuracy 0.46376812328463013
Meta Valid Error 1.1749493259450663
Meta Valid Accuracy 0.4130434853875119


Iteration 99
Meta Train Error 0.952579910988393
Meta Train Accuracy 0.5362318929122842
Meta Valid Error 1.0780960865642712
Meta Valid Accuracy 0.42028986371081806
Meta Test Error 1.0438084576440894
Meta Test Accuracy 0.4565217482007068
Ways:  3
Shots:  3
Meta Size:  15


Iteration 0
Meta Train Error 1.1199597676595052
Meta Train Accuracy 0.37777778605620066
Meta Valid Error 1.087990109125773
Meta Valid Accuracy 0.42962963779767355


Iteration 1
Meta Train Error 1.0430028319358826
Meta Train Accuracy 0.4000000109275182
Meta Valid Error 1.0956261197725932
Meta Valid Accuracy 0.35555556416511536


Iteration 2
Meta Train Error 1.124479643503825
Meta Train Accuracy



Iteration 47
Meta Train Error 0.9514055609703064
Meta Train Accuracy 0.5555555681387584
Meta Valid Error 0.9930595278739929
Meta Valid Accuracy 0.5555555681387584


Iteration 48
Meta Train Error 0.8204127232233683
Meta Train Accuracy 0.5777777954936028
Meta Valid Error 0.7907057444254557
Meta Valid Accuracy 0.637037051220735


Iteration 49
Meta Train Error 0.7989013195037842
Meta Train Accuracy 0.600000015894572
Meta Valid Error 0.8624431133270264
Meta Valid Accuracy 0.6000000139077505


Iteration 50
Meta Train Error 0.8744948883851369
Meta Train Accuracy 0.5925926074385643
Meta Valid Error 0.8431953191757202
Meta Valid Accuracy 0.6148148278395335


Iteration 51
Meta Train Error 0.7988727947076162
Meta Train Accuracy 0.5555555701255799
Meta Valid Error 0.9893866459528605
Meta Valid Accuracy 0.562962978084882


Iteration 52
Meta Train Error 0.9857060710589091
Meta Train Accuracy 0.5333333497246107
Meta Valid Error 0.9961845556894938
Meta Valid Accuracy 0.5481481616695721


Iteration 5



Iteration 97
Meta Train Error 0.8615330259005228
Meta Train Accuracy 0.6074074228604635
Meta Valid Error 0.7494766116142273
Meta Valid Accuracy 0.7259259422620138


Iteration 98
Meta Train Error 0.7169415682554245
Meta Train Accuracy 0.6444444581866264
Meta Valid Error 0.9122810602188111
Meta Valid Accuracy 0.5481481671333313


Iteration 99
Meta Train Error 0.8099295447270075
Meta Train Accuracy 0.6444444601734479
Meta Valid Error 0.9513514359792073
Meta Valid Accuracy 0.5481481666366259
Meta Test Error 0.8464218556880951
Meta Test Accuracy 0.5925926069418589
Ways:  3
Shots:  5
Meta Size:  9


Iteration 0
Meta Train Error 1.0185536212391324
Meta Train Accuracy 0.4296296603149838
Meta Valid Error 1.0308734642134771
Meta Valid Accuracy 0.40000002334515256


Iteration 1
Meta Train Error 1.0958260695139568
Meta Train Accuracy 0.40740742948320174
Meta Valid Error 1.0521658658981323
Meta Valid Accuracy 0.3481481687890159


Iteration 2
Meta Train Error 0.8707956406805251
Meta Train Accuracy



Iteration 47
Meta Train Error 0.9070533613363901
Meta Train Accuracy 0.555555585357878
Meta Valid Error 0.8187835415204366
Meta Valid Accuracy 0.6222222546736399


Iteration 48
Meta Train Error 0.79213813940684
Meta Train Accuracy 0.614814837773641
Meta Valid Error 1.007195274035136
Meta Valid Accuracy 0.4740740954875946


Iteration 49
Meta Train Error 0.7417821486790975
Meta Train Accuracy 0.5925926168759664
Meta Valid Error 0.8229788276884291
Meta Valid Accuracy 0.6518518792258369


Iteration 50
Meta Train Error 0.7523793445693122
Meta Train Accuracy 0.614814837773641
Meta Valid Error 0.8438833488358392
Meta Valid Accuracy 0.5481481717692481


Iteration 51
Meta Train Error 0.7207683556609683
Meta Train Accuracy 0.6370370652940538
Meta Valid Error 1.010984288321601
Meta Valid Accuracy 0.49629631969663834


Iteration 52
Meta Train Error 0.8609422809547849
Meta Train Accuracy 0.5777778062555525
Meta Valid Error 0.9393146965238783
Meta Valid Accuracy 0.5037037233511606


Iteration 53
M



Iteration 97
Meta Train Error 0.8692252337932587
Meta Train Accuracy 0.607407444053226
Meta Valid Error 0.9160533679856194
Meta Valid Accuracy 0.5111111402511597


Iteration 98
Meta Train Error 0.8155793150266012
Meta Train Accuracy 0.66666669315762
Meta Valid Error 0.9392208523220487
Meta Valid Accuracy 0.5111111369397905


Iteration 99
Meta Train Error 0.7157675441768434
Meta Train Accuracy 0.7407407826847501
Meta Valid Error 0.9290616777208116
Meta Valid Accuracy 0.548148168457879
Meta Test Error 0.7253471844726138
Meta Test Accuracy 0.6592592928144667
Ways:  2
Shots:  1
Meta Size:  69


Iteration 0
Meta Train Error 0.7106100649073503
Meta Train Accuracy 0.5289855072463768
Meta Valid Error 0.7013807892799377
Meta Valid Accuracy 0.5289855072463768


Iteration 1
Meta Train Error 0.6927685703056446
Meta Train Accuracy 0.5652173913043478
Meta Valid Error 0.8768792020669882
Meta Valid Accuracy 0.391304347826087


Iteration 2
Meta Train Error 0.8511852357482564
Meta Train Accuracy 0.615



Iteration 47
Meta Train Error 0.7139653509509736
Meta Train Accuracy 0.6159420289855072
Meta Valid Error 0.6358842356265455
Meta Valid Accuracy 0.7028985507246377


Iteration 48
Meta Train Error 0.43224339370710263
Meta Train Accuracy 0.8043478260869565
Meta Valid Error 0.8703150999718818
Meta Valid Accuracy 0.5434782608695652


Iteration 49
Meta Train Error 0.7311138625162236
Meta Train Accuracy 0.6014492753623188
Meta Valid Error 0.8990362990593564
Meta Valid Accuracy 0.5797101449275363


Iteration 50
Meta Train Error 0.7709504074376562
Meta Train Accuracy 0.6086956521739131
Meta Valid Error 0.8075427283951337
Meta Valid Accuracy 0.6304347826086957


Iteration 51
Meta Train Error 0.5505222791023012
Meta Train Accuracy 0.7391304347826086
Meta Valid Error 0.7594670207928056
Meta Valid Accuracy 0.6231884057971014


Iteration 52
Meta Train Error 0.7435126069231309
Meta Train Accuracy 0.6014492753623188
Meta Valid Error 0.780369239008945
Meta Valid Accuracy 0.6159420289855072


Iteratio



Iteration 97
Meta Train Error 0.605715376743372
Meta Train Accuracy 0.6666666666666666
Meta Valid Error 0.6564150856456895
Meta Valid Accuracy 0.6521739130434783


Iteration 98
Meta Train Error 0.6534770876169205
Meta Train Accuracy 0.6594202898550725
Meta Valid Error 0.539904176864935
Meta Valid Accuracy 0.7101449275362319


Iteration 99
Meta Train Error 0.7824574193876722
Meta Train Accuracy 0.5652173913043478
Meta Valid Error 0.6572894641886586
Meta Valid Accuracy 0.6739130434782609
Meta Test Error 0.5641824234870897
Meta Test Accuracy 0.6739130434782609
Ways:  2
Shots:  3
Meta Size:  23


Iteration 0
Meta Train Error 0.7325779497623444
Meta Train Accuracy 0.4637681278197662
Meta Valid Error 0.6904495928598486
Meta Valid Accuracy 0.543478271883467


Iteration 1
Meta Train Error 0.7271579180074774
Meta Train Accuracy 0.5579710265864497
Meta Valid Error 0.7057691449704377
Meta Valid Accuracy 0.5797101602606152


Iteration 2
Meta Train Error 0.7711160296331281
Meta Train Accuracy 0.5



Iteration 47
Meta Train Error 0.5838246032919573
Meta Train Accuracy 0.695652186870575
Meta Valid Error 0.5500210413466329
Meta Valid Accuracy 0.7608695794706759


Iteration 48
Meta Train Error 0.6199042984972829
Meta Train Accuracy 0.7173913244319998
Meta Valid Error 0.6243631697219351
Meta Valid Accuracy 0.6666666807039924


Iteration 49
Meta Train Error 0.564367519772571
Meta Train Accuracy 0.717391312122345
Meta Valid Error 0.554003676318604
Meta Valid Accuracy 0.6739130590272986


Iteration 50
Meta Train Error 0.4760736557452575
Meta Train Accuracy 0.7608695807664291
Meta Valid Error 0.6062190292969994
Meta Valid Accuracy 0.6884058091951453


Iteration 51
Meta Train Error 0.5230347089793371
Meta Train Accuracy 0.7536232030909994
Meta Valid Error 0.5647186342140903
Meta Valid Accuracy 0.6811594334633454


Iteration 52
Meta Train Error 0.5743667020098023
Meta Train Accuracy 0.7101449474044468
Meta Valid Error 0.614584444657616
Meta Valid Accuracy 0.710144940925681


Iteration 53
M



Iteration 97
Meta Train Error 0.5577445214857226
Meta Train Accuracy 0.6594203082115754
Meta Valid Error 0.5131224565531897
Meta Valid Accuracy 0.7753623400045477


Iteration 98
Meta Train Error 0.5297474563121796
Meta Train Accuracy 0.6884058085472687
Meta Valid Error 0.46853881912386935
Meta Valid Accuracy 0.7753623367651649


Iteration 99
Meta Train Error 0.5729666008897449
Meta Train Accuracy 0.652173928592516
Meta Valid Error 0.4418108599341434
Meta Valid Accuracy 0.8260869681835175
Meta Test Error 0.5244723828914373
Meta Test Accuracy 0.7101449428693108
Ways:  2
Shots:  5
Meta Size:  13


Iteration 0
Meta Train Error 0.7059394808915945
Meta Train Accuracy 0.507692317549999
Meta Valid Error 0.7037360484783466
Meta Valid Accuracy 0.5000000011462432


Iteration 1
Meta Train Error 0.6437220092003162
Meta Train Accuracy 0.5692307765667255
Meta Valid Error 0.6083054336217734
Meta Valid Accuracy 0.6538461607236129


Iteration 2
Meta Train Error 0.6148322132917551
Meta Train Accuracy 0



Iteration 47
Meta Train Error 0.4516201251401351
Meta Train Accuracy 0.7769230856345251
Meta Valid Error 0.5760735385119915
Meta Valid Accuracy 0.6846153925244625


Iteration 48
Meta Train Error 0.5599043678778869
Meta Train Accuracy 0.684615393097584
Meta Valid Error 0.6584001458608187
Meta Valid Accuracy 0.6230769294958848


Iteration 49
Meta Train Error 0.5287547673170383
Meta Train Accuracy 0.7000000087114481
Meta Valid Error 0.6174061825642219
Meta Valid Accuracy 0.6615384679574233


Iteration 50
Meta Train Error 0.39083434039583576
Meta Train Accuracy 0.8076923225934689
Meta Valid Error 0.6999218165874481
Meta Valid Accuracy 0.6076923196132367


Iteration 51
Meta Train Error 0.4690372720360756
Meta Train Accuracy 0.7692307784007146
Meta Valid Error 0.6029691604467539
Meta Valid Accuracy 0.6846153999750431


Iteration 52
Meta Train Error 0.4277935386277162
Meta Train Accuracy 0.8076923214472257
Meta Valid Error 0.7042737213464884
Meta Valid Accuracy 0.6461538557822888


Iteratio



Iteration 97
Meta Train Error 0.5148247976142627
Meta Train Accuracy 0.7461538566992834
Meta Valid Error 0.4354109483269545
Meta Valid Accuracy 0.7923077046871185


Iteration 98
Meta Train Error 0.5664230410296184
Meta Train Accuracy 0.6384615474022352
Meta Valid Error 0.4705989532745801
Meta Valid Accuracy 0.7615384688744178


Iteration 99
Meta Train Error 0.6072888523340225
Meta Train Accuracy 0.6692307774837201
Meta Valid Error 0.4332858071877406
Meta Valid Accuracy 0.8230769404998193
Meta Test Error 0.40032745553896976
Meta Test Accuracy 0.7692307772544714


### CNN MAML Benchmark

In [10]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cpu')
if cuda and torch.cuda.device_count():
    torch.cuda.manual_seed(seed)
    device = torch.device('cuda')

# Create Tasksets using the benchmark interface
train_dataset = l2l.data.MetaDataset(train_dataset)
train_transforms = [
    NWays(train_dataset, ways),
    KShots(train_dataset, 2*shots),
    LoadData(train_dataset),
    RemapLabels(train_dataset),
    ConsecutiveLabels(train_dataset),
]

valid_dataset = l2l.data.MetaDataset(valid_dataset)
valid_transforms = [
    NWays(valid_dataset, ways),
    KShots(valid_dataset, 2*shots),
    LoadData(valid_dataset),
    ConsecutiveLabels(valid_dataset),
    RemapLabels(valid_dataset),
]

test_dataset = l2l.data.MetaDataset(test_dataset)
test_transforms = [
    NWays(test_dataset, ways),
    KShots(test_dataset, 2*shots),
    LoadData(test_dataset),
    RemapLabels(test_dataset),
    ConsecutiveLabels(test_dataset),
]

# Instantiate the tasksets
train_tasks = l2l.data.TaskDataset(
    dataset=train_dataset,
    task_transforms=train_transforms,
    num_tasks=-1,
)
validation_tasks = l2l.data.TaskDataset(
    dataset=valid_dataset,
    task_transforms=valid_transforms,
    num_tasks=-1,
)
test_tasks = l2l.data.TaskDataset(
    dataset=test_dataset,
    task_transforms=test_transforms,
    num_tasks=-1,
)

BenchmarkTasksets = namedtuple('BenchmarkTasksets', ('train', 'validation', 'test'))
tasksets = BenchmarkTasksets(train_tasks, validation_tasks, test_tasks)

# Create model
#model = l2l.vision.models.MiniImagenetCNN(ways)
model = l2l.vision.models.OmniglotFC(512 ** 2 * 3, ways)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')

for iteration in range(num_iterations):
    opt.zero_grad()
    meta_train_error = 0.0
    meta_train_accuracy = 0.0
    meta_valid_error = 0.0
    meta_valid_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-training loss
        learner = maml.clone()
        batch = tasksets.train.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        evaluation_error.backward()
        meta_train_error += evaluation_error.item()
        meta_train_accuracy += evaluation_accuracy.item()

        # Compute meta-validation loss
        learner = maml.clone()
        batch = tasksets.validation.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        meta_valid_error += evaluation_error.item()
        meta_valid_accuracy += evaluation_accuracy.item()

    # Print some metrics
    print('\n')
    print('Iteration', iteration)
    print('Meta Train Error', meta_train_error / meta_batch_size)
    print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
    print('Meta Valid Error', meta_valid_error / meta_batch_size)
    print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

    # Average the accumulated gradients and optimize
    for p in maml.parameters():
        p.grad.data.mul_(1.0 / meta_batch_size)
    opt.step()

meta_test_error = 0.0
meta_test_accuracy = 0.0
for task in range(meta_batch_size):
    # Compute meta-testing loss
    learner = maml.clone()
    batch = tasksets.test.sample()
    evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                       learner,
                                                       loss,
                                                       adaptation_steps,
                                                       shots,
                                                       ways,
                                                       device)
    meta_test_error += evaluation_error.item()
    meta_test_accuracy += evaluation_accuracy.item()
print('Meta Test Error', meta_test_error / meta_batch_size)
print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)