# What if we use ImageNet pretrained model?

In [1]:
from dlcliche.notebook import *
from dlcliche.utils import *
sys.path.append('..')

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from few_shot.models import get_few_shot_encoder
from few_shot.callbacks import *

assert torch.cuda.is_available()
device = torch.device('cuda')
#torch.backends.cudnn.benchmark = True

from few_shot.extmodel_proto_net_clf import ExtModelProtoNetClf, BasePretrainedModel
from torchvision import models
from torch import nn
from few_shot.models import Flatten

In [2]:
# Make dataset split
from config import DATA_PATH
DATA_PATH = Path(DATA_PATH)/'miniImageNet'
SRC_EVAL_PATH = DATA_PATH/'images_evaluation'
SRC_TRN_PATH = DATA_PATH/'images_background'

EVAL_TRN_PATH = Path('data/clf_eval_train')
EVAL_VAL_PATH = Path('data/clf_eval_valid')
K_WAY  = 5  # Class
N_SKOT = 5  # Samples to build a prototype
N_INPUT_CHANNELS = 3

def rebuild_data_files(src=SRC_EVAL_PATH, trn_path=EVAL_TRN_PATH, val_path=EVAL_VAL_PATH,
                       K_WAY=K_WAY, N_SKOT=N_SKOT):
    ensure_delete(trn_path)
    ensure_delete(val_path)

    classes = [str(d.name) for d in src.glob('*')][:K_WAY]
    for cls in classes:
        dest_trn_folder = trn_path/cls
        dest_val_folder = val_path/cls
        ensure_folder(dest_trn_folder)
        ensure_folder(dest_val_folder)
        files = sorted(list((src/cls).glob('*.jpg')))
        for i in range(N_SKOT):
            copy_file(files[i], dest_trn_folder/files[i].name)
        for i in range(N_SKOT, len(files)):
            copy_file(files[i], dest_val_folder/files[i].name)

    global plain_train_ds, train_dl, valid_ds, valid_dl
    plain_train_ds = datasets.ImageFolder(
        trn_path,
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]))
    train_dl = DataLoader(
        plain_train_ds,
        batch_size=8,
    )
    valid_ds = datasets.ImageFolder(
            val_path,
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ]))
    valid_dl = DataLoader(
        valid_ds,
        batch_size=8,
    )

## Create model

In [3]:
model = BasePretrainedModel(base_model=models.resnet18)
model.to(device, dtype=torch.float)
model.eval()

proto_net_clf = ExtModelProtoNetClf(model, device=device)

## Test 5-way

In [4]:
rebuild_data_files(src=SRC_TRN_PATH, K_WAY=5, N_SKOT=5)
prototypes = proto_net_clf.make_prototypes(train_dl)
print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))
rebuild_data_files(src=SRC_TRN_PATH, K_WAY=5, N_SKOT=1)
prototypes = proto_net_clf.make_prototypes(train_dl)
print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))

rebuild_data_files(src=SRC_EVAL_PATH, K_WAY=5, N_SKOT=5)
prototypes = proto_net_clf.make_prototypes(train_dl)
print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))
rebuild_data_files(src=SRC_EVAL_PATH, K_WAY=5, N_SKOT=1)
prototypes = proto_net_clf.make_prototypes(train_dl)
print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))

100%|██████████| 4/4 [00:00<00:00, 42.06it/s]
2019-01-23 12:29:17,494 dlcliche.utils make_prototypes [INFO]: Making new prototypes.
100%|██████████| 372/372 [00:08<00:00, 43.71it/s]
100%|██████████| 2975/2975 [00:00<00:00, 20467.07it/s]


F1/Recall/Precision/Accuracy = (0.9882539207551945, 0.9882352941176471, 0.9883894134090533, 0.9882352941176471)


100%|██████████| 1/1 [00:00<00:00, 69.66it/s]
2019-01-23 12:29:26,553 dlcliche.utils make_prototypes [INFO]: Making new prototypes.
100%|██████████| 375/375 [00:08<00:00, 43.87it/s]
100%|██████████| 2995/2995 [00:00<00:00, 21287.13it/s]


F1/Recall/Precision/Accuracy = (0.8297963255981461, 0.8410684474123539, 0.9035678645381544, 0.8410684474123539)


100%|██████████| 4/4 [00:00<00:00, 58.42it/s]
2019-01-23 12:29:35,697 dlcliche.utils make_prototypes [INFO]: Making new prototypes.
100%|██████████| 372/372 [00:08<00:00, 36.94it/s]
100%|██████████| 2975/2975 [00:00<00:00, 20611.19it/s]


F1/Recall/Precision/Accuracy = (0.9804990776074115, 0.9805042016806723, 0.9811282105492662, 0.9805042016806723)


100%|██████████| 1/1 [00:00<00:00, 63.47it/s]
2019-01-23 12:29:45,100 dlcliche.utils make_prototypes [INFO]: Making new prototypes.
100%|██████████| 375/375 [00:08<00:00, 42.41it/s]
100%|██████████| 2995/2995 [00:00<00:00, 20847.89it/s]

F1/Recall/Precision/Accuracy = (0.7793769948555661, 0.7796327212020033, 0.8696235374513783, 0.7796327212020033)





## Test 80-way (only available with SRC_TRN_PATH)

In [6]:
rebuild_data_files(src=SRC_TRN_PATH, K_WAY=80, N_SKOT=1)
prototypes = proto_net_clf.make_prototypes(train_dl)
print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))
rebuild_data_files(src=SRC_TRN_PATH, K_WAY=80, N_SKOT=5)
prototypes = proto_net_clf.make_prototypes(train_dl)
print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))

100%|██████████| 10/10 [00:00<00:00, 46.86it/s]
2019-01-23 12:31:05,807 dlcliche.utils make_prototypes [INFO]: Making new prototypes.
100%|██████████| 5990/5990 [02:12<00:00, 43.69it/s]
100%|██████████| 47920/47920 [00:27<00:00, 1748.40it/s]


F1/Recall/Precision/Accuracy = (0.5320421397494145, 0.532262103505843, 0.6443292247615288, 0.532262103505843)


100%|██████████| 50/50 [00:01<00:00, 43.68it/s]
2019-01-23 12:33:53,551 dlcliche.utils make_prototypes [INFO]: Making new prototypes.
100%|██████████| 5950/5950 [02:14<00:00, 44.39it/s]
100%|██████████| 47600/47600 [00:27<00:00, 1720.76it/s]


F1/Recall/Precision/Accuracy = (0.8249790996404818, 0.8227100840336135, 0.8402276374817964, 0.8227100840336135)
