In [None]:
import os

In [None]:
os.chdir("/content/fsl")

In [None]:
# import packages
from tqdm import tqdm

import torch as th
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.nn.modules.loss import CrossEntropyLoss, MSELoss

from torchvision.transforms import Compose, Grayscale, CenterCrop, ToTensor, ToPILImage, Resize
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.datasets import ImageFolder

from src import Learner, TaskSampler, PrototypicalNetworks, RelationNetworks, CNNEncoder

In [None]:
# configurations
CONFIG = {}
CONFIG['shot'] = 5
CONFIG['way'] = 5
CONFIG['query'] = 10
DEVICE = 'cuda' if th.cuda.is_available() else 'cpu'
if DEVICE is 'cuda':
  print ('cuda is available as device')

cuda is available as device


In [None]:
# Image transformation pipeline

image_transform_pipe = Compose(
    [
        Grayscale(num_output_channels=3),
        CenterCrop(224),
        Resize(128),
        ToTensor()
    ]
)

# load Train dataset
train_dataset = ImageFolder(root = "./few_shot_data/train", transform=image_transform_pipe)
# load Validation dataset
val_dataset = ImageFolder(root = "./few_shot_data/val", transform=image_transform_pipe)
# load Test dataset
test_dataset = ImageFolder(root ="./few_shot_data/test", transform=image_transform_pipe)

In [None]:
# Create task samplers
train_sampler = TaskSampler(train_dataset, n_way=CONFIG['way'], n_shot=CONFIG['shot'], n_query=CONFIG['query'], n_tasks=500)

validation_sampler = TaskSampler(val_dataset, n_way=CONFIG['way'], n_shot=CONFIG['shot'], n_query=CONFIG['query'], n_tasks=100)

test_sampler = TaskSampler(test_dataset, n_way=CONFIG['way'], n_shot=CONFIG['shot'], n_query=CONFIG['query'], n_tasks=20)

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_sampler=train_sampler,
    num_workers=2,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_sampler=validation_sampler,
    num_workers=2,
    pin_memory=True,
    collate_fn=validation_sampler.episodic_collate_fn,
)

test_loader = DataLoader(
    test_dataset,
    batch_sampler=test_sampler,
    num_workers=2,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

## Run a baseline model.
Use an unfitted Prototypical Networks model, with an ImageNet pretrained Resnet model as backbone. Result gives a baseline result for a just out of the box model

In [None]:
# load pretrained model for backbone
pretrained_weights = ResNet18_Weights.IMAGENET1K_V1
backbone_model = resnet18(weights=pretrained_weights)

# Change final layer output from number of classes 
# to flattened feature vector from resnet's preceding architecure
backbone_model.fc = nn.Flatten()
model = backbone_model.cuda()

In [None]:
# Create baseline FSL classifier
baseline_classifier = PrototypicalNetworks(backbone=model, output_softmax_score=True)

In [None]:
# Run baseline test prediction
baseline_classifier.eval()
with th.no_grad():
  with tqdm(test_loader, total=20) as prediction_tasks:
    total_predictions = 0
    total_correct_predictions = 0
    for task_support_images,\
      task_support_labels, \
      task_query_images, \
      task_query_labels, \
      task_class_ids in prediction_tasks:

      task_prediction_scores = baseline_classifier(
          task_support_images.to(DEVICE), 
          task_support_labels.to(DEVICE), 
          task_query_images.to(DEVICE)
        )
      
      task_prediction_labels = th.argmax(task_prediction_scores,-1)
      task_correct_predictions = (task_prediction_labels == task_query_labels.to(DEVICE)).float()
      task_accuracy = th.mean(task_correct_predictions).item()

      total_predictions += len(task_prediction_labels)
      total_correct_predictions += th.sum(task_correct_predictions).item()

      prediction_tasks.set_postfix(task_accuracy = task_accuracy)
    
    overall_accuracy = total_correct_predictions / total_predictions
  
  print (f"\n Test overall average accuracy: {overall_accuracy}")

100%|██████████| 20/20 [00:04<00:00,  4.72it/s, task_accuracy=0.6]


 Test overall average accuracy: 0.582





## Understand basic performance of Prototypical Networks after simple training

In [None]:
# load pretrained model for backbone
pretrained_weights = ResNet18_Weights.IMAGENET1K_V1
backbone_model_1 = resnet18(weights=pretrained_weights)

# Change final layer output from number of classes 
# to flattened feature vector from resnet's preceding architecure
backbone_model_1.fc = nn.Flatten()
model_1 = backbone_model_1.cuda()

classifier_1 = PrototypicalNetworks(backbone=model_1, output_softmax_score=False)

In [None]:
learning_rate = 0.05
optimizer_1 = SGD(classifier_1.parameters(), lr=learning_rate)
loss_function = CrossEntropyLoss()

c1_best_val_accuracy = Learner.fit(
    train_data_loader = train_loader,
    val_data_loader = val_loader,
    model = classifier_1,
    optimizer=optimizer_1, 
    loss_function = loss_function,
    epochs = 10,
    tensorboard_log_path = None
  )

Training Epoch 0


100%|██████████| 500/500 [00:26<00:00, 19.23it/s, episode_accuracy=0.94, episode_loss=0.185, epoch_accuracy=0.892, epoch_loss=0.308]


Validating Epoch 0


100%|██████████| 100/100 [00:05<00:00, 18.42it/s, episode_accuracy=1, overall_accuracy=0.874]


Training Epoch 1


100%|██████████| 500/500 [00:25<00:00, 19.46it/s, episode_accuracy=0.82, episode_loss=0.506, epoch_accuracy=0.93, epoch_loss=0.197]


Validating Epoch 1


100%|██████████| 100/100 [00:05<00:00, 18.38it/s, episode_accuracy=0.8, overall_accuracy=0.89]


Training Epoch 2


100%|██████████| 500/500 [00:25<00:00, 19.38it/s, episode_accuracy=0.88, episode_loss=0.229, epoch_accuracy=0.944, epoch_loss=0.159]


Validating Epoch 2


100%|██████████| 100/100 [00:05<00:00, 17.78it/s, episode_accuracy=0.8, overall_accuracy=0.902]


Training Epoch 3


100%|██████████| 500/500 [00:25<00:00, 19.37it/s, episode_accuracy=0.98, episode_loss=0.0671, epoch_accuracy=0.954, epoch_loss=0.132]


Validating Epoch 3


100%|██████████| 100/100 [00:05<00:00, 18.02it/s, episode_accuracy=1, overall_accuracy=0.911]


Training Epoch 4


100%|██████████| 500/500 [00:25<00:00, 19.37it/s, episode_accuracy=1, episode_loss=0.0334, epoch_accuracy=0.96, epoch_loss=0.117]


Validating Epoch 4


100%|██████████| 100/100 [00:05<00:00, 18.05it/s, episode_accuracy=0.98, overall_accuracy=0.906]


Training Epoch 5


100%|██████████| 500/500 [00:25<00:00, 19.33it/s, episode_accuracy=0.96, episode_loss=0.166, epoch_accuracy=0.966, epoch_loss=0.102]


Validating Epoch 5


100%|██████████| 100/100 [00:05<00:00, 17.76it/s, episode_accuracy=0.96, overall_accuracy=0.913]


Training Epoch 6


100%|██████████| 500/500 [00:25<00:00, 19.30it/s, episode_accuracy=0.98, episode_loss=0.0711, epoch_accuracy=0.968, epoch_loss=0.0949]


Validating Epoch 6


100%|██████████| 100/100 [00:05<00:00, 17.81it/s, episode_accuracy=1, overall_accuracy=0.914]


Training Epoch 7


100%|██████████| 500/500 [00:26<00:00, 19.13it/s, episode_accuracy=1, episode_loss=0.0144, epoch_accuracy=0.97, epoch_loss=0.0857]


Validating Epoch 7


100%|██████████| 100/100 [00:05<00:00, 17.64it/s, episode_accuracy=0.96, overall_accuracy=0.906]


Training Epoch 8


100%|██████████| 500/500 [00:26<00:00, 19.12it/s, episode_accuracy=0.94, episode_loss=0.156, epoch_accuracy=0.975, epoch_loss=0.0755]


Validating Epoch 8


100%|██████████| 100/100 [00:05<00:00, 17.68it/s, episode_accuracy=0.92, overall_accuracy=0.877]


Training Epoch 9


100%|██████████| 500/500 [00:26<00:00, 19.05it/s, episode_accuracy=0.98, episode_loss=0.0528, epoch_accuracy=0.976, epoch_loss=0.0715]


Validating Epoch 9


100%|██████████| 100/100 [00:05<00:00, 17.56it/s, episode_accuracy=0.94, overall_accuracy=0.91]


In [None]:
c1_best_val_accuracy

0.9138

## Understand basic performance of Relation Networks after simple training

In [None]:
# Load pretrained CNNEncoder, trained on miniImageNet on paper
pretrained_weights = th.load("./pretrained_model/miniimagenet_feature_encoder_5way_5shot.pkl", map_location=th.device('cpu'))
pretrained_encoder = CNNEncoder()
pretrained_encoder.load_state_dict(pretrained_weights)

model_2 = pretrained_encoder.cuda()
classifier_2 = RelationNetworks(backbone=model_2, output_softmax_score=False)

In [None]:
learning_rate = 0.05
optimizer_2 = SGD(classifier_2.parameters(), lr=learning_rate)
loss_function = MSELoss()

c2_best_val_accuracy = Learner.fit(
    train_data_loader = train_loader,
    val_data_loader = val_loader,
    model = classifier_2,
    optimizer=optimizer_2, 
    loss_function = loss_function,
    epochs = 10,
    tensorboard_log_path = None
  )

Training Epoch 0


100%|██████████| 500/500 [01:34<00:00,  5.28it/s, episode_accuracy=0.28, episode_loss=0.158, epoch_accuracy=0.231, epoch_loss=0.161]


Validating Epoch 0


100%|██████████| 100/100 [00:16<00:00,  5.89it/s, episode_accuracy=0.28, overall_accuracy=0.263]


Training Epoch 1


100%|██████████| 500/500 [01:21<00:00,  6.16it/s, episode_accuracy=0.2, episode_loss=0.159, epoch_accuracy=0.306, epoch_loss=0.157]


Validating Epoch 1


100%|██████████| 100/100 [00:15<00:00,  6.35it/s, episode_accuracy=0.3, overall_accuracy=0.331]


Training Epoch 2


100%|██████████| 500/500 [01:21<00:00,  6.15it/s, episode_accuracy=0.44, episode_loss=0.14, epoch_accuracy=0.354, epoch_loss=0.152]


Validating Epoch 2


100%|██████████| 100/100 [00:15<00:00,  6.64it/s, episode_accuracy=0.32, overall_accuracy=0.378]


Training Epoch 3


100%|██████████| 500/500 [01:20<00:00,  6.18it/s, episode_accuracy=0.58, episode_loss=0.134, epoch_accuracy=0.397, epoch_loss=0.148]


Validating Epoch 3


100%|██████████| 100/100 [00:14<00:00,  6.68it/s, episode_accuracy=0.46, overall_accuracy=0.413]


Training Epoch 4


100%|██████████| 500/500 [01:20<00:00,  6.21it/s, episode_accuracy=0.36, episode_loss=0.147, epoch_accuracy=0.454, epoch_loss=0.141]


Validating Epoch 4


100%|██████████| 100/100 [00:15<00:00,  6.61it/s, episode_accuracy=0.46, overall_accuracy=0.447]


Training Epoch 5


100%|██████████| 500/500 [01:20<00:00,  6.24it/s, episode_accuracy=0.56, episode_loss=0.121, epoch_accuracy=0.498, epoch_loss=0.134]


Validating Epoch 5


100%|██████████| 100/100 [00:15<00:00,  6.61it/s, episode_accuracy=0.44, overall_accuracy=0.438]


Training Epoch 6


100%|██████████| 500/500 [01:20<00:00,  6.24it/s, episode_accuracy=0.54, episode_loss=0.128, epoch_accuracy=0.526, epoch_loss=0.129]


Validating Epoch 6


100%|██████████| 100/100 [00:15<00:00,  6.55it/s, episode_accuracy=0.34, overall_accuracy=0.498]


Training Epoch 7


100%|██████████| 500/500 [01:20<00:00,  6.19it/s, episode_accuracy=0.62, episode_loss=0.116, epoch_accuracy=0.564, epoch_loss=0.123]


Validating Epoch 7


100%|██████████| 100/100 [00:15<00:00,  6.61it/s, episode_accuracy=0.54, overall_accuracy=0.534]


Training Epoch 8


100%|██████████| 500/500 [01:20<00:00,  6.19it/s, episode_accuracy=0.64, episode_loss=0.105, epoch_accuracy=0.581, epoch_loss=0.12]


Validating Epoch 8


100%|██████████| 100/100 [00:15<00:00,  6.58it/s, episode_accuracy=0.5, overall_accuracy=0.544]


Training Epoch 9


100%|██████████| 500/500 [01:20<00:00,  6.18it/s, episode_accuracy=0.62, episode_loss=0.107, epoch_accuracy=0.602, epoch_loss=0.116]


Validating Epoch 9


100%|██████████| 100/100 [00:15<00:00,  6.61it/s, episode_accuracy=0.44, overall_accuracy=0.56]


In [None]:
c2_best_val_accuracy

0.56