# Train a model with Episodic Training
Episodic training has attracted a lot of interest in the early years of Few-Shot Learning research. Some papers still use it, and refer to it as "meta-learning".

Recent works distinguish the Few-Shot Classifier from the training framework, so as from v1.0 of EasyFSL, methods to episodically train a classifier were taken out of the logic of the FewShotClassifier class. Instead, we provide in this notebook an example of how to perform episodic training on a few-shot classifier.

Use it, copy it, change it, get crazy.

## Getting started
First we're going to do some imports (this is not the interesting part).

In [1]:
# IMPORTANT: Replace the token with your actual, current token if you generate a new one.
# This token is for demonstration only.
! export KAGGLE_USERNAME="richathakwani"
! export KAGGLE_KEY="KGAT_8ed46d003f4c9135fdf7081597a8cc00"

# Alternatively, and often more effective in Colab for setting a persistent env variable:
import os
os.environ['KAGGLE_USERNAME'] = "richathakwani"
os.environ['KAGGLE_KEY'] = "KGAT_8ed46d003f4c9135fdf7081597a8cc00"

In [5]:
! kaggle datasets download -d arjunashok33/miniimagenet

Dataset URL: https://www.kaggle.com/datasets/arjunashok33/miniimagenet
License(s): CC0-1.0
^C


In [15]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("arjunashok33/miniimagenet")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'miniimagenet' dataset.
Path to dataset files: /kaggle/input/miniimagenet


In [16]:
# This unzips the file and creates a new directory named 'miniimagenet'
! unzip -q miniimagenet.zip -d kaggle/input/miniimagenet

# Verify the files are there (optional)
! ls /kaggle/input/miniimagenet

unzip:  cannot find or open miniimagenet.zip, miniimagenet.zip.zip or miniimagenet.zip.ZIP.
n01532829  n02101006  n02174001  n03047690  n03544143  n04149813  n04612504
n01558993  n02105505  n02219486  n03062245  n03584254  n04243546  n06794110
n01704323  n02108089  n02443484  n03075370  n03676483  n04251144  n07584110
n01749939  n02108551  n02457408  n03127925  n03770439  n04258138  n07613480
n01770081  n02108915  n02606052  n03146219  n03773504  n04275548  n07697537
n01843383  n02110063  n02687172  n03207743  n03775546  n04296562  n07747607
n01855672  n02110341  n02747177  n03220513  n03838899  n04389033  n09246464
n01910747  n02111277  n02795169  n03272010  n03854065  n04418357  n09256479
n01930112  n02113712  n02823428  n03337140  n03888605  n04435653  n13054560
n01981276  n02114548  n02871525  n03347037  n03908618  n04443257  n13133613
n02074367  n02116738  n02950826  n03400231  n03924679  n04509417
n02089867  n02120079  n02966193  n03417042  n03980874  n04515003
n02091244  n021291

In [20]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [23]:
! mkdir -p data/mini_imagenet/images

In [24]:
!cp -r /kaggle/input/miniimagenet/* data/mini_imagenet/images/

In [54]:
!wget https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/mini_imagenet/train.csv -O ./data/mini_imagenet/train.csv
!wget https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/mini_imagenet/val.csv -O ./data/mini_imagenet/val.csv
!wget https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/mini_imagenet/test.csv -O ./data/mini_imagenet/test.csv


--2025-11-22 06:42:55--  https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/mini_imagenet/train.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1166529 (1.1M) [text/plain]
Saving to: ‘./data/mini_imagenet/train.csv’


2025-11-22 06:42:55 (184 MB/s) - ‘./data/mini_imagenet/train.csv’ saved [1166529/1166529]

--2025-11-22 06:42:55--  https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/mini_imagenet/val.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 290868 (284K) [text/plain]


In [18]:
try:
    import google.colab
    colab = True
except:
    colab = False

In [None]:
if colab is True:
    # Running in Google Colab
    # Clone the repo
    !git clone https://github.com/sicara/easy-few-shot-learning
    %cd easy-few-shot-learning
    !pip install .
else:
    # Run locally
    # Ensure working directory is the project's root
    # Make sure easyfsl is installed!
    %cd ..

In [21]:
import copy
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

Then we're gonna do the most important thing in Machine Learning research: ensuring reproducibility by setting the random seed. We're going to set the seed for all random packages that we could possibly use, plus some other stuff to make CUDA deterministic (see [here](https://pytorch.org/docs/stable/notes/randomness.html)).

I strongly encourage that you do this in **all your scripts**.

In [25]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Then we're gonna set the shape of our problem.

Also we define our set-up, like the device (change it if you don't have CUDA) or the number of workers for data loading.

In [40]:
n_way = 5
n_shot = 5
n_query = 10

n_workers = 12

## Training

First we define our data loaders for training and validation. You can see that I chose tu use CUB in this notebook, because it's a small dataset, so we can have good results quite quickly. We use `CUB` and `TaskSampler` which are built-in objects from EasyFSL.

In [None]:
# Download the CUB dataset
# !make download-cub

In [46]:
from easyfsl.datasets import MiniImageNet
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader


n_tasks_per_epoch = 700
n_validation_tasks = 200


train_set = MiniImageNet(root="./data/mini_imagenet/data", split="train")
val_set = MiniImageNet(root="./data/mini_imagenet/data", split="val")
test_set = MiniImageNet(root="./data/mini_imagenet/data", split="test")

# Those are special batch samplers that sample few-shot classification tasks with a pre-defined shape
train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

# Finally, the DataLoader. We customize the collate_fn so that batches are delivered
# in the shape: (support_images, support_labels, query_images, query_labels, class_ids)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)



And then we define the network. Here I chose Prototypical Networks and the built-in ResNet18 from PyTorch because it's easy.

In [47]:
from easyfsl.methods import PrototypicalNetworks, FewShotClassifier
from easyfsl.modules import resnet12
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

convolutional_network = resnet12()
few_shot_classifier = PrototypicalNetworks(convolutional_network).to(DEVICE)

Now let's define our training helpers ! I chose to use Stochastic Gradient Descent on 200 epochs with a scheduler that divides the learning rate by 10 after 120 and 160 epochs. The strategy is derived from [this repo](https://github.com/fiveai/on-episodes-fsl).

We're also gonna use a TensorBoard because it's always good to see what your training curves look like.

In [48]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 150
scheduler_milestones = [100, 130]
scheduler_gamma = 0.1
learning_rate = 1e-2
tb_logs_dir = Path(".")

train_optimizer = SGD(
    few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

And now let's get to it! Here we define the function that performs a training epoch.

We use tqdm to monitor the training in real time in our logs.

In [49]:
def training_epoch(
    model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

And we have everything we need! To perform validations we'll just use the built-in `evaluate` function from `easyfsl.methods.utils`.

This is now the time to **start training**.

I added something to log the state of the model that gave the best performance on the validation set.

In [32]:
! mv data/mini_imagenet/images data/mini_imagenet/data

In [None]:
# from easyfsl.utils import evaluate


# best_state = few_shot_classifier.state_dict()
# best_validation_accuracy = 0.0
# for epoch in range(n_epochs):
#     print(f"Epoch {epoch}")
#     average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
#     validation_accuracy = evaluate(
#         few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
#     )

#     if validation_accuracy > best_validation_accuracy:
#         best_validation_accuracy = validation_accuracy
#         best_state = copy.deepcopy(few_shot_classifier.state_dict())
#         # state_dict() returns a reference to the still evolving model's state so we deepcopy
#         # https://pytorch.org/tutorials/beginner/saving_loading_models
#         print("Ding ding ding! We found a new best model!")

#     tb_writer.add_scalar("Train/loss", average_loss, epoch)
#     tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

#     # Warn the scheduler that we did an epoch
#     # so it knows when to decrease the learning rate
#     train_scheduler.step()
from easyfsl.utils import evaluate
import copy
import torch
import os

# --- 1. MOUNT GOOGLE DRIVE (Run this cell once before the loop) ---
from google.colab import drive
drive.mount('/content/drive')
CHECKPOINT_DIR = '/content/drive/MyDrive/miniimagenet_checkpoints/'
# Create the checkpoint directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# ---------------------------------------------------------------------

best_state = few_shot_classifier.state_dict()
best_validation_accuracy = 0.0

# Assume DEVICE, n_epochs, training_epoch, train_loader, val_loader,
# train_optimizer, tb_writer, and train_scheduler are defined previously

for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
    validation_accuracy = evaluate(
        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(few_shot_classifier.state_dict())
        print("Ding ding ding! We found a new best model!")

        # --- SAVE THE BEST MODEL ANYTIME IT IMPROVES ---
        best_checkpoint = {
            'epoch': epoch,
            'model_state_dict': few_shot_classifier.state_dict(),
            'optimizer_state_dict': train_optimizer.state_dict(),
            'best_validation_accuracy': best_validation_accuracy
        }
        torch.save(best_checkpoint, os.path.join(CHECKPOINT_DIR, 'best_model.pth'))


    # --- 2. SAVE CHECKPOINT EVERY 10TH EPOCH ---
    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': few_shot_classifier.state_dict(),
            'optimizer_state_dict': train_optimizer.state_dict(),
            'best_validation_accuracy': best_validation_accuracy
        }
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f'epoch_{epoch+1}.pth')
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch+1} to {checkpoint_path}")
    # -----------------------------------------------

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    train_scheduler.step()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Epoch 0


Training: 100%|██████████| 700/700 [04:57<00:00,  2.35it/s, loss=1.32]
Validation: 100%|██████████| 200/200 [01:04<00:00,  3.12it/s, accuracy=0.44]


Ding ding ding! We found a new best model!
Epoch 1


Training: 100%|██████████| 700/700 [05:02<00:00,  2.31it/s, loss=1.12]
Validation: 100%|██████████| 200/200 [01:08<00:00,  2.94it/s, accuracy=0.463]


Ding ding ding! We found a new best model!
Epoch 2


Training: 100%|██████████| 700/700 [05:03<00:00,  2.30it/s, loss=1.01]
Validation: 100%|██████████| 200/200 [01:05<00:00,  3.06it/s, accuracy=0.537]


Ding ding ding! We found a new best model!
Epoch 3


Training: 100%|██████████| 700/700 [05:04<00:00,  2.30it/s, loss=0.93]
Validation: 100%|██████████| 200/200 [01:06<00:00,  3.02it/s, accuracy=0.552]


Ding ding ding! We found a new best model!
Epoch 4


Training:  17%|█▋        | 116/700 [00:53<04:11,  2.32it/s, loss=0.895]

Yay we successfully performed Episodic Training! Now if you want to you can retrieve the best model's state.

In [None]:
few_shot_classifier.load_state_dict(best_state)

## Evaluation

Now that our model is trained, we want to test it.

First step: we fetch the test data.

In [None]:
n_test_tasks = 1000

test_set = CUB(split="test", training=False)
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Second step: we run the few-shot classifier on the test data.

In [None]:
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

Congrats! You performed Episodic Training using EasyFSL. If you want to compare with a model trained using classical training, look at [this other example notebook](classical_training.ipynb).
