In [1]:
from easyfsl.methods import PrototypicalNetworks
from easyfsl.datasets import CUSTOM
from easyfsl.samplers import TaskSampler
from easyfsl.utils import compute_prototypes

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torchvision.transforms as T

import albumentations as A
from albumentations.pytorch import ToTensorV2

import matplotlib.pyplot as plt
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm
import timm
import os
from glob import glob
import pandas as pd
from collections import defaultdict
import json
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Model

In [None]:
class CNN(nn.Module) :
    def __init__(self, name='efficientnetv2_rw_m', num_classes=0)  :
        super(CNN, self).__init__()
        self.model = timm.create_model(name, pretrained=True, num_classes=num_classes)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten(1)
        
        self.use_fc = True
        
    def forward(self, x) :

        if self.use_fc:
            x = self.model(x)

        else :
            x = self.model.forward_features(x)
            x = self.pool(x)
            x = self.flatten(x)

        return x
    
    def set_use_fc(self, x) :
        self.use_fc = x

In [None]:
fs_model_path = './submission/fsl_224_effi_b0/2E_model.pt'
convolutional_network = CNN()
convolutional_network.set_use_fc(False)
few_shot_classifier = PrototypicalNetworks(convolutional_network).to(device)

fsm_checkpoint = torch.load(fs_model_path)
few_shot_classifier.load_state_dict(fsm_checkpoint["model_state_dict"])

# Dataset

In [2]:
# test set
fsl = defaultdict(list)

path = './data/fsl_img/*'
label_list = glob(path)
label_list.sort(key=lambda x: int(x.split('\\')[-1]))

for label in label_list :  
    fsl['class_roots'].append(label)
    fsl['class_names'].append(label.split('\\')[-1])
    
with open('data/fsl_test_supset.json', 'w') as f:
    json.dump(fsl, f, indent=2)

In [None]:
class CUSTOM_DATASET(Dataset) :
    def __init__(self, root_path, n_way_start=None, n_way_end=None, n_shot=None, transform=None) :
        self.img_list, self.label_list = self.folder_split(root_path, n_way_start, n_way_end, n_shot)
        self.n_way_start = n_way_start
        self.n_way_end = n_way_end
        self.transform = transform

    def __len__(self) :
        assert len(self.img_list) == len(self.label_list), 'Doesn\'t match between img length and label length'
        return len(self.img_list)
    
    def __getitem__(self, idx) :
        img = cv2.imread(self.img_list[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        img = self.transform(image=img)['image']
        label = torch.tensor(self.label_list[idx], dtype=torch.long ) 
        
        return img, label
    
    def folder_split(self, root_path, n_way_start, n_way_end, n_shot) :
        folders = glob(os.path.join(root_path, "*"))
        folders.sort(key=lambda x: int(x.split('\\')[-1]))
        folders = folders[n_way_start : n_way_end]

        img_list = []
        label_list = []
        for idx, path in enumerate(folders) :
            img_list.extend(glob(os.path.join(path,"*"))[:n_shot])
            label_list.extend([idx] * len(glob(os.path.join(path,"*"))[:n_shot]))
        
        return img_list, label_list

In [None]:
support_path = 'data/fsl_img
transform = A.Compose([
    # H ,W = 1 : 1.5
    A.Resize(250, 375),
    A.OneOf([
        A.RandomCrop(224,224),
        A.CenterCrop(224,224)
    ], p=1),
    
])

In [None]:
support_path = 'data/fsl_test_supset'
n_way = 5
n_shot = 5
device = "cuda"

# Prototypes 생성

In [None]:
few_shot_classifier.eval()
prototypes = np.array([])
for n_way_start in range(0, 50, n_way) :
    n_way_end = n_way_start + n_way
    
    custom_dataset = CUSTOM_DATASET(support_path, n_way_start, n_way_end, n_shot, transform)
    dataloader = DataLoader(custom_dataset, batch_size=(n_way_end-n_way_start) * n_shot)
    sup_img, sup_label = next(iter(dataloader))
    with torch.no_grad() :
        sup_feature = few_shot_classifier.backbone.forward(sup_img.to(device))
        if n_way_start == 0 :
            prototypes = compute_prototypes(sup_feature, sup_label).cpu().detach().numpy()
        else : 
            prototypes = np.concatenate([prototypes, compute_prototypes(sup_feature, sup_label).cpu().detach().numpy()])
#     np.save(f'support_feautres/{n_way_start}_{n_way_end}', prototypes.cpu().detach().numpy()) 