In [1]:
import os

import random
import numpy as np

import torch
from torch import nn, optim
from torchvision.models import resnet50

import learn2learn as l2l
from learn2learn.data.transforms import (NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels)

In [2]:
seed = 54
cuda = True
shots = 1
ways = 3

meta_lr = 0.003
fast_lr = 0.5
meta_batch_size = 32
adaptation_steps = 1
num_iterations = 60000

In [3]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cpu')
if cuda and torch.cuda.device_count():
    torch.cuda.manual_seed(seed)
    device = torch.device('cuda')

In [4]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

In [5]:
# outer loop 
def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy

In [6]:
# Create Tasksets using the benchmark interface
tasksets = l2l.vision.benchmarks.get_tasksets('cifarfs',
                                                  train_samples=2*shots,
                                                  train_ways=ways,
                                                  test_samples=2*shots,
                                                  test_ways=ways,
                                                  root='../../datasets')

In [6]:
# Create model
model = resnet50(pretrained = False)
model_path = './checkpoint/iter13616_acc0.604166679084301.pth'
ckpt = torch.load(model_path)
model.load_state_dict(ckpt)
model.to(device)

maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')

In [None]:
# 학습시작
for iteration in range(num_iterations):
    opt.zero_grad()
    meta_train_error = 0.0
    meta_train_accuracy = 0.0
    meta_valid_error = 0.0
    meta_valid_accuracy = 0.0
    
    for task in range(meta_batch_size):
        # Compute meta-training loss
        learner = maml.clone()
        batch = tasksets.train.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device)

        evaluation_error.backward()
        meta_train_error += evaluation_error.item()
        meta_train_accuracy += evaluation_accuracy.item()

        # Compute meta-validation loss
        # learner = maml.clone()
        # batch = tasksets.validation.sample()
        # evaluation_error, evaluation_accuracy = fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device)
        # meta_valid_error += evaluation_error.item()
        # meta_valid_accuracy += evaluation_accuracy.item()

    # Print some metrics
    print('\n')
    print('Iteration', iteration)
    print('Meta Train Error', meta_train_error / meta_batch_size)
    print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
    # print('Meta Valid Error', meta_valid_error / meta_batch_size)
    # print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

    # Average the accumulated gradients and optimize
    for p in maml.parameters():
        p.grad.data.mul_(1.0 / meta_batch_size)
    opt.step()

In [7]:
# mini imagenet test set 확인
tasksets = l2l.vision.benchmarks.get_tasksets('mini-imagenet',
                                                  train_samples=2*shots,
                                                  train_ways=ways,
                                                  test_samples=2*shots,
                                                  test_ways=ways,
                                                  root='../../datasets')

In [8]:
# test 진행
meta_test_error = 0.0
meta_test_accuracy = 0.0

for task in range(meta_batch_size):
    # Compute meta-testing loss
    learner = maml.clone()
    batch = tasksets.test.sample()
    evaluation_error, evaluation_accuracy = fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device)

    meta_test_error += evaluation_error.item()
    meta_test_accuracy += evaluation_accuracy.item()
    
print('Meta Test Error', meta_test_error / meta_batch_size)
print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

Meta Test Error 3.1253299694508314
Meta Test Accuracy 0.3645833432674408


In [9]:
# double mnist test set 확인
import torchvision
from torchvision.transforms import transforms

path_data = '../../datasets/double_mnist/test'
test_dataset = torchvision.datasets.ImageFolder(root = path_data, transform = transforms.ToTensor())
mnist_test = l2l.data.MetaDataset(test_dataset)

test_tasks = l2l.data.TaskDataset(mnist_test,
                                       task_transforms=[
                                            l2l.data.transforms.NWays(mnist_test, ways),
                                            l2l.data.transforms.KShots(mnist_test, 2*shots),
                                            l2l.data.transforms.LoadData(mnist_test),
                                            l2l.data.transforms.RemapLabels(mnist_test),
                                            l2l.data.transforms.ConsecutiveLabels(mnist_test),
                                       ],
                                       num_tasks=20000)

In [10]:
# test 진행
meta_test_error = 0.0
meta_test_accuracy = 0.0

for task in range(meta_batch_size):
    # Compute meta-testing loss
    learner = maml.clone()
    batch = test_tasks.sample()
    evaluation_error, evaluation_accuracy = fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device)

    meta_test_error += evaluation_error.item()
    meta_test_accuracy += evaluation_accuracy.item()
    
print('Meta Test Error', meta_test_error / meta_batch_size)
print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

Meta Test Error 1.9164721611887217
Meta Test Accuracy 0.3020833423361182
