<a name="top"></a>
# **HW15 Meta Learning: Few-shot Classification**

This is the sample code for homework 15.

Please mail to mlta-2022-spring@googlegroups.com if you have any questions.

In [1]:
from google.colab import drive
import os
# 挂载Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


## **Step 0: Check GPU**

In [2]:
!nvidia-smi

Thu Sep 18 23:56:19 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   38C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## **Step 1: Download Data**

Run the cell to download data, which has been pre-processed by TAs.  
The dataset has been augmented, so extra data augmentation is not required.


In [3]:
workspace_dir = '/content/drive/MyDrive/ml2022spring-hw15'

# Download dataset
# !wget https://github.com/xraychen/shiny-disco/releases/download/Latest/omniglot.tar.gz \
#     -O "{workspace_dir}/Omniglot.tar.gz"
# !wget https://github.com/xraychen/shiny-disco/releases/download/Latest/omniglot-test.tar.gz \
#     -O "{workspace_dir}/Omniglot-test.tar.gz"

# Use `tar' command to decompress
# !tar -zxf "{workspace_dir}/Omniglot.tar.gz" -C "{workspace_dir}/"
# !tar -zxf "{workspace_dir}/Omniglot-test.tar.gz" -C "{workspace_dir}/"

## **Step 2: Build the model**

### Library importation

In [4]:
# Import modules we need
import glob, random
from collections import OrderedDict

import numpy as np
from tqdm.auto import tqdm

import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from PIL import Image
from IPython.display import display

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEVICE = {device}")

# Fix random seeds
random_seed = 0
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

DEVICE = cuda


### Model Construction Preliminaries

Since our task is image classification, we need to build a CNN-based model.  
However, to implement MAML algorithm, we should adjust some code in `nn.Module`.


Take a look at MAML pseudocode...

<img src="https://i.imgur.com/9aHlvfX.png" width="50%" />

On the 10-th line, what we take gradients on are those $\theta$ representing  
<font color="#0CC">**the original model parameters**</font> (outer loop) instead of those in  the  
<font color="#0C0">**inner loop**</font>, so we need to use `functional_forward` to compute the output  
logits of input image instead of `forward` in `nn.Module`.

The following defines these functions.

<!-- 由於在第10行，我們是要對原本的參數 θ 微分，並非 inner-loop (Line5~8) 的 θ' 微分，因此在 inner-loop，我們需要用 functional forward 的方式算出 input image 的 output logits，而不是直接用 nn.module 裡面的 forward（直接對 θ 微分）。在下面我們分別定義了 functional forward 以及 forward 函數。 -->

### Model block definition

In [5]:
def ConvBlock(in_ch: int, out_ch: int):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, 3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )


def ConvBlockFunction(x, w, b, w_bn, b_bn):
    x = F.conv2d(x, w, b, padding=1)
    x = F.batch_norm(
        x, running_mean=None, running_var=None, weight=w_bn, bias=b_bn, training=True
    )
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    return x

### Model definition

In [6]:
class Classifier(nn.Module):
    def __init__(self, in_ch, k_way):
        super(Classifier, self).__init__()
        self.conv1 = ConvBlock(in_ch, 64)
        self.conv2 = ConvBlock(64, 64)
        self.conv3 = ConvBlock(64, 64)
        self.conv4 = ConvBlock(64, 64)
        self.logits = nn.Linear(64, k_way)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.shape[0], -1)
        x = self.logits(x)
        return x

    def functional_forward(self, x, params):
        """
        Arguments:
        x: input images [batch, 1, 28, 28]
        params: model parameters,
                i.e. weights and biases of convolution
                     and weights and biases of
                                   batch normalization
                type is an OrderedDict

        Arguments:
        x: input images [batch, 1, 28, 28]
        params: The model parameters,
                i.e. weights and biases of convolution
                     and batch normalization layers
                It's an `OrderedDict`
        """
        for block in [1, 2, 3, 4]:
            x = ConvBlockFunction(
                x,
                params[f"conv{block}.0.weight"],
                params[f"conv{block}.0.bias"],
                params.get(f"conv{block}.1.weight"),
                params.get(f"conv{block}.1.bias"),
            )
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["logits.weight"], params["logits.bias"])
        return x

### Create Label

This function is used to create labels.  
In a N-way K-shot few-shot classification problem,
each task has `n_way` classes, while there are `k_shot` images for each class.  
This is a function that creates such labels.


In [7]:
def create_label(n_way, k_shot):
    return torch.arange(n_way).repeat_interleave(k_shot).long()


# Try to create labels for 5-way 2-shot setting
create_label(5, 2)

tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])

### Accuracy calculation

In [8]:
def calculate_accuracy(logits, labels):
    """utility function for accuracy calculation"""
    acc = np.asarray(
        [(torch.argmax(logits, -1).cpu().numpy() == labels.cpu().numpy())]
    ).mean()
    return acc

### Define Dataset

Define the dataset.  
The dataset returns images of a random character, with (`k_shot + q_query`) images,  
so the size of returned tensor is `[k_shot+q_query, 1, 28, 28]`.  


In [9]:
# Dataset for train and val
class Omniglot(Dataset):
    def __init__(self, data_dir, k_way, q_query, task_num=None):
        self.file_list = [
            f for f in glob.glob(data_dir + "**/character*", recursive=True)
        ]
        # limit task number if task_num is set
        if task_num is not None:
            self.file_list = self.file_list[: min(len(self.file_list), task_num)]
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.n = k_way + q_query

    def __getitem__(self, idx):
        sample = np.arange(20)

        # For random sampling the characters we want.
        np.random.shuffle(sample)
        img_path = self.file_list[idx]
        img_list = [f for f in glob.glob(img_path + "**/*.png", recursive=True)]
        img_list.sort()
        imgs = [self.transform(Image.open(img_file)) for img_file in img_list]
        # `k_way + q_query` examples for each character
        imgs = torch.stack(imgs)[sample[: self.n]]
        return imgs

    def __len__(self):
        return len(self.file_list)

## **Step 3: Learning Algorithms**

### Transfer learning

The solver first chose five task from the training set, then do normal classification training on the chosen five tasks. In inference, the model finetune for `inner_train_step` steps on the support set images, and than do inference on the query set images.

For consistant with the meta-learning solver, the base solver has the exactly same input and output format with the meta-learning solver.



In [10]:
def BaseSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False,
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        if train:
            """ training loop """
            # Use the support set to calculate loss
            labels = create_label(n_way, k_shot).to(device)
            logits = model.forward(support_set)
            loss = criterion(logits, labels)

            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, labels))
        else:
            """ validation / testing loop """
            # First update model with support set images for `inner_train_step` steps
            fast_weights = OrderedDict(model.named_parameters())


            for inner_step in range(inner_train_step):
                # Simply training
                train_label = create_label(n_way, k_shot).to(device)
                logits = model.functional_forward(support_set, fast_weights)
                loss = criterion(logits, train_label)

                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
                # Perform SGD
                fast_weights = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights.items(), grads)
                )

            if not return_labels:
                """ validation """
                val_label = create_label(n_way, q_query).to(device)

                logits = model.functional_forward(query_set, fast_weights)
                loss = criterion(logits, val_label)
                task_loss.append(loss)
                task_acc.append(calculate_accuracy(logits, val_label))
            else:
                """ testing """
                logits = model.functional_forward(query_set, fast_weights)
                labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    batch_loss = torch.stack(task_loss).mean()
    task_acc = np.mean(task_acc)

    if train:
        # Update model
        model.train()
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

    return batch_loss, task_acc

### Meta Learning

Here is the main Meta Learning algorithm.

Please finish the TODO blocks for the inner and outer loop update rules.

- For implementing FO-MAML you can refer to [p.25 of the slides](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=25&view=FitW).

- For the original MAML, you can refer to [the slides of meta learning (p.13 ~ p.18)](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=13&view=FitW).


In [11]:
def MetaSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:  # 遍历meta batch中的每个任务
        # ===== 新增：数据增强部分 =====
        if torch.rand(1).item() > 0.6:  # 40%的概率进行数据增强（随机数>0.6）
            # 随机选择旋转角度：50%概率旋转90度，50%概率旋转270度
            times = 1 if torch.rand(1).item() > 0.5 else 3  # times=1表示90度，times=3表示270度
            # 对图像进行旋转，[-1, -2]表示在最后两个维度（高度和宽度）上旋转
            meta_batch = torch.rot90(meta_batch, times, [-1, -2])

        # 分离支持集和查询集
        support_set = meta_batch[: n_way * k_shot]    # 前n_way*k_shot个样本作为支持集（用于快速适应）
        query_set = meta_batch[n_way * k_shot :]      # 剩余样本作为查询集（用于测试适应效果）

        # ===== MAML算法的核心实现 =====
        # 获取模型当前的所有参数，创建一个副本用于快速适应
        fast_weights = OrderedDict(model.named_parameters())

        # 内循环：在支持集上进行快速适应
        for inner_step in range(inner_train_step):  # 执行inner_train_step次内循环更新
            # 为支持集创建标签 [0,1,2,3,4] 对应5-way分类
            train_label = create_label(n_way, k_shot).to(device)

            # 使用当前的fast_weights进行前向传播
            logits = model.functional_forward(support_set, fast_weights)

            # 计算支持集上的损失
            loss = criterion(logits, train_label)

            # ===== 关键：计算梯度并更新fast_weights =====
            # 计算损失对fast_weights中每个参数的梯度
            # create_graph=True是为了支持二阶梯度（MAML需要）
            grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)

            # 使用梯度下降更新fast_weights：新参数 = 旧参数 - 学习率 × 梯度
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)  # 对每个参数进行更新
                for ((name, param), grad) in zip(fast_weights.items(), grads)  # 配对参数和对应梯度
            )

        # 内循环完成后，使用更新后的fast_weights在查询集上计算损失
        if not return_labels:  # 如果不是要返回标签（即正常训练/验证模式）
            # 为查询集创建标签
            val_label = create_label(n_way, q_query).to(device)

            # 使用适应后的参数在查询集上预测
            logits = model.functional_forward(query_set, fast_weights)

            # 计算查询集上的损失（这个损失用于更新原始模型参数）
            loss = criterion(logits, val_label)
            task_loss.append(loss)  # 存储这个任务的损失
            task_acc.append(calculate_accuracy(logits, val_label))  # 存储这个任务的准确率
        else:  # 如果要返回预测标签（测试模式）
            logits = model.functional_forward(query_set, fast_weights)
            labels.extend(torch.argmax(logits, -1).cpu().numpy())  # 获取预测类别

    # 如果是测试模式，直接返回预测标签
    if return_labels:
        return labels

    # ===== 外循环：Meta更新 =====
    model.train()  # 设置模型为训练模式
    optimizer.zero_grad()  # 清空之前的梯度

    # 计算所有任务损失的平均值（meta loss）
    meta_batch_loss = torch.stack(task_loss).mean()

    if train:  # 如果是训练模式
        # ===== 关键：执行meta更新 =====
        meta_batch_loss.backward()  # 反向传播计算原始模型参数的梯度
        optimizer.step()            # 使用梯度更新原始模型参数

    # 计算平均准确率
    task_acc = np.mean(task_acc)

    return meta_batch_loss, task_acc

## **Step 4: Initialization**

After defining all components we need, the following initialize a model before training.

### Hyperparameters

In [12]:
n_way = 5
k_shot = 1
q_query = 1
train_inner_train_step = 1
val_inner_train_step = 3
inner_lr = 0.4
meta_lr = 0.001
meta_batch_size = 32
max_epoch = 100
eval_batches = 20
train_data_path = "/content/drive/MyDrive/ml2022spring-hw15/omniglot/Omniglot/images_background/"

### Dataloader initialization

In [13]:
def dataloader_init(datasets, shuffle=True, num_workers=2):
    train_set, val_set = datasets
    train_loader = DataLoader(
        train_set,
        # The "batch_size" here is not \
        #    the meta batch size, but  \
        #    how many different        \
        #    characters in a task,     \
        #    i.e. the "n_way" in       \
        #    few-shot classification.
        batch_size=n_way,
        num_workers=num_workers,
        shuffle=shuffle,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set, batch_size=n_way, num_workers=num_workers, shuffle=shuffle, drop_last=True
    )

    train_iter = iter(train_loader)
    val_iter = iter(val_loader)
    return (train_loader, val_loader), (train_iter, val_iter)

### Model & optimizer initialization

In [14]:
def model_init():
    meta_model = Classifier(1, n_way).to(device)
    optimizer = torch.optim.Adam(meta_model.parameters(), lr=meta_lr)
    loss_fn = nn.CrossEntropyLoss().to(device)
    return meta_model, optimizer, loss_fn

### Utility function to get a meta-batch

In [15]:
def get_meta_batch(meta_batch_size, k_shot, q_query, data_loader, iterator):
    data = []
    for _ in range(meta_batch_size):
        try:
            # a "task_data" tensor is representing \
            #     the data of a task, with size of \
            #     [n_way, k_shot+q_query, 1, 28, 28]
            task_data = next(iterator)
        except StopIteration:
            iterator = iter(data_loader)
            task_data = next(iterator)
        train_data = task_data[:, :k_shot].reshape(-1, 1, 28, 28)
        val_data = task_data[:, k_shot:].reshape(-1, 1, 28, 28)
        task_data = torch.cat((train_data, val_data), 0)
        data.append(task_data)
    return torch.stack(data).to(device), iterator

<a name="mainprog" id="mainprog"></a>
## **Step 5: Main program for training & testing**

### Start training!
With `solver = 'base'`, the solver is a transfer learning algorithm.

Once you finish the TODO blocks in the `MetaSolver`, change the variable `solver = 'meta'` to start training with meta learning algorithm.


In [16]:
solver = 'meta'  # base, meta
meta_model, optimizer, loss_fn = model_init()

# init solver and dataset according to solver type
if solver == 'base':
    max_epoch = 5  # the base solver only needs 5 epochs
    Solver = BaseSolver
    train_set, val_set = torch.utils.data.random_split(
        Omniglot(train_data_path, k_shot, q_query, task_num=10), [5, 5]
    )
    (train_loader, val_loader), (train_iter, val_iter) = dataloader_init((train_set, val_set), shuffle=False)

elif solver == 'meta':
    Solver = MetaSolver
    dataset = Omniglot(train_data_path, k_shot, q_query)
    train_split = int(0.8 * len(dataset))
    val_split = len(dataset) - train_split
    train_set, val_set = torch.utils.data.random_split(
        dataset, [train_split, val_split]
    )
    (train_loader, val_loader), (train_iter, val_iter) = dataloader_init((train_set, val_set))
else:
    raise NotImplementedError

# Save Best Model 和 Early Stopping 设置
best_val_acc = 0.0
best_model_path = "/content/drive/MyDrive/ml2022spring-hw15/best_meta_model.pth"
patience = 20                         # 添加patience设置
patience_counter = 0                  # 添加patience计数器

# main training loop
for epoch in range(max_epoch):
    print("Epoch %d" % (epoch + 1))
    train_meta_loss = []
    train_acc = []

    # The "step" here is a meta-gradient update step
    for step in tqdm(range(max(1, len(train_loader) // meta_batch_size))):
        x, train_iter = get_meta_batch(
            meta_batch_size, k_shot, q_query, train_loader, train_iter
        )
        meta_loss, acc = Solver(
            meta_model,
            optimizer,
            x,
            n_way,
            k_shot,
            q_query,
            loss_fn,
            inner_train_step=train_inner_train_step
        )
        train_meta_loss.append(meta_loss.item())
        train_acc.append(acc)

    print("  Loss    : ", "%.3f" % (np.mean(train_meta_loss)), end="\t")
    print("  Accuracy: ", "%.3f %%" % (np.mean(train_acc) * 100))

    # See the validation accuracy after each epoch.
    val_acc = []
    for eval_step in tqdm(range(max(1, len(val_loader) // (eval_batches)))):
        x, val_iter = get_meta_batch(
            eval_batches, k_shot, q_query, val_loader, val_iter
        )
        _, acc = Solver(
            meta_model,
            optimizer,
            x,
            n_way,
            k_shot,
            q_query,
            loss_fn,
            inner_train_step=val_inner_train_step,
            train=False,
        )
        val_acc.append(acc)

    current_val_acc = np.mean(val_acc)
    print("  Validation accuracy: ", "%.3f %%" % (current_val_acc * 100))

    # ===== Save Best Model 逻辑 =====
    if current_val_acc > best_val_acc:
        best_val_acc = current_val_acc
        patience_counter = 0  # 重置patience计数器

        # 保存最佳模型
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': meta_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_acc': best_val_acc,
            'n_way': n_way,
            'k_shot': k_shot,
            'q_query': q_query,
            'solver_type': solver
        }, best_model_path)

        print(f"  New best model saved! Validation accuracy: {best_val_acc:.4f}")
    else:
        patience_counter += 1
        print(f"  No improvement. Best validation accuracy: {best_val_acc:.4f} (Patience: {patience_counter}/{patience})")

    # ===== Early Stopping 逻辑 =====
    if patience_counter >= patience:
        print(f"Early stopping after {patience} epochs without improvement!")
        break

# 训练结束信息
print(f"\n训练完成！")
print(f"最佳验证准确率: {best_val_acc:.4f}")
print(f"最佳模型已保存到: {best_model_path}")

Epoch 1


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  2.320	  Accuracy:  28.438 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  39.000 %
  New best model saved! Validation accuracy: 0.3900
Epoch 2


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.338	  Accuracy:  43.594 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  41.000 %
  New best model saved! Validation accuracy: 0.4100
Epoch 3


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.255	  Accuracy:  49.062 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  44.000 %
  New best model saved! Validation accuracy: 0.4400
Epoch 4


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.222	  Accuracy:  54.375 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  43.000 %
  No improvement. Best validation accuracy: 0.4400 (Patience: 1/20)
Epoch 5


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.191	  Accuracy:  56.563 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  51.000 %
  New best model saved! Validation accuracy: 0.5100
Epoch 6


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.161	  Accuracy:  55.313 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  45.000 %
  No improvement. Best validation accuracy: 0.5100 (Patience: 1/20)
Epoch 7


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.135	  Accuracy:  57.188 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  49.000 %
  No improvement. Best validation accuracy: 0.5100 (Patience: 2/20)
Epoch 8


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.067	  Accuracy:  61.094 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  52.000 %
  New best model saved! Validation accuracy: 0.5200
Epoch 9


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.067	  Accuracy:  57.344 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  46.000 %
  No improvement. Best validation accuracy: 0.5200 (Patience: 1/20)
Epoch 10


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.044	  Accuracy:  60.469 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  52.000 %
  New best model saved! Validation accuracy: 0.5200
Epoch 11


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.964	  Accuracy:  64.844 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  42.000 %
  No improvement. Best validation accuracy: 0.5200 (Patience: 1/20)
Epoch 12


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.961	  Accuracy:  64.531 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  62.000 %
  New best model saved! Validation accuracy: 0.6200
Epoch 13


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.973	  Accuracy:  63.750 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  59.000 %
  No improvement. Best validation accuracy: 0.6200 (Patience: 1/20)
Epoch 14


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.891	  Accuracy:  66.875 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  53.000 %
  No improvement. Best validation accuracy: 0.6200 (Patience: 2/20)
Epoch 15


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.882	  Accuracy:  67.031 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  60.000 %
  No improvement. Best validation accuracy: 0.6200 (Patience: 3/20)
Epoch 16


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.828	  Accuracy:  70.469 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  52.000 %
  No improvement. Best validation accuracy: 0.6200 (Patience: 4/20)
Epoch 17


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.856	  Accuracy:  68.594 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  65.000 %
  New best model saved! Validation accuracy: 0.6500
Epoch 18


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.797	  Accuracy:  72.656 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  58.000 %
  No improvement. Best validation accuracy: 0.6500 (Patience: 1/20)
Epoch 19


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.748	  Accuracy:  75.156 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  62.000 %
  No improvement. Best validation accuracy: 0.6500 (Patience: 2/20)
Epoch 20


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.753	  Accuracy:  75.312 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  76.000 %
  New best model saved! Validation accuracy: 0.7600
Epoch 21


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.686	  Accuracy:  76.719 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  72.000 %
  No improvement. Best validation accuracy: 0.7600 (Patience: 1/20)
Epoch 22


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.697	  Accuracy:  77.344 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  73.000 %
  No improvement. Best validation accuracy: 0.7600 (Patience: 2/20)
Epoch 23


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.639	  Accuracy:  78.438 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  70.000 %
  No improvement. Best validation accuracy: 0.7600 (Patience: 3/20)
Epoch 24


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.608	  Accuracy:  80.312 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  74.000 %
  No improvement. Best validation accuracy: 0.7600 (Patience: 4/20)
Epoch 25


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.545	  Accuracy:  82.500 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  79.000 %
  New best model saved! Validation accuracy: 0.7900
Epoch 26


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.570	  Accuracy:  81.250 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  77.000 %
  No improvement. Best validation accuracy: 0.7900 (Patience: 1/20)
Epoch 27


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.516	  Accuracy:  84.219 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  89.000 %
  New best model saved! Validation accuracy: 0.8900
Epoch 28


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.508	  Accuracy:  84.062 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  78.000 %
  No improvement. Best validation accuracy: 0.8900 (Patience: 1/20)
Epoch 29


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.475	  Accuracy:  84.531 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  82.000 %
  No improvement. Best validation accuracy: 0.8900 (Patience: 2/20)
Epoch 30


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.477	  Accuracy:  84.375 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  84.000 %
  No improvement. Best validation accuracy: 0.8900 (Patience: 3/20)
Epoch 31


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.474	  Accuracy:  84.844 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  87.000 %
  No improvement. Best validation accuracy: 0.8900 (Patience: 4/20)
Epoch 32


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.444	  Accuracy:  86.406 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  95.000 %
  New best model saved! Validation accuracy: 0.9500
Epoch 33


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.396	  Accuracy:  87.656 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  93.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 1/20)
Epoch 34


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.431	  Accuracy:  85.312 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  91.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 2/20)
Epoch 35


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.418	  Accuracy:  86.875 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  85.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 3/20)
Epoch 36


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.448	  Accuracy:  85.625 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  88.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 4/20)
Epoch 37


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.378	  Accuracy:  87.969 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  93.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 5/20)
Epoch 38


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.416	  Accuracy:  86.094 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  84.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 6/20)
Epoch 39


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.371	  Accuracy:  88.438 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  89.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 7/20)
Epoch 40


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.356	  Accuracy:  89.531 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  89.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 8/20)
Epoch 41


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.399	  Accuracy:  86.094 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  90.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 9/20)
Epoch 42


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.379	  Accuracy:  89.062 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  90.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 10/20)
Epoch 43


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.373	  Accuracy:  87.969 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  94.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 11/20)
Epoch 44


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.389	  Accuracy:  87.188 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  89.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 12/20)
Epoch 45


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.343	  Accuracy:  89.844 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  89.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 13/20)
Epoch 46


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.344	  Accuracy:  90.000 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  87.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 14/20)
Epoch 47


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.329	  Accuracy:  90.469 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  93.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 15/20)
Epoch 48


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.364	  Accuracy:  89.219 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  94.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 16/20)
Epoch 49


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.275	  Accuracy:  92.812 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  94.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 17/20)
Epoch 50


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.278	  Accuracy:  91.719 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  94.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 18/20)
Epoch 51


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.284	  Accuracy:  91.562 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  92.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 19/20)
Epoch 52


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.324	  Accuracy:  89.844 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  91.000 %
  No improvement. Best validation accuracy: 0.9500 (Patience: 20/20)
Early stopping after 20 epochs without improvement!

训练完成！
最佳验证准确率: 0.9500
最佳模型已保存到: /content/drive/MyDrive/ml2022spring-hw15/best_meta_model.pth


### Testing the result

Since the testing data is sampled by TAs in advance, you should not change the code in `OmnigloTest` dataset, otherwise your score may not be correct on the Kaggle leaderboard.

However, fell free to chagne the variable `inner_train_step` to set the training steps on the query set images.

In [17]:
import os

# test dataset
class OmniglotTest(Dataset):
    def __init__(self, test_dir):
        self.test_dir = test_dir
        self.n = 5

        self.transform = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, idx):
        support_files = [
            os.path.join(self.test_dir, "support", f"{idx:>04}", f"image_{i}.png")
            for i in range(self.n)
        ]
        query_files = [
            os.path.join(self.test_dir, "query", f"{idx:>04}", f"image_{i}.png")
            for i in range(self.n)
        ]

        support_imgs = torch.stack(
            [self.transform(Image.open(e)) for e in support_files]
        )
        query_imgs = torch.stack([self.transform(Image.open(e)) for e in query_files])

        return support_imgs, query_imgs

    def __len__(self):
        return len(os.listdir(os.path.join(self.test_dir, "support")))

In [18]:
test_inner_train_step = 10 # you can change this

test_batches = 20
#test_dataset = OmniglotTest("Omniglot-test")
test_dataset = OmniglotTest("/content/drive/MyDrive/ml2022spring-hw15/omniglot-test/Omniglot-test")
test_loader = DataLoader(test_dataset, batch_size=test_batches, shuffle=False)

output = []
for _, batch in enumerate(tqdm(test_loader)):
    support_set, query_set = batch
    x = torch.cat([support_set, query_set], dim=1)
    x = x.to(device)

    labels = Solver(
        meta_model,
        optimizer,
        x,
        n_way,
        k_shot,
        q_query,
        loss_fn,
        inner_train_step=test_inner_train_step,
        train=False,
        return_labels=True,
    )

    output.extend(labels)

# write to csv
output_path = "/content/drive/MyDrive/outputs/output.csv"
with open(output_path, "w") as f:
    f.write(f"id,class\n")
    for i, label in enumerate(output):
        f.write(f"{i},{label}\n")

  0%|          | 0/32 [00:00<?, ?it/s]

Download the `output.csv` and submit to Kaggle!

## **Reference**
1. Chelsea Finn, Pieter Abbeel, & Sergey Levine. (2017). [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.](https://arxiv.org/abs/1909.09157)
1. Aniruddh Raghu, Maithra Raghu, Samy Bengio, & Oriol Vinyals. (2020). [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML.](https://arxiv.org/abs/1909.09157)