In [1]:
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels
from torchvision.models import resnet50
from torchvision import transforms

In [8]:
max_epoch = 200
train_shot = 1 # 1shot
train_way = 3
train_query = 15

test_shot = 1 # 1shot
test_way = 3
test_query = 15

gpu = 0

In [9]:
# prototype loss 정의 
def pairwise_distances_logits(a, b):
    n = a.shape[0]
    m = b.shape[0]
    logits = -((a.unsqueeze(1).expand(n, m, -1) -
                b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2)
    return logits

In [10]:
# metric
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

In [11]:
def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    if device is None:
        device = model.device()
    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)
    n_items = shot * ways

    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)

    embeddings = model(data)
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shot + query_num)
    for offset in range(shot):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    support = support.reshape(ways, shot, -1).mean(dim=1)
    query = embeddings[query_indices]
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)
    loss = F.cross_entropy(logits, labels)
    acc = accuracy(logits, labels)
    return loss, acc

In [12]:
device = torch.device('cpu')
if gpu and torch.cuda.device_count():
    torch.cuda.manual_seed(43)
    device = torch.cuda()
    print(f"Using {device} gpu")

model = resnet50()
model.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [16]:
# dataset 정의
path_data = '../../datasets'
train_dataset = l2l.vision.datasets.CIFARFS(root=path_data, mode='train',transform = transforms.ToTensor(), download = True)
valid_dataset = l2l.vision.datasets.CIFARFS(root=path_data, mode='validation', transform = transforms.ToTensor(), download = True)
test_dataset = l2l.vision.datasets.CIFARFS(root=path_data, mode='test', transform = transforms.ToTensor(), download = True)

In [17]:
# dataloader 정의
train_dataset = l2l.data.MetaDataset(train_dataset)
train_transforms = [
        NWays(train_dataset, train_way),
        KShots(train_dataset, train_query + train_shot),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
]
train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms)
train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True)

valid_dataset = l2l.data.MetaDataset(valid_dataset)
valid_transforms = [
        NWays(valid_dataset, test_way),
        KShots(valid_dataset, test_query + test_shot),
        LoadData(valid_dataset),
        RemapLabels(valid_dataset),
]
valid_tasks = l2l.data.TaskDataset(valid_dataset,task_transforms=valid_transforms,num_tasks=200)
valid_loader = DataLoader(valid_tasks, pin_memory=True, shuffle=True)

test_dataset = l2l.data.MetaDataset(test_dataset)
test_transforms = [
        NWays(test_dataset, test_way),
        KShots(test_dataset, test_query + test_shot),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
]
test_tasks = l2l.data.TaskDataset(test_dataset,task_transforms=test_transforms,num_tasks=2000)
test_loader = DataLoader(test_tasks, pin_memory=True, shuffle=True)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

In [None]:
for epoch in range(1, max_epoch + 1):
    ## start training phase
    model.train()
    loss_ctr = 0
    n_loss = 0
    n_acc = 0

    for i in range(100):
        batch = next(iter(train_loader))
        loss, acc = fast_adapt(model, batch, train_way, train_shot, train_query, metric = pairwise_distances_logits, device = device)
        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    lr_scheduler.step()
    print('epoch {}, train, loss={:.4f} acc={:.4f}'.format(epoch, n_loss/loss_ctr, n_acc/loss_ctr))
    
    ## start validation phase
    model.eval()
    loss_ctr = 0
    n_loss = 0
    n_acc = 0
    for i, batch in enumerate(valid_loader):
        loss, acc = fast_adapt(model, batch, test_way, test_shot, test_query, metric = pairwise_distances_logits, device = device)
        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc

    print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, n_loss/loss_ctr, n_acc/loss_ctr))

In [None]:
## test phase
for i, batch in enumerate(test_loader, 1):
    loss, acc = fast_adapt(model, batch, test_way, test_shot, test_query, metric = pairwise_distances_logits, device = device)
    loss_ctr += 1
    n_acc += acc
    print('batch {}: {:.2f}({:.2f})'.format(i, n_acc/loss_ctr * 100, acc * 100))