# Train a model with Classical Training

Although episodic training has attracted a lot of interest in the early years of Few-Shot Learning research, more recent works suggest that competitive results can be achieved with a simple cross entropy loss across all training classes. Therefore, it is becoming more and more common to use this classical process to train the backbone, that will be common to all methods compared at test time.

This is in fact more representative of real use cases: episodic training assumes that, at training time, you have access to the shape of the few-shot tasks that will be encountered at test time (indeed you choose a specific number of ways for episodic training). You also "force" your inference method into the training of the network. Switching the few-shot learning logic to inference (i.e. no episodic training) allows methods to be agnostic of the backbone.

Nonetheless, if you need to perform episodic training, we also provide [an example notebook](episodic_training.ipynb) for that.

## Getting started
First we're going to do some imports (this is not the interesting part).

In [1]:
try:
    import google.colab
    colab = True
except:
    colab = False

In [2]:
if colab is True:
    # Running in Google Colab
    # Clone the repo
    !git clone https://github.com/sicara/easy-few-shot-learning
    %cd easy-few-shot-learning
    !pip install .
else:
    # Run locally
    # Ensure working directory is the project's root
    # Make sure easyfsl is installed!
    %cd ..

Cloning into 'easy-few-shot-learning'...
remote: Enumerating objects: 827, done.[K
remote: Counting objects: 100% (288/288), done.[K
remote: Compressing objects: 100% (162/162), done.[K
remote: Total 827 (delta 165), reused 169 (delta 123), pack-reused 539[K
Receiving objects: 100% (827/827), 2.21 MiB | 3.21 MiB/s, done.
Resolving deltas: 100% (455/455), done.
/content/easy-few-shot-learning
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing /content/easy-few-shot-learning
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Building wheels for collected packages: easyfsl
  Buil

In [3]:
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

Then we're gonna do the most important thing in Machine Learning research: ensuring reproducibility by setting the random seed. We're going to set the seed for all random packages that we could possibly use, plus some other stuff to make CUDA deterministic (see [here](https://pytorch.org/docs/stable/notes/randomness.html)).

I strongly encourage that you do this in **all your scripts**.

In [4]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Then we're gonna create our data loader for the training set. You can see that I chose tu use CUB in this notebook, because it’s a small dataset, so we can have good results quite quickly. I set a batch size of 128 but feel free to adapt it to your constraints.

Note that we're not using the `TaskSampler` for the train data loader, because we won't be sampling training data in the shape of tasks as we would have in episodic training. We do it **normally**.

In [5]:
# Download the CUB dataset
!bash scripts/download_CUB.sh

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
images/167.Hooded_Warbler/Hooded_Warbler_0024_2654633686.jpg
images/167.Hooded_Warbler/Hooded_Warbler_0025_2690306769.jpg
images/167.Hooded_Warbler/Hooded_Warbler_0026_2817738193.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0001_495149583.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0002_2471636465.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0003_2540117270.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0004_2301685839.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0005_498814033.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0006_2519317376.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0007_2468962218.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0008_2464457843.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0009_2591413864.jpg
images/166.Golden_winged_Warbler/Golden_winged_Warbler_0010_2316599673.jpg
imag

In [6]:
from easyfsl.datasets import CUB
from torch.utils.data import DataLoader

batch_size = 128
n_workers = 12

train_set = CUB(split="train", training=True)
train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    num_workers=n_workers,
    pin_memory=True,
    shuffle=True,
)

  cpuset_checked))


Now, we are going to create the model that we want to train. Here we choose the ResNet12 that is very often used in Few-Shot Learning research. Note that the default setting of these networks in EasyFSL is to not have a last fully connected layer (as it is usual for most Few-Shot Learning methods), but for classical training we need this layer! We also force it to output a vector which size is the number of different classes in the training set.

In [7]:
from easyfsl.modules import resnet12

DEVICE = "cuda"

model = resnet12(
    use_fc=True,
    num_classes=len(set(train_set.get_labels())),
).to(DEVICE)

Now, we still need validation ! Since we're training a model to perform few-shot classification, we will validate on few-shot tasks, so now we'll use the `TaskSampler`. We arbitrarily set the shape of the validation tasks. Ideally, you'd like to perform validation on various shapes of tasks, but we didn't implement this yet (feel free to contribute!).

We also need to define the few-shot classification method that we will use during validation of the neural network we're training.
Here we choose Prototypical Networks, because it's simple and efficient, but this is still an arbitrary choice.

In [8]:
from easyfsl.methods import PrototypicalNetworks
from easyfsl.samplers import TaskSampler

n_way = 5
n_shot = 5
n_query = 10
n_validation_tasks = 500

val_set = CUB(split="val", training=False)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)

## Training

Now let's define our training helpers ! I chose to use Stochastic Gradient Descent on 200 epochs with a scheduler that divides the learning rate by 10 after 120 and 160 epochs. The strategy is derived from [this repo](https://github.com/fiveai/on-episodes-fsl).

We're also gonna use a TensorBoard because it's always good to see what your training curves look like.

An other thing: we're doing 200 epochs like in [the episodic training notebook](notebooks/episodic_training.ipynb), but keep in mind that an epoch in classical training means one pass through the 6000 images of the dataset, while in episodic training it's an arbitrary number of episodes. In the episodic training notebook an epoch is 500 episodes of 5-way, 5-shot, 10-query tasks, so 37500 images. TL;DR you may want to monitor your training and increase the number of epochs if necessary.

In [9]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 200
scheduler_milestones = [150, 180]
scheduler_gamma = 0.1
learning_rate = 1e-01
tb_logs_dir = Path(".")

train_optimizer = SGD(
    model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

And now let's get to it! Here we define the function that performs a training epoch.

We use tqdm to monitor the training in real time in our logs.

In [10]:
def training_epoch(model_: nn.Module, data_loader: DataLoader, optimizer: Optimizer):
    all_loss = []
    model_.train()
    with tqdm(data_loader, total=len(data_loader), desc="Training") as tqdm_train:
        for images, labels in tqdm_train:
            optimizer.zero_grad()

            loss = LOSS_FUNCTION(model_(images.to(DEVICE)), labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

And we have everything we need! This is now the time to **start training**.

A few notes:

- We only validate every 10 epochs (you may set an even less frequent validation) because a training epoch is much faster than 500 few-shot tasks, and we don't want validation to be the bottleneck of our training process.

- I also added something to log the state of the model that gave the best performance on the validation set.

In [11]:
from easyfsl.methods.utils import evaluate


best_state = model.state_dict()
best_validation_accuracy = 0.0
validation_frequency = 10
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(model, train_loader, train_optimizer)

    if epoch % validation_frequency == validation_frequency - 1:

        # We use this very convenient method from EasyFSL's ResNet to specify
        # that the model shouldn't use its last fully connected layer during validation.
        model.set_use_fc(False)
        validation_accuracy = evaluate(
            few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
        )
        model.set_use_fc(True)

        if validation_accuracy > best_validation_accuracy:
            best_validation_accuracy = validation_accuracy
            best_state = model.state_dict()
            print("Ding ding ding! We found a new best model!")

        tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    tb_writer.add_scalar("Train/loss", average_loss, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()

Epoch 0


  cpuset_checked))
Training: 100%|██████████| 34/34 [00:52<00:00,  1.53s/it, loss=5.06]


Epoch 1


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=4.79]


Epoch 2


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=4.66]


Epoch 3


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=4.62]


Epoch 4


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=4.57]


Epoch 5


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=4.53]


Epoch 6


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=4.51]


Epoch 7


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=4.48]


Epoch 8


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=4.45]


Epoch 9


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=4.4]
Validation: 100%|██████████| 500/500 [02:59<00:00,  2.78it/s, accuracy=0.408]


Ding ding ding! We found a new best model!
Epoch 10


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=4.38]


Epoch 11


Training: 100%|██████████| 34/34 [00:37<00:00,  1.12s/it, loss=4.32]


Epoch 12


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=4.32]


Epoch 13


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=4.29]


Epoch 14


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=4.23]


Epoch 15


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=4.19]


Epoch 16


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=4.15]


Epoch 17


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=4.13]


Epoch 18


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=4.1]


Epoch 19


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=4.06]
Validation: 100%|██████████| 500/500 [02:59<00:00,  2.79it/s, accuracy=0.492]


Ding ding ding! We found a new best model!
Epoch 20


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=3.98]


Epoch 21


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=3.95]


Epoch 22


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=3.9]


Epoch 23


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=3.84]


Epoch 24


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=3.81]


Epoch 25


Training: 100%|██████████| 34/34 [00:40<00:00,  1.19s/it, loss=3.74]


Epoch 26


Training: 100%|██████████| 34/34 [00:40<00:00,  1.19s/it, loss=3.74]


Epoch 27


Training: 100%|██████████| 34/34 [00:40<00:00,  1.19s/it, loss=3.7]


Epoch 28


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=3.61]


Epoch 29


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=3.58]
Validation: 100%|██████████| 500/500 [02:58<00:00,  2.80it/s, accuracy=0.565]


Ding ding ding! We found a new best model!
Epoch 30


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=3.51]


Epoch 31


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=3.47]


Epoch 32


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=3.39]


Epoch 33


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=3.38]


Epoch 34


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=3.32]


Epoch 35


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=3.29]


Epoch 36


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=3.23]


Epoch 37


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=3.22]


Epoch 38


Training: 100%|██████████| 34/34 [00:38<00:00,  1.15s/it, loss=3.13]


Epoch 39


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=3.12]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.81it/s, accuracy=0.646]


Ding ding ding! We found a new best model!
Epoch 40


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=3]


Epoch 41


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=3.02]


Epoch 42


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=2.97]


Epoch 43


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=2.87]


Epoch 44


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=2.82]


Epoch 45


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=2.83]


Epoch 46


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=2.76]


Epoch 47


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=2.7]


Epoch 48


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=2.64]


Epoch 49


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=2.57]
Validation: 100%|██████████| 500/500 [02:58<00:00,  2.80it/s, accuracy=0.658]


Ding ding ding! We found a new best model!
Epoch 50


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=2.58]


Epoch 51


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=2.5]


Epoch 52


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=2.5]


Epoch 53


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=2.49]


Epoch 54


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=2.41]


Epoch 55


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=2.38]


Epoch 56


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=2.29]


Epoch 57


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=2.25]


Epoch 58


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=2.22]


Epoch 59


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=2.22]
Validation: 100%|██████████| 500/500 [02:56<00:00,  2.84it/s, accuracy=0.671]


Ding ding ding! We found a new best model!
Epoch 60


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=2.18]


Epoch 61


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=2.12]


Epoch 62


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=2.15]


Epoch 63


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=2.1]


Epoch 64


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=2.09]


Epoch 65


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=2.01]


Epoch 66


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=2.02]


Epoch 67


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.96]


Epoch 68


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.93]


Epoch 69


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=1.91]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.81it/s, accuracy=0.7]


Ding ding ding! We found a new best model!
Epoch 70


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=1.88]


Epoch 71


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=1.78]


Epoch 72


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=1.8]


Epoch 73


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.74]


Epoch 74


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.78]


Epoch 75


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.71]


Epoch 76


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=1.76]


Epoch 77


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.63]


Epoch 78


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=1.62]


Epoch 79


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=1.64]
Validation: 100%|██████████| 500/500 [02:56<00:00,  2.83it/s, accuracy=0.694]


Epoch 80


Training: 100%|██████████| 34/34 [00:36<00:00,  1.06s/it, loss=1.65]


Epoch 81


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=1.54]


Epoch 82


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=1.52]


Epoch 83


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.46]


Epoch 84


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.45]


Epoch 85


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.45]


Epoch 86


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.45]


Epoch 87


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.43]


Epoch 88


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=1.46]


Epoch 89


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=1.44]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.81it/s, accuracy=0.728]


Ding ding ding! We found a new best model!
Epoch 90


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=1.49]


Epoch 91


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=1.47]


Epoch 92


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.38]


Epoch 93


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.35]


Epoch 94


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.3]


Epoch 95


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.33]


Epoch 96


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=1.34]


Epoch 97


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=1.27]


Epoch 98


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=1.22]


Epoch 99


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=1.3]
Validation: 100%|██████████| 500/500 [02:58<00:00,  2.80it/s, accuracy=0.723]


Epoch 100


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=1.29]


Epoch 101


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=1.22]


Epoch 102


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=1.19]


Epoch 103


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=1.15]


Epoch 104


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=1.2]


Epoch 105


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=1.15]


Epoch 106


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=1.16]


Epoch 107


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=1.16]


Epoch 108


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=1.13]


Epoch 109


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=1.12]
Validation: 100%|██████████| 500/500 [02:58<00:00,  2.81it/s, accuracy=0.739]


Ding ding ding! We found a new best model!
Epoch 110


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=1.12]


Epoch 111


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.12]


Epoch 112


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.07]


Epoch 113


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.06]


Epoch 114


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.05]


Epoch 115


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.05]


Epoch 116


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=1.05]


Epoch 117


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.05]


Epoch 118


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=1.08]


Epoch 119


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=1.07]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.82it/s, accuracy=0.697]


Epoch 120


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=1.04]


Epoch 121


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=1.06]


Epoch 122


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=1.05]


Epoch 123


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=1.01]


Epoch 124


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=0.983]


Epoch 125


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=0.955]


Epoch 126


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=0.974]


Epoch 127


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=1.01]


Epoch 128


Training: 100%|██████████| 34/34 [00:37<00:00,  1.10s/it, loss=0.908]


Epoch 129


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=0.908]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.82it/s, accuracy=0.725]


Epoch 130


Training: 100%|██████████| 34/34 [00:36<00:00,  1.07s/it, loss=0.926]


Epoch 131


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=0.945]


Epoch 132


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=0.979]


Epoch 133


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=0.972]


Epoch 134


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=0.981]


Epoch 135


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=0.98]


Epoch 136


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=0.94]


Epoch 137


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=0.894]


Epoch 138


Training: 100%|██████████| 34/34 [00:37<00:00,  1.09s/it, loss=0.924]


Epoch 139


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=0.859]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.81it/s, accuracy=0.715]


Epoch 140


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=0.867]


Epoch 141


Training: 100%|██████████| 34/34 [00:36<00:00,  1.08s/it, loss=0.974]


Epoch 142


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=0.973]


Epoch 143


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=0.858]


Epoch 144


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=0.872]


Epoch 145


Training: 100%|██████████| 34/34 [00:36<00:00,  1.09s/it, loss=0.865]


Epoch 146


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.853]


Epoch 147


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.879]


Epoch 148


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.921]


Epoch 149


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.881]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.82it/s, accuracy=0.729]


Epoch 150


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=0.679]


Epoch 151


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.506]


Epoch 152


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.504]


Epoch 153


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.473]


Epoch 154


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.415]


Epoch 155


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.407]


Epoch 156


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.429]


Epoch 157


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.413]


Epoch 158


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.4]


Epoch 159


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=0.421]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.82it/s, accuracy=0.751]


Ding ding ding! We found a new best model!
Epoch 160


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.368]


Epoch 161


Training: 100%|██████████| 34/34 [00:37<00:00,  1.12s/it, loss=0.381]


Epoch 162


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.39]


Epoch 163


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.351]


Epoch 164


Training: 100%|██████████| 34/34 [00:37<00:00,  1.12s/it, loss=0.392]


Epoch 165


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.371]


Epoch 166


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.349]


Epoch 167


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.371]


Epoch 168


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.346]


Epoch 169


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.34]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.82it/s, accuracy=0.736]


Epoch 170


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.348]


Epoch 171


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.367]


Epoch 172


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.338]


Epoch 173


Training: 100%|██████████| 34/34 [00:38<00:00,  1.14s/it, loss=0.335]


Epoch 174


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=0.362]


Epoch 175


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.328]


Epoch 176


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.334]


Epoch 177


Training: 100%|██████████| 34/34 [00:37<00:00,  1.11s/it, loss=0.308]


Epoch 178


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.329]


Epoch 179


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.304]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.82it/s, accuracy=0.748]


Epoch 180


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.334]


Epoch 181


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.288]


Epoch 182


Training: 100%|██████████| 34/34 [00:37<00:00,  1.12s/it, loss=0.317]


Epoch 183


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.295]


Epoch 184


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.325]


Epoch 185


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.291]


Epoch 186


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.311]


Epoch 187


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.315]


Epoch 188


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.289]


Epoch 189


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.316]
Validation: 100%|██████████| 500/500 [02:57<00:00,  2.81it/s, accuracy=0.748]


Epoch 190


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.323]


Epoch 191


Training: 100%|██████████| 34/34 [00:37<00:00,  1.12s/it, loss=0.296]


Epoch 192


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.308]


Epoch 193


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.302]


Epoch 194


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.283]


Epoch 195


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.306]


Epoch 196


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.308]


Epoch 197


Training: 100%|██████████| 34/34 [00:37<00:00,  1.12s/it, loss=0.304]


Epoch 198


Training: 100%|██████████| 34/34 [00:38<00:00,  1.12s/it, loss=0.305]


Epoch 199


Training: 100%|██████████| 34/34 [00:38<00:00,  1.13s/it, loss=0.301]
Validation: 100%|██████████| 500/500 [02:59<00:00,  2.79it/s, accuracy=0.747]


Yay we successfully performed Classical Training! Now if you want to you can retrieve the best model's state.

In [12]:
model.load_state_dict(best_state)

<All keys matched successfully>

## Evaluation

Now that our model is trained, we want to test it.

First step: we fetch the test data. Note that we'll evaluate on the same shape of tasks as in validation. This is malicious practice, because it means that we used *a priori* information about the evaluation tasks during training. This is still less malicious than episodic training, though.

In [17]:
n_test_tasks = 1000

test_set = CUB(split="test", training=False)
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Second step: we instantiate a few-shot classifier using our trained ResNet as backbone, and run it on the test data. We keep using Prototypical Networks for consistence, but at this point you could basically use any few-shot classifier that takes no additional trainable parameters.

Like we did during validation, we need to tell our ResNet to not use its last fully connected layer.

In [18]:
model.set_use_fc(False)

accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

100%|██████████| 1000/1000 [06:01<00:00,  2.77it/s, accuracy=0.706]

Average accuracy : 70.57 %



