# Description

* This notebook demonstrates how a few-shot learning on BirdCLEF 2022 dataset is going.
* For metric learning, I used Prototypical network[1] since it is one of basic approaches on few-shot learning.
* Most of the code in this notebook was adapted from [2].

## What this notebook supports

- learning classifier without overfitting to train dataset even if samples are few (few-shot learning)
    - classifying unseen 5 classes by feeding only 5 samples per each classes

## What this notebook doesn't support

- call/no call classifier
    - since the models are trained only by positive (bird call) samples, it can't distinguish background from bird call (; maybe you need another classifier, or you need to input background samples).
- multi-label classification
    - it only classifies most probable 1 class with inputting 5 second audio frame
    - `secondary_labels` are completely ignored

# Reference

* [1] https://arxiv.org/abs/1703.05175
* [2] https://github.com/Frankluox/LightningFSL

# Environment Setup

In [None]:
!pip install nb_black > /dev/null
!pip install torchinfo > /dev/null

In [None]:
!nvidia-smi

In [None]:
!pip freeze | grep torch

In [None]:
!pip freeze | grep librosa

In [None]:
from collections import Counter
import math
import os

from argparse import Namespace
from typing import List, Union, TypeVar, Iterator, Optional, Dict, Tuple, Any

from IPython.core.debugger import set_trace, Pdb
import librosa
import librosa.display
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import matplotlib.pyplot as plt
import wandb

from kaggle_secrets import UserSecretsClient
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.cli import (
    LightningCLI,
    LightningArgumentParser,
    SaveConfigCallback,
)
from pytorch_lightning.utilities.seed import seed_everything

from pytorch_lightning.trainer.trainer import Trainer
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from torch import nn, Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, Sampler

from torchmetrics import Accuracy, MeanMetric, Metric
from torchinfo import summary

plt.style.use("ggplot")

%load_ext lab_black
%load_ext autoreload
%autoreload 2

# Configuration

In [None]:
cfg = dict(
    sample_rate=32_000,
    hop_length=500,
    n_fft=2_000,
    fmin=0,
    fmax=16_000,
    wav_crop_sec=7,
    random_seed=3154,
    max_epochs=10,
)
metadata_path = (
    "../input/birdclef-2022-precomputed-melspec-hop-size500/spec_metadata.csv"
)
img_path = "../input/birdclef-2022-precomputed-melspec-hop-size500/train_audio"
train_transforms = [
    # T.TimeMasking(time_mask_param=80),
    # T.FrequencyMasking(freq_mask_param=32),
    # T.PitchShift(
    #    sample_rate=cfg["sample_rate"],
    #    n_steps=5,
    #    n_fft=cfg["n_fft"],
    #    hop_length=cfg["hop_length"],
    # ),
]
val_transforms = []
train_transform = nn.Sequential(*train_transforms)
val_transform = nn.Sequential(*val_transforms)

# W&B login

In [None]:
user_secrets = UserSecretsClient()

personal_key_for_api = user_secrets.get_secret("ke")
! wandb login $personal_key_for_api

# Split data

In [None]:
trainval = pd.read_csv(
    "../input/birdclef-2022-precomputed-melspec-hop-size500/spec_metadata.csv"
)

In [None]:
trainval_species = (
    trainval.groupby("primary_label")
    .agg(n_samples=("length", "count"))
    .query("n_samples >= 20")
)
trainval_species = set(trainval_species.index)
print(f"* {len(trainval_species)} species which have sample >= 20")

In [None]:
def train_val_split(trainval_species, train_ratio=0.8):
    train_species = set(
        np.random.choice(
            list(trainval_species),
            int(len(trainval_species) * train_ratio),
            replace=False,
        )
    )
    val_species = trainval_species - train_species
    return train_species, val_species

In [None]:
seed_everything(cfg["random_seed"])
train_species, val_species = train_val_split(trainval_species)
print(len(train_species), len(val_species))
print(list(train_species)[:3], list(val_species)[:3])

In [None]:
train = trainval.query("primary_label in @train_species")
val = trainval.query("primary_label in @val_species")
len(train) + len(val), len(trainval)

# Dataset

In [None]:
class BirdClefDataset(Dataset):
    def __init__(
        self,
        metadata_df,
        img_dir,
        cfg,
        transform=None,
        epsilon=1e-12,
    ):
        self.metadata_df = metadata_df
        self.img_dir = img_dir
        self.transform = transform
        self.cfg = cfg
        label = self.metadata_df.primary_label.to_numpy()
        self.species = np.unique(label)
        self.species2index = {s: i for i, s in enumerate(self.species)}
        self.label = np.array(list(map(lambda x: self.species2index[x], label)))
        self.epsilon = 1e-12

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        row = self.metadata_df.iloc[idx]
        cfg = self.cfg

        data_path = os.path.join(self.img_dir, row["filename"])
        spec = np.load(data_path).astype(np.float32)

        # calc_offset
        wav_len = spec.shape[1]
        duration = int(
            cfg["wav_crop_sec"] / cfg["hop_length"] * cfg["sample_rate"]
        )  # 30 / 500 * 32_000 = 1920
        max_offset = max(wav_len - duration, 0)
        offset = np.random.randint(max_offset + 1)

        # print(f"========== {row['filename']} ==========")
        # print(f"offset: {offset}")
        # print(f"wav_len: {wav_len}")
        # print(f"duration: {duration}")
        spec = spec[:, offset:]

        # crop spectrogram in time dimension
        if wav_len < duration:
            pad = duration - wav_len
            if pad >= wav_len:  # repeat the same timeframe
                n_repeat = duration // wav_len + 1
                spec = np.tile(spec, [1, n_repeat])
            else:
                spec_orig = (
                    spec.copy()
                )  # concat original timeframe with randomly cropped timeframe
                max_offs = wav_len - pad
                offs = np.random.randint(max_offs + 1)
                spec = np.concatenate([spec, spec_orig[:, offs : offs + pad]], axis=1)
        spec = spec[:, :duration]
        # print(f"spec.shape: {spec.shape}")

        # channel normalization
        mean, std = spec.mean(), spec.std()
        spec = (spec - mean) / max(std, self.epsilon)

        # (F, T) -> (1, F, T)
        spec = np.expand_dims(spec, 0)

        # set label
        label = row["primary_label"]
        # label = self.species2index[label]

        # numpy to tensor
        spec = torch.from_numpy(spec).float()

        # transform
        if self.transform:
            spec = self.transform(spec)
        return spec, label

# Data Module

In [None]:
T_co = TypeVar("T_co", covariant=True)


class CategoriesSampler(Sampler[T_co]):
    r"""Sampler that collects data into several few-shot learning tasks"""

    def __init__(
        self,
        labels: Union[List, "np.ndarray"],
        num_task: int,
        way: int,
        total_sample_per_class: int,
        total_batch_size: int = 4,
        drop_last: bool = False,
    ) -> None:
        """
        Args:
           labels: The corresponding labels of the whole dataset .
           num_task: The number of tasks within one epoch.
           way: The number of classes within one task.
           total_sample_per_class: The number of samples within each few-shot class(all samples from support and query).
           total_batch_size: The number of tasks to handle per iteration.
           drop_last (bool, optional): if ``True``, then the sampler will drop the
               tail of the data to make it evenly divisible across the number of
               replicas. If ``False``, the sampler will add extra indices to make
               the data evenly divisible across the replicas. Default: ``False``.
        """
        self.num_task = num_task
        self.way = way
        self.total_sample_per_class = total_sample_per_class
        self.drop_last = drop_last
        self.total_batch_size = total_batch_size
        self.per_gpu_batch_size = self.total_batch_size
        self.m_ind = None

        if self.drop_last:
            self.num_iteration = math.floor(self.num_task / self.total_batch_size)
        else:
            self.num_iteration = math.ceil(self.num_task / self.total_batch_size)

        labels = np.array(labels)  # all data labels
        self.m_ind = {}  # the data index of each class
        classes = np.unique(labels)
        for c in classes:
            ind = np.argwhere(labels == c).reshape(-1)  # all data index of this class
            # ind = torch.from_numpy(ind)
            self.m_ind[c] = ind

    def __len__(self) -> int:
        return self.num_iteration

    def __iter__(self) -> Iterator[T_co]:
        # print(self.num_iteration)
        for _ in range(self.num_iteration):
            tasks = []
            for i in range(self.per_gpu_batch_size):
                task = []
                # random sample num_class indexs,e.g. 5
                classes = torch.randperm(len(self.m_ind))[: self.way].numpy()
                # print(f"{j}: {i}: {dist.get_rank()}: {classes}")
                # print(classes)
                for c in classes:
                    # sample total_sample_per_class data index of this class
                    l = self.m_ind[c]  # all data indexs of this class
                    pos = torch.randperm(len(l))[: self.total_sample_per_class]
                    task.append(l[pos])
                tasks.append(np.stack(task).transpose().reshape(-1))
            tasks = np.stack(tasks).reshape(-1)
            yield tasks

In [None]:
class FewShotDataModule(LightningDataModule):
    def __init__(
        self,
        train,
        val,
        img_path,
        cfg,
        train_transform,
        val_transform,
        way: int = 5,
        num_query: int = 15,
        drop_last: Optional[bool] = None,
        num_gpus: int = 1,
        train_batch_size: int = 1,
        val_batch_size: int = 4,
        train_num_workers: int = 2,
        val_num_workers: int = 2,
        train_num_task_per_epoch: Optional[int] = 1000,
        val_num_task: int = 200,
        train_shot: Optional[int] = 5,
        val_shot: int = 5,
        train_dataset_params: Dict = {},
        val_test_dataset_params: Dict = {},
    ) -> None:
        super().__init__()
        self.train = train
        self.val = val
        self.img_path = img_path
        self.cfg = cfg
        self.train_transform = train_transform
        self.val_transform = val_transform

        self.way = way
        self.num_query = num_query
        self.drop_last = drop_last
        self.num_gpus = num_gpus

        self.train_num_workers = train_num_workers
        self.train_batch_size = train_batch_size
        self.train_batch_sampler = None
        self.train_num_task_per_epoch = train_num_task_per_epoch
        self.train_shot = train_shot
        self.train_dataset_params = train_dataset_params

        self.val_num_workers = val_num_workers
        self.val_batch_size = val_batch_size
        self.val_batch_sampler = None
        self.val_num_task = val_num_task
        self.val_shot = val_shot
        self.val_test_dataset_params = val_test_dataset_params

    def set_train_dataset(self):
        self.train_dataset = BirdClefDataset(
            self.train, self.img_path, self.cfg, transform=self.train_transform
        )

    def set_val_dataset(self):
        self.val_dataset = BirdClefDataset(
            self.val, self.img_path, self.cfg, transform=self.val_transform
        )

    def set_sampler(self):
        self.train_batch_sampler = CategoriesSampler(
            self.train_dataset.label,
            self.train_num_task_per_epoch,
            self.way,
            self.train_shot + self.num_query,
            self.train_batch_size,
            self.drop_last,
        )
        self.val_batch_sampler = CategoriesSampler(
            self.val_dataset.label,
            self.val_num_task,
            self.way,
            self.val_shot + self.num_query,
            self.val_batch_size,
            self.drop_last,
        )

    def setup(self, stage=None):
        self.set_train_dataset()
        self.set_val_dataset()
        self.set_sampler()

    def train_dataloader(self):
        loader = DataLoader(
            self.train_dataset,
            shuffle=False,
            num_workers=self.train_num_workers,
            batch_sampler=self.train_batch_sampler,
            pin_memory=True,
        )
        return loader

    def val_dataloader(self):
        loader = DataLoader(
            self.val_dataset,
            shuffle=False,
            num_workers=self.val_num_workers,
            batch_sampler=self.val_batch_sampler,
            pin_memory=True,
        )
        return loader

In [None]:
datamodule = FewShotDataModule(
    train, val, img_path, cfg, train_transform, val_transform
)
datamodule.setup()
for batch, label in datamodule.train_dataloader():
    print(batch.shape)
    break

# Model

## Backbone

In [None]:
def floor_power(num, divisor, power):
    """Performs what we call a floor power, a recursive fixed division process
        with a flooring between each time

    Args:
        num (int or float):The original number to divide from
        divisor (int or float): The actual divisor for the number
        power (int): How many times we apply this divide and then floor

    Returns:
        int: The numerical result of the floor division process
    """
    for _ in range(power):
        num = np.floor(num / divisor)
    return num

In [None]:
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p)


class GeMPooling(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = p
        self.eps = eps

    def forward(self, x):
        x = gem(x, p=self.p, eps=self.eps)
        return x

In [None]:
def conv3x3(in_planes, out_planes):
    return nn.Conv2d(in_planes, out_planes, 3, padding=1, bias=False)


def norm_layer(planes):
    return nn.BatchNorm2d(planes)


class Block(nn.Module):
    def __init__(self, inplanes, outplanes, pool_dim):
        super().__init__()
        self.relu = nn.LeakyReLU(0.1)
        self.conv = conv3x3(inplanes, outplanes)
        self.bn = norm_layer(outplanes)
        self.maxpool = nn.MaxPool2d(pool_dim)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.maxpool(out)
        return out


class ConvN(nn.Module):
    def __init__(self, channels=[1, 64, 64, 64, 64], pool_dim=(3, 3)):
        super().__init__()

        self.nn = nn.Sequential(
            *[
                Block(ci, co, pool_dim)
                for i, (ci, co) in enumerate(zip(channels[:-1], channels[1:]))
            ]
        )

    def forward(self, x):
        out = self.nn(x)
        return out


class ConvNFeatureExtractor(nn.Module):
    def __init__(
        self,
        trial_shape=[-1, -1, 128, 448],
        channels=[1, 64, 64, 128, 128],
        pool_dim=(2, 3),
        out_dim=128,
        pool_type="flatten",
    ):
        super().__init__()
        self.pool_type = pool_type

        self.conv_encoder = ConvN(channels=channels, pool_dim=pool_dim)
        num_convs = len(channels) - 1
        num_logits = int(
            channels[-1]
            * floor_power(trial_shape[2], pool_dim[0], num_convs)
            * floor_power(trial_shape[3], pool_dim[1], num_convs)
        )
        if self.pool_type == "dense":
            self.pool = nn.Sequential(
                nn.Dropout(p=0.3),
                nn.BatchNorm1d(num_logits, eps=1e-05, momentum=0.1, affine=True),
                nn.Linear(in_features=num_logits, out_features=out_dim),
            )
        elif self.pool_type == "gem":
            self.pool = GeMPooling(p=3)

    def forward(self, x):
        x = self.conv_encoder(x)
        if self.pool_type == "flatten":
            x = x.view(x.size(0), -1)
        elif self.pool_type == "dense":
            x = x.view(x.size(0), -1)
            x = self.pool(x)
        elif self.pool_type == "adaptive_average":
            x = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
        elif self.pool_type == "normed_adaptive_average":
            F.normalize(x, p=2, dim=1, eps=1e-12)
            x = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
        elif self.pool_type == "gem":
            x = self.pool(x).view(x.size(0), -1)
        else:
            raise NotImplementedError
        assert x.dim() == 2, x.dim()

        return x

In [None]:
batch_size = 2
shot = 5
way = 5
n_query = 15
duration = int(cfg["wav_crop_sec"] / cfg["hop_length"] * cfg["sample_rate"])
total_samples = batch_size * (way * (shot + n_query))
summary(ConvNFeatureExtractor(), input_size=(total_samples, 1, 128, duration))

## Classifier

In [None]:
def L2SquareDist(A: Tensor, B: Tensor, average: bool = True) -> Tensor:
    r"""calculate parwise euclidean distance between two batchs of features.

    Args:
        A: Torch feature tensor. size:[Batch_size, Na, nC]
        B: Torch feature tensor. size:[Batch_size, Nb, nC]
    Output:
        dist: The calculated distance tensor. size:[Batch_size, Na, Nb]
    """
    assert A.dim() == 3
    assert B.dim() == 3
    assert A.size(0) == B.size(0) and A.size(2) == B.size(2)
    nB = A.size(0)
    Na = A.size(1)
    Nb = B.size(1)
    nC = A.size(2)

    # AB = A * B = [nB x Na x nC] * [nB x nC x Nb] = [nB x Na x Nb]
    AB = torch.bmm(A, B.transpose(1, 2))

    AA = (A * A).sum(dim=2, keepdim=True).view(nB, Na, 1)  # [nB x Na x 1]
    BB = (B * B).sum(dim=2, keepdim=True).view(nB, 1, Nb)  # [nB x 1 x Nb]
    # l2squaredist = A*A + B*B - 2 * A * B
    dist = AA.expand_as(AB) + BB.expand_as(AB) - 2 * AB
    if average:
        dist = dist / nC

    return dist

In [None]:
class PN_head(nn.Module):
    r"""The metric-based protypical classifier from ``Prototypical Networks for Few-shot Learning''.

    Args:
        metric: Whether use cosine or enclidean distance.
        scale_cls: The initial scale number which affects the following softmax function.
        learn_scale: Whether make scale number learnable.
        normalize: Whether normalize each spatial dimension of image features before average pooling.
    """

    def __init__(
        self,
        metric: str = "euclidean",
        scale_cls: int = 10.0,
        learn_scale: bool = True,
        normalize: bool = True,
        pool_feature: bool = False,
    ) -> None:
        super().__init__()
        assert metric in ["cosine", "euclidean"]
        if learn_scale:
            self.scale_cls = nn.Parameter(
                torch.FloatTensor(1).fill_(scale_cls), requires_grad=True
            )
        else:
            self.scale_cls = scale_cls
        self.metric = metric
        self.normalize = normalize

    def forward(
        self, features_test: Tensor, features_train: Tensor, way: int, shot: int
    ) -> Tensor:
        r"""Take batches of few-shot training examples and testing examples as input,
            output the logits of each testing examples.

        Args:
            features_test: Testing examples. size: [batch_size, num_query, d]
            features_train: Training examples which has labels like:[abcdabcdabcd].
                            size: [batch_size, way*shot, d]
            way: The number of classes of each few-shot classification task.
            shot: The number of training images per class in each few-shot classification
                  task.
        Output:
            classification_scores: The calculated logits of testing examples.
                                   size: [batch_size, num_query, way]
        """
        assert features_train.dim() == 3
        assert features_test.dim() == 3

        batch_size = features_train.size(0)
        if self.metric == "cosine":
            features_train = F.normalize(features_train, p=2, dim=2, eps=1e-12)

        # prototypes: [batch_size, way, c]
        prototypes = torch.mean(
            features_train.reshape(batch_size, shot, way, -1), dim=1
        )
        prototypes = F.normalize(prototypes, p=2, dim=2, eps=1e-12)

        if self.metric == "cosine":
            features_test = F.normalize(features_test, p=2, dim=2, eps=1e-12)
            # [batch_size, num_query, c] * [batch_size, c, way] -> [batch_size, num_query, way]
            classification_scores = -self.scale_cls * torch.bmm(
                features_test, prototypes.transpose(1, 2)
            )
        elif self.metric == "euclidean":
            classification_scores = self.scale_cls * L2SquareDist(
                features_test, prototypes
            )
        return classification_scores

# Train/Evaluation Module

In [None]:
class CategoricalAccuracy(Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        predictions = preds.argmax(dim=-1)
        assert predictions.shape == target.shape
        self.correct += torch.sum(predictions == target)
        self.total += target.numel()

    def compute(self):
        return self.correct.float() / self.total

In [None]:
def epoch_wrapup(pl_module: LightningModule, mode: str):
    r"""On the end of each epoch, log information of the whole
        epoch and reset all metrics.

    Args:
        pl_module: An instance of LightningModule.
        mode: The current mode (train, val or test).
    """
    assert mode in ["train", "val", "test"]
    value = getattr(pl_module, f"{mode}_loss").compute()
    if mode == "train":
        pl_module.log(f"{mode}/loss_epoch", value)
    getattr(pl_module, f"{mode}_loss").reset()
    value = getattr(pl_module, f"{mode}_acc").compute()
    if mode == "train":
        pl_module.log(f"{mode}/acc_epoch", value)
    getattr(pl_module, f"{mode}_acc").reset()


def set_schedule(pl_module):
    r"""Set the optimizer and scheduler for training.

    Supported optimizer:
        Adam and SGD
    Supported scheduler:
        cosine scheduler and decaying on specified epochs

    Args:
        pl_module: An instance of LightningModule.
    """
    lr = pl_module.hparams.lr
    wd = pl_module.hparams.weight_decay
    decay_scheduler = pl_module.hparams.decay_scheduler
    optim_type = pl_module.hparams.optim_type

    if optim_type == "adamw":
        optimizer = AdamW(pl_module.parameters(), weight_decay=wd, lr=lr, amsgrad=True)
    elif optim_type == "sgd":
        optimizer = SGD(
            pl_module.parameters(), momentum=0.9, nesterov=True, weight_decay=wd, lr=lr
        )
    else:
        raise RuntimeError(
            "optim_type not supported.\
                            Try to implement your own optimizer."
        )

    if decay_scheduler == "cosine":
        max_steps = pl_module.trainer.max_steps
        if (max_steps is None) or (max_steps == -1):
            length_epoch = len(pl_module.trainer.datamodule.train_dataloader())
            max_steps = length_epoch * pl_module.trainer.max_epochs

        print(f"max_steps: {max_steps}")
        scheduler = {
            "scheduler": CosineAnnealingLR(optimizer, max_steps),
            "interval": "step",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    elif decay_scheduler == "specified_epochs":
        decay_epochs = pl_module.hparams.decay_epochs
        decay_power = pl_module.hparams.decay_power
        assert decay_epochs is not None and decay_power is not None
        scheduler = {
            "scheduler": MultiStepLR(
                optimizer, milestones=decay_epochs, gamma=decay_power
            ),
            "interval": "epoch",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    elif decay_scheduler is None:
        return optimizer
    else:
        raise RuntimeError(
            "decay scheduler not supported.\
                            Try to implement your own scheduler."
        )

In [None]:
class BaseFewShotModule(LightningModule):
    r"""Template for all few-shot learning models."""

    def __init__(
        self,
        backbone_name: str = "ConvN",
        way: int = 5,
        train_shot: Optional[int] = None,
        val_shot: int = 5,
        test_shot: int = 5,
        num_query: int = 15,
        train_batch_size_per_gpu: Optional[int] = None,
        val_batch_size_per_gpu: int = 4,
        test_batch_size_per_gpu: int = 4,
        lr: float = 0.001,
        weight_decay: float = 5e-4,
        decay_scheduler: Optional[str] = "cosine",
        optim_type: str = "adamw",
        decay_epochs: Union[List, Tuple, None] = None,
        decay_power: Optional[float] = None,
        backbone_kwargs: Dict = {},
    ) -> None:
        """
        Args:
            backbone_name: The name of the feature extractor,
                        which should match the correspond
                        file name in architectures.feature_extractor
            way: The number of classes within one task.
            train_shot: The number of samples within each few-shot
                        support class during training.
                        For meta-learning only.
            val_shot: The number of samples within each few-shot
                    support class during validation.
            test_shot: The number of samples within each few-shot
                    support class during testing.
            num_query: The number of samples within each few-shot
                    query class.
            train_batch_size_per_gpu: The batch size of training per GPU.
            val_batch_size_per_gpu: The batch size of validation per GPU.
            test_batch_size_per_gpu: The batch size of testing per GPU.
            lr: The initial learning rate.
            weight_decay: The weight decay parameter.
            decay_scheduler: The scheduler of optimizer.
                            "cosine" or "specified_epochs".
            optim_type: The optimizer type.
                        "sgd" or "adam"
            decay_epochs: The list of decay epochs of decay_scheduler "specified_epochs".
            decay_power: The decay power of decay_scheduler "specified_epochs"
                        at eachspeicified epoch.
                        i.e., adjusted_lr = lr * decay_power
            backbone_kwargs: The parameters for creating backbone network.
        """
        super().__init__()
        self.save_hyperparameters()
        self.backbone = ConvNFeatureExtractor()
        self.label = torch.arange(way, dtype=torch.int8).repeat(num_query)
        self.label = self.label.type(torch.LongTensor).reshape(-1)

        self.set_metrics()

    def train_forward(self, batch):
        r"""Here implements the forward function of training.

        Output: logits
        Args: (can be dynamically adjusted)
            batch: a batch from train_dataloader.
        """
        raise NotImplementedError

    def val_test_forward(self, batch, batch_size, way, shot):
        r"""Here implements the forward function of validation and testing.

        Output: logits
        Args: (can be dynamically adjusted)
            batch: a batch from val_dataloader.
            batch_size: number of tasks during one iteration.
            way: The number of classes within one task.
            shot: The number of samples within each few-shot support class.
        """
        raise NotImplementedError

    def shared_step(self, batch, mode):
        r"""The shared operation across
            validation, testing and potentially training (meta-learning).

        Args:
            batch: a batch from val_dataloader.
            mode: train, val or test
        """
        assert mode in ["train", "val", "test"]
        if mode == "train":
            flag = "train"
        else:
            flag = "val_test"
        foward_function = getattr(self, f"{flag}_forward")
        batch_size_per_gpu = getattr(self.hparams, f"{mode}_batch_size_per_gpu")
        shot = getattr(self.hparams, f"{mode}_shot")
        # label
        # print(batch[0].shape)
        distances = foward_function(batch, batch_size_per_gpu, self.hparams.way, shot)
        # import pdb
        # pdb.set_trace()
        label = (
            torch.unsqueeze(self.label, 0)
            .repeat(batch_size_per_gpu, 1)
            .reshape(-1)
            .to(distances.device)
        )
        distances = distances.reshape(label.size(0), -1)

        # loss = F.cross_entropy(-distances, label)
        loss = F.nll_loss(F.log_softmax(-distances, dim=1), label)

        y_pred = (-distances).softmax(dim=1)

        log_loss = getattr(self, f"{mode}_loss")(loss)
        accuracy = getattr(self, f"{mode}_acc")(y_pred, label)
        self.log(f"{mode}/loss", log_loss)
        self.log(f"{mode}/acc", accuracy)
        return loss

    def training_step(self, batch, batch_idx):
        if (
            self.hparams.train_shot == None
            or self.hparams.train_batch_size_per_gpu == None
        ):
            raise RuntimeError(
                "train_shot or train_batch_size not specified.\
                                Please implement your own training step if the\
                                 training is not meta-learning."
            )
        return self.shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        _ = self.shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        _ = self.shared_step(batch, "test")

    def training_epoch_end(self, outs):
        epoch_wrapup(self, "train")

    def validation_epoch_end(self, outs):
        epoch_wrapup(self, "val")

    def test_epoch_end(self, outs):
        epoch_wrapup(self, "test")

    def configure_optimizers(self):
        return set_schedule(self)

    def set_metrics(self):
        r"""Set basic logging metrics for few-shot learning."""
        for split in ["train", "val", "test"]:
            setattr(self, f"{split}_loss", MeanMetric())
            setattr(self, f"{split}_acc", CategoricalAccuracy())

In [None]:
class ProtoNet(BaseFewShotModule):
    r"""The datamodule implementing Prototypical Network."""

    def __init__(
        self,
        metric: str = "euclidean",
        scale_cls: float = 10.0,
        normalize: bool = True,
        backbone_name: str = "ConvN",
        way: int = 5,
        train_shot: int = 5,
        val_shot: int = 5,
        test_shot: int = 5,
        num_query: int = 15,
        train_batch_size_per_gpu: int = 1,
        val_batch_size_per_gpu: int = 4,
        test_batch_size_per_gpu: int = 4,
        lr: float = 0.1,
        weight_decay: float = 5e-4,
        decay_scheduler: Optional[str] = "cosine",
        optim_type: str = "adamw",
        decay_epochs: Union[List, Tuple, None] = None,
        decay_power: Optional[float] = None,
        backbone_kwargs: Dict = {},
        **kwargs,
    ) -> None:
        """
        Args:
            metric: what metrics applied. "cosine" or "euclidean".
            scale_cls: The initial scale number which affects the
                    following softmax function.
            normalize: Whether normalize each spatial dimension of image features before average pooling.
            backbone_name: The name of the feature extractor,
                        which should match the correspond
                        file name in architectures.feature_extractor
            way: The number of classes within one task.
            train_shot: The number of samples within each few-shot
                        support class during training.
                        For meta-learning only.
            val_shot: The number of samples within each few-shot
                    support class during validation.
            test_shot: The number of samples within each few-shot
                    support class during testing.
            num_query: The number of samples within each few-shot
                    query class.
            train_batch_size_per_gpu: The batch size of training per GPU.
            val_batch_size_per_gpu: The batch size of validation per GPU.
            test_batch_size_per_gpu: The batch size of testing per GPU.
            lr: The initial learning rate.
            weight_decay: The weight decay parameter.
            decay_scheduler: The scheduler of optimizer.
                            "cosine" or "specified_epochs".
            optim_type: The optimizer type.
                        "sgd" or "adam"
            decay_epochs: The list of decay epochs of decay_scheduler "specified_epochs".
            decay_power: The decay power of decay_scheduler "specified_epochs"
                        at eachspeicified epoch.
                        i.e., adjusted_lr = lr * decay_power
            backbone_kwargs: The parameters for creating backbone network.
        """
        super().__init__(
            backbone_name,
            way,
            train_shot,
            val_shot,
            test_shot,
            num_query,
            train_batch_size_per_gpu,
            val_batch_size_per_gpu,
            test_batch_size_per_gpu,
            lr,
            weight_decay,
            decay_scheduler,
            optim_type,
            decay_epochs,
            decay_power,
            backbone_kwargs,
        )
        self.classifier = PN_head(metric, scale_cls, normalize=normalize)

    def forward(self, batch, batch_size, way, shot):
        r"""Since PN is a meta-learning method,
            the model forward process is the same for train, val and test.

        Args:
            batch: a batch from val_dataloader.
            batch_size: number of tasks during one iteration.
            way: The number of classes within one task.
            shot: The number of samples within each few-shot support class.
        """
        num_support_samples = way * shot
        data, _ = batch
        data = self.backbone(data)  # (B * (N + Q), D)

        assert data.dim() == 2
        data = data.reshape([batch_size, -1, data.size(-1)])  # (B, N + Q, D)
        data_support = data[:, :num_support_samples]  # (B, N, D)
        data_query = data[:, num_support_samples:]  # (B, Q, D)
        distances = self.classifier(data_query, data_support, way, shot)
        return distances

    def train_forward(self, batch, batch_size, way, shot):
        return self(batch, batch_size, way, shot)

    def val_test_forward(self, batch, batch_size, way, shot):
        return self(batch, batch_size, way, shot)

In [None]:
train_batch_size = 2
val_batch_size = 4

model = ProtoNet(
    train_batch_size_per_gpu=train_batch_size, val_batch_size_per_gpu=val_batch_size
)

# trainer = Trainer(precision=16)  # need torch >= 1.10
# trainer = Trainer()
wandb_logger = WandbLogger(
    project="BirdCLEF22",
    name="5-way-5-shot-prototypical-conv4w-flatten-7s",
    log_model="all",
)
callbacks = [
    LearningRateMonitor(logging_interval="step"),
    ModelCheckpoint(verbose=True, save_last=True, monitor="val/acc", mode="max"),
]

trainer = Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=cfg["max_epochs"],
    logger=wandb_logger,
    callbacks=callbacks,
    # precision=16,
)
datamodule = FewShotDataModule(
    train,
    val,
    img_path,
    cfg,
    train_transform,
    val_transform,
    train_batch_size=train_batch_size,
    val_batch_size=val_batch_size,
)

seed_everything(cfg["random_seed"])
try:
    trainer.fit(model, datamodule=datamodule)
finally:
    wandb.finish()