In [1]:
# Test phase (1-gpu)
import os
import argparse
import numpy as np
import torch.backends.cudnn as cudnn

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels
from torchvision.models import resnet50
from torchvision import transforms

In [2]:
# hyperparameters
max_epoch = 200
test_shot = 1 # 1shot
test_way = 3
test_query = 15

In [3]:
# get pretrained model (loss 기준 best model 확인)
model_path = './checkpoint/epoch50_loss1.414059302210808.pth'
model = resnet50(pretrained = False)
ckpt = torch.load(model_path)
corrected_dict = {k.replace('module.', '') : v for k, v in ckpt.items() }
model.load_state_dict(corrected_dict)
model = model.cuda()

print(f'>> check the trained model path : {model_path}')

>> check the trained model path : ./checkpoint/epoch50_loss1.414059302210808.pth


In [4]:
# metric
def pairwise_distances_logits(a, b):
    n = a.shape[0]
    m = b.shape[0]
    logits = -((a.unsqueeze(1).expand(n, m, -1) - b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2)
    return logits

def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

In [5]:
def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
    
    if metric is None:
        metric = pairwise_distances_logits

    #device = model.cuda()
    data, labels = batch
    data = data.cuda()
    labels = labels.cuda()
    n_items = shot * ways

    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)

    embeddings = model(data)
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shot + query_num)
    for offset in range(shot):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    support = support.reshape(ways, shot, -1).mean(dim=1)
    query = embeddings[query_indices]
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)
    loss = F.cross_entropy(logits, labels)
    acc = accuracy(logits, labels)
    return loss, acc

In [8]:
# get dataset
# dataset 정의
dataset_nm = 'mnist' # miniImagenet, mnist

if dataset_nm == 'cifarfs' : 
    path_data = '../../datasets'
    test_dataset = l2l.vision.datasets.CIFARFS(root=path_data, mode='test', transform = transforms.ToTensor(), download = True)

    # dataloader 정의
    test_dataset = l2l.data.MetaDataset(test_dataset)
    test_transforms = [
            NWays(test_dataset, test_way),
            KShots(test_dataset, test_query + test_shot),
            LoadData(test_dataset),
            RemapLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,task_transforms=test_transforms,num_tasks=100000)
    test_loader = DataLoader(test_tasks, pin_memory=True, shuffle=True)
    
elif dataset_nm == 'miniImagenet' :
    path_data = '../../datasets'
    test_dataset = l2l.vision.datasets.MiniImagenet(root=path_data, mode='test', download = True) # transform = transforms.ToTensor() # check
    
    # dataloader 정의
    test_dataset = l2l.data.MetaDataset(test_dataset)
    test_transforms = [
            NWays(test_dataset, test_way),
            KShots(test_dataset, test_query + test_shot),
            LoadData(test_dataset),
            RemapLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,task_transforms=test_transforms,num_tasks=100000)
    test_loader = DataLoader(test_tasks, pin_memory=True, shuffle=True)
    
elif dataset_nm == 'mnist' :
    path_data = '../../datasets/double_mnist/test'
    test_dataset = torchvision.datasets.ImageFolder(root = path_data, transform = transforms.ToTensor())
    test_dataset = l2l.data.MetaDataset(test_dataset)
    test_transforms = [
            NWays(test_dataset, test_way),
            KShots(test_dataset, test_query + test_shot),
            LoadData(test_dataset),
            RemapLabels(test_dataset),
    ]
    
    test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms = test_transforms, num_tasks = 100000)
    test_loader = DataLoader(test_tasks, pin_memory = True, shuffle = True)

else :
    raise NotImplementedError()

In [None]:
# check the test result
loss_ctr = 0.0
n_loss = 0.0
n_acc = 0.0
for i, batch in enumerate(test_loader, 1):
    loss, acc = fast_adapt(model, batch, test_way, test_shot, test_query, metric = pairwise_distances_logits, device = None)
    loss_ctr += 1
    n_acc += acc
    if i % 500 == 0 :
        print('batch {}: {:.2f}({:.2f})'.format(i, n_acc/loss_ctr * 100, acc * 100))
        
print(f'Test Accuracy : {n_acc / loss_ctr * 100}')

batch 500: 36.24(28.89)
batch 1000: 35.85(44.44)
