In [None]:
# reference: https://github.com/sicara/easy-few-shot-learning/tree/master

In [None]:
import os
import json
import timm
import numpy as np
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import glob
from PIL import Image
from tqdm.auto import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from copy import deepcopy

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

In [None]:
dinov2_backbone = timm.create_model(
    'vit_small_patch14_dinov2.lvd142m',
    pretrained=True,
    img_size=224,
)

In [None]:
transform = create_transform(
    input_size=(3, 224, 224),
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
)

In [None]:
aug_transform = create_transform(
    input_size=(3, 224, 224),
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
    is_training=True,
    auto_augment='rand-m9-mstd0.5',
)

In [None]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)

        n_way = len(torch.unique(support_labels))
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        dists = torch.cdist(z_query, z_proto)

        scores = -dists
        return scores

In [None]:
model = PrototypicalNetworks(dinov2_backbone).cuda()

In [None]:
my_transform = lambda x: torch.stack([transform(x), aug_transform(x)])

In [None]:
train_set = ImageFolder(
    root="./miniimagenet",
    transform=my_transform,
)
test_set = ImageFolder(
    root="./miniimagenet",
    transform=my_transform,
)

In [None]:
N_WAY = 4
N_SHOT = 5
N_QUERY = 10
N_EVALUATION_TASKS = 10

test_set.get_labels = lambda: [
    instance[1] for instance in test_set
]
test_sampler = TaskSampler(
    test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [None]:
N_TRAINING_EPISODES = 40000
N_VALIDATION_TASKS = 100

train_set.get_labels = lambda: [instance[1] for instance in train_set]
train_sampler = TaskSampler(
    train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def fit(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> float:
    optimizer.zero_grad()
    classification_scores = model(
        support_images.cuda(), support_labels.cuda(), query_images.cuda()
    )

    loss = criterion(classification_scores, query_labels.cuda())
    loss.backward()
    optimizer.step()

    return loss.item()

In [None]:
log_update_frequency = 10

all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
    for episode_index, (
        support_images,
        support_labels,
        query_images,
        query_labels,
        _,
    ) in tqdm_train:
        support_images = support_images[:, 0, :, :, :]
        query_images = query_images[:, 0, :, :, :]
    
        loss_value = fit(support_images, support_labels, query_images, query_labels)
        all_loss.append(loss_value)

        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))

In [None]:
def evaluate_on_one_task(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> [int, int]:
    
    origin_support_images = support_images[:, 0, :, :, :]
    augment_support_images = support_images[:, 1, :, :, :]
    query_images = query_images[:, 0, :, :, :]
    
    ft_model = deepcopy(model)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(ft_model.parameters(), lr=1e-5)
    for _ in range(5):
        optimizer.zero_grad()
        classification_scores = ft_model(
            origin_support_images.cuda(), support_labels.cuda(), augment_support_images.cuda()
        )

        loss = criterion(classification_scores, support_labels.cuda())
        loss.backward()
        optimizer.step()
    
    ft_model.eval()
    with torch.no_grad():
        n_correct = (
            torch.max(
                ft_model(origin_support_images.cuda(), support_labels.cuda(), query_images.cuda())
                .detach()
                .data,
                1,
            )[1]
            == query_labels.cuda()
        ).sum().item()
        
    del ft_model
    
    return n_correct, len(query_labels)


def evaluate(data_loader: DataLoader):
    total_predictions = 0
    correct_predictions = 0

    for episode_index, (
        support_images,
        support_labels,
        query_images,
        query_labels,
        class_ids,
    ) in tqdm(enumerate(data_loader), total=len(data_loader)):

        correct, total = evaluate_on_one_task(
            support_images, support_labels, query_images, query_labels
        )

        total_predictions += total
        correct_predictions += correct

    print(
        f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
    )

In [None]:
test_set = ImageFolder(
    root="./rabbit_breed",
    transform=my_transform,
)
test_set.get_labels = lambda: [
    instance[1] for instance in test_set
]
test_sampler = TaskSampler(
    test_sampler, n_way=4, n_shot=5, n_query=10, n_tasks=100,
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

evaluate(test_loader)

In [None]:
test_set = ImageFolder(
    root="./rabbit_breed",
    transform=my_transform,
)
test_set.get_labels = lambda: [
    instance[1] for instance in test_set
]
test_sampler = TaskSampler(
    test_sampler, n_way=4, n_shot=1, n_query=10, n_tasks=100,
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

evaluate(test_loader)

In [None]:
test_set = ImageFolder(
    root="./CUB_200_2011",
    transform=my_transform,
)
test_set.get_labels = lambda: [
    instance[1] for instance in test_set
]
test_sampler = TaskSampler(
    test_sampler, n_way=5, n_shot=1, n_query=10, n_tasks=100,
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

evaluate(test_loader)

In [None]:
test_set = ImageFolder(
    root="./CUB_200_2011",
    transform=my_transform,
)
test_set.get_labels = lambda: [
    instance[1] for instance in test_set
]
test_sampler = TaskSampler(
    test_sampler, n_way=5, n_shot=5, n_query=10, n_tasks=100,
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

evaluate(test_loader)