In [None]:
import PIL
import torchvision.transforms as transforms
import math
import copy
import pandas as pd

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import backbone
from methods.baselinetrain import BaselineTrain
from io_utils import parse_args, get_resume_file, get_best_file, get_assigned_file
from datasets import miniImageNet_few_shot, ISIC_few_shot, EuroSAT_few_shot, CropDisease_few_shot, Chest_few_shot
from data.datamgr import SimpleDataManager, SetDataManager

from sklearn.cluster import KMeans
from sklearn.metrics.cluster import v_measure_score

In [None]:
image_size = 224

n_way = 5
n_support = 5
n_query = 15
few_shot_params = dict(n_way=n_way, n_support=n_support)

model = 'ResNet10'
method = 'baseline'
pretrained_dataset = 'miniImageNet'
save_dir = './logs'

source_datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 128)
source_loader = source_datamgr.get_data_loader(aug = False)

In [None]:
target_dataset = 'ISIC'

if target_dataset == 'miniImageNet':
    target_datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 128)
    target_loader = target_datamgr.get_data_loader(aug=False, train=False)
elif target_dataset == 'CropDisease':
    target_datamgr = CropDisease_few_shot.SimpleDataManager(image_size, batch_size = 128)
    target_loader = target_datamgr.get_data_loader(aug=False)
elif target_dataset == 'EuroSAT':
    target_datamgr = EuroSAT_few_shot.SimpleDataManager(image_size, batch_size = 128)
    target_loader = target_datamgr.get_data_loader(aug=False)
elif target_dataset == 'ISIC':
    target_datamgr = ISIC_few_shot.SimpleDataManager(image_size, batch_size = 128)
    target_loader = target_datamgr.get_data_loader(aug=False)
elif target_dataset == 'ChestX':
    target_datamgr = Chest_few_shot.SimpleDataManager(image_size, batch_size = 128)
    target_loader = target_datamgr.get_data_loader(aug=False)

# Measure

In [None]:
def get_pretrained_model(model, method, pretrained_dataset, save_dir, type_dir, dataset_name, epoch=999):
    if pretrained_dataset == 'miniImageNet':
        num_classes = 64
    elif pretrained_dataset == 'tieredImageNet':
        num_classes = 351
    
    model_dict = {model: backbone.ResNet10(method=method, track_bn=True, reinit_bn_stats=False)}
    pretrained_model = BaselineTrain(model_dict[model], num_classes, loss_type='softmax')
    
    checkpoint_dir = '%s/checkpoints/%s/%s_%s/%s/' %(save_dir, pretrained_dataset, model, method, type_dir)
    if 'type1' in type_dir:
        modelfile = get_assigned_file(checkpoint_dir, epoch)
    else:
        modelfile = get_assigned_file(checkpoint_dir, epoch, dataset_name=dataset_name)
    state = torch.load(modelfile)['state']

    pretrained_model.load_state_dict(state, strict=True)
    pretrained_model.cuda()
    pretrained_model.eval()

    for p in pretrained_model.parameters():
        p.requires_grad = False
    
    return pretrained_model

In [None]:
def get_repr_label(pretrained_model, loader):
    pretrained_model.eval()
    
    repr_lst = []
    label_lst = []

    for idx, (image, label) in tqdm(enumerate(loader)):
        image = image.cuda()
        last_repr = pretrained_model.feature(image)

        repr_lst.append(last_repr.cpu())
        label_lst.append(label)
    
    return torch.cat(repr_lst, dim=0), torch.cat(label_lst, dim=0)

def get_clustering_measure(features, target):
    N = target.size(0)
    C = len(target.unique())
    
    class_features = []
    for c in range(C):
        class_features.append(features[target==c])
    
    mu = torch.mean(features, dim=0)
    class_mu = []
    for c in range(C):
        class_mu.append(torch.mean(class_features[c], dim=0))
    
    sigma_within = 0
    sigma_btw = 0
    
    for c in range(C):
        sigma_btw += torch.norm(class_mu[c]-mu)
        for feature in class_features[c]:
            sigma_within += torch.norm(feature-class_mu[c])
    
    sigma_within /= N
    sigma_btw /= C
    
    return sigma_within.item(), sigma_btw.item(), (sigma_within/sigma_btw).item()

def get_vmeasure(repr_np, label_np):
    num_classes = len(np.unique(label_np))
    kmeans = KMeans(n_clusters=num_classes, random_state=0).fit(repr_np)
    v_measure = v_measure_score(label_np, kmeans.labels_)
    return v_measure*100

def get_cos(out, y):
    cos = torch.nn.CosineSimilarity()
    cos_mtx = torch.zeros([len(out), len(out)])
    for i in range(len(out)):
        cos_mtx[i] = cos(out, out[i].view(1, -1))
        
    same_mtx = (y.view(1, -1) == y.view(-1, 1)).float().cpu()
    diff_mtx = 1 - same_mtx
    same_mtx[range(len(same_mtx)), range(len(same_mtx))] = 0.
    
    same_cos = cos_mtx * same_mtx
    diff_cos = cos_mtx * diff_mtx
    
    return torch.sum(same_cos).item() / len(torch.where(same_cos!=0)[0]), torch.sum(diff_cos).item() / len(torch.where(diff_cos!=0)[0])

In [None]:
type_dirs = ['type1_strong', 'type4_strong_gamma125', 'type4_strong_gamma250', 'type4_strong_gamma375', 'type4_strong_gamma500', 'type4_strong_gamma625', 'type4_strong_gamma750', 'type4_strong_gamma875', 'type3_strong']

source_clustering_measure_lst = []
target_clustering_measure_lst = []
source_v_measure_lst = []
target_v_measure_lst = []

for type_dir in type_dirs:
    pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir, type_dir, dataset_name=target_dataset)
    
    source_repr, source_label = get_repr_label(pretrained_model, source_loader)
    source_repr_np, source_label_np = source_repr.numpy(), source_label.numpy()
    
    target_repr, target_label = get_repr_label(pretrained_model, target_loader)
    target_repr_np, target_label_np = target_repr.numpy(), target_label.numpy()
    
    _, _, source_clustering_measure = get_clustering_measure(source_repr, source_label)
    _, _, target_clustering_measure = get_clustering_measure(target_repr, target_label)
    
    source_clustering_measure_lst.append(source_clustering_measure)
    target_clustering_measure_lst.append(target_clustering_measure)
    
    source_v_measure = get_vmeasure(source_repr_np, source_label_np)
    target_v_measure = get_vmeasure(target_repr_np, target_label_np)
    
    source_v_measure_lst.append(source_v_measure)
    target_v_measure_lst.append(target_v_measure)

In [None]:
print (np.round(source_clustering_measure_lst, 2))
print (np.round(target_clustering_measure_lst, 2))
print (np.round(source_v_measure_lst, 2))
print (np.round(target_v_measure_lst, 2))

In [None]:
type_dir = 'type4_strong_gamma500'

source_clustering_measure_lst = []
target_clustering_measure_lst = []
source_v_measure_lst = []
target_v_measure_lst = []

for epoch in [100, 200, 300, 400, 500, 600, 700, 800, 900]:
    pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir, type_dir, dataset_name=target_dataset, epoch=epoch)
    
    source_repr, source_label = get_repr_label(pretrained_model, source_loader)
    source_repr_np, source_label_np = source_repr.numpy(), source_label.numpy()
    
    target_repr, target_label = get_repr_label(pretrained_model, target_loader)
    target_repr_np, target_label_np = target_repr.numpy(), target_label.numpy()
    
    _, _, source_clustering_measure = get_clustering_measure(source_repr, source_label)
    _, _, target_clustering_measure = get_clustering_measure(target_repr, target_label)
    
    source_clustering_measure_lst.append(source_clustering_measure)
    target_clustering_measure_lst.append(target_clustering_measure)
    
    source_v_measure = get_vmeasure(source_repr_np, source_label_np)
    target_v_measure = get_vmeasure(target_repr_np, target_label_np)
    
    source_v_measure_lst.append(source_v_measure)
    target_v_measure_lst.append(target_v_measure)

In [None]:
print (np.round(source_clustering_measure_lst, 2))
print (np.round(target_clustering_measure_lst, 2))
print (np.round(source_v_measure_lst, 2))
print (np.round(target_v_measure_lst, 2))

# T-sne

In [None]:
from tsnecuda import TSNE

In [None]:
def plot_tsne(repr_np, label_np):
    tsne_model = TSNE(n_components=2, perplexity=50.0, n_iter=1000)
    embedded_repr = tsne_model.fit_transform(repr_np)

    plt.figure(figsize=(6, 4))
    plt.scatter(embedded_repr[:, 0], embedded_repr[:, 1], c=label_np, alpha=0.4, s=1)
    plt.xticks([])
    plt.yticks([])
    plt.show()
    # plt.savefig('./src/{}_{}_test_vehicle.pdf'.format(alg, dataset), bbox_inches='tight', format='pdf')
    plt.close()

In [None]:
type_dir = 'type1_strong'
# type_dir = 'type3_strong'
# type_dir = 'type4_strong_gamma500'

pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir, type_dir, dataset_name=target_dataset)
pretrained_model.eval()

source_repr, source_label = get_repr_label(source_loader)
source_repr_np, source_label_np = source_repr.numpy(), source_label.numpy()
plot_tsne(source_repr_np, source_label_np)

target_repr, target_label = get_repr_label(target_loader)
target_repr_np, target_label_np = target_repr.numpy(), target_label.numpy()
plot_tsne(target_repr_np, target_label_np)

In [None]:
source_v_measure = get_vmeasure(source_repr_np, source_label_np)
target_v_measure = get_vmeasure(target_repr_np, target_label_np)

In [None]:
print (source_v_measure, target_v_measure)