In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from hcmus.utils import data_utils

splits = data_utils.get_data_splits()

[32m2025-07-22 09:32:25.831[0m | [1mINFO    [0m | [36mhcmus.core.appconfig[0m:[36m<module>[0m:[36m7[0m - [1mLoad DotEnv: True[0m
[32m2025-07-22 09:32:27.122[0m | [1mINFO    [0m | [36mhcmus.lbs._label_studio_connector[0m:[36mget_tasks[0m:[36m152[0m - [1mNew `page_to` applied: 35[0m
Loading tasks: 100%|██████████| 35/35 [00:11<00:00,  3.01it/s]
Downloading images: 100%|██████████| 3443/3443 [00:06<00:00, 540.24it/s] 
[32m2025-07-22 09:32:45.342[0m | [1mINFO    [0m | [36mhcmus.lbs._label_studio_connector[0m:[36mget_tasks[0m:[36m152[0m - [1mNew `page_to` applied: 5[0m
Loading tasks: 100%|██████████| 5/5 [00:03<00:00,  1.42it/s]
Downloading images: 100%|██████████| 402/402 [00:02<00:00, 197.78it/s]
[32m2025-07-22 09:32:50.972[0m | [1mINFO    [0m | [36mhcmus.lbs._label_studio_connector[0m:[36mget_tasks[0m:[36m152[0m - [1mNew `page_to` applied: 4[0m
Loading tasks: 100%|██████████| 4/4 [00:01<00:00,  2.57it/s]
Downloading images: 100%|██████████|

In [3]:
import random
from torchvision import transforms as T

transform_train = T.Compose([
    T.Lambda(lambda img: T.Resize(random.randint(32, 224))(img)),
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(180),
    T.ColorJitter(
        brightness=0.2,
        contrast=0.0,
        saturation=0.0,
        hue=0.0
    ),
    T.RandomResizedCrop(
        size=224,
        scale=(0.8, 1.2),
        ratio=(0.75, 1.3333)
    ),
    # T.RandAugment(num_ops=5),
    T.ToTensor()
])

transform_test = T.Compose([
    # T.Resize((size, size)),
    T.Resize((224, 224)),
    T.ToTensor()
])

In [4]:
datasets = data_utils.get_image_datasets_v2(splits, transform_train, transform_test, random_margin=0.2, return_metadata=False)
# dataloaders = data_utils.get_data_loaders_v2(datasets, {
#     "train": True
# })

[32m2025-07-22 09:32:54.513[0m | [1mINFO    [0m | [36mhcmus.data._torch_dataset_v2[0m:[36m__init__[0m:[36m57[0m - [1mApply random_margin=0.2[0m
[32m2025-07-22 09:32:54.827[0m | [1mINFO    [0m | [36mhcmus.data._torch_dataset_v2[0m:[36m__init__[0m:[36m68[0m - [1mAuto infer `label2idx` mapping, mapping length: 99.[0m
[32m2025-07-22 09:32:54.831[0m | [1mINFO    [0m | [36mhcmus.data._torch_dataset_v2[0m:[36m__init__[0m:[36m57[0m - [1mApply random_margin=0[0m
[32m2025-07-22 09:32:55.205[0m | [1mINFO    [0m | [36mhcmus.data._torch_dataset_v2[0m:[36m__init__[0m:[36m57[0m - [1mApply random_margin=0[0m


In [5]:
import torch
import random
from torch.utils.data import Dataset
from typing import Tuple, List, Dict, Any
from collections import defaultdict
from torch.utils.data import DataLoader
from easyfsl.samplers import TaskSampler
from hcmus.data import CroppedImageDatasetV2

class CrossDatasetEpisodeSampler:
    """
    Sampler that creates episodes for prototypical networks where:
    - Support samples (n-shot) come from training dataset
    - Query samples (n-query) come from validation dataset

    This enables cross-dataset evaluation and domain adaptation scenarios.
    """

    def __init__(self,
                 train_dataset: CroppedImageDatasetV2,
                 val_dataset: CroppedImageDatasetV2,
                 n_way: int,
                 n_shot: int,
                 n_query: int,
                 train_class_to_indices: Dict[int, List[int]] = None,
                 val_class_to_indices: Dict[int, List[int]] = None):
        """
        Args:
            train_dataset: Training dataset for support samples
            val_dataset: Validation dataset for query samples
            n_way: Number of classes per episode
            n_shot: Number of support samples per class (from train_dataset)
            n_query: Number of query samples per class (from val_dataset)
            train_class_to_indices: Mapping of class labels to indices in train_dataset
            val_class_to_indices: Mapping of class labels to indices in val_dataset
        """
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_query = n_query

        # Build class-to-indices mappings if not provided
        self.train_class_to_indices = train_class_to_indices or self._build_class_indices(train_dataset)
        self.val_class_to_indices = val_class_to_indices or self._build_class_indices(val_dataset)

        # Find common classes between datasets using dataset.classes
        train_classes = set(range(len(train_dataset.classes)))
        val_classes = set(range(len(val_dataset.classes)))

        # Only include classes that exist in both datasets and have samples
        train_available_classes = set(self.train_class_to_indices.keys())
        val_available_classes = set(self.val_class_to_indices.keys())

        self.common_classes = list(
            train_classes.intersection(val_classes)
            .intersection(train_available_classes)
            .intersection(val_available_classes)
        )

        if len(self.common_classes) < n_way:
            raise ValueError(f"Not enough common classes ({len(self.common_classes)}) for {n_way}-way episodes")

        # Validate that each class has enough samples
        self._validate_sample_counts()

    def _build_class_indices(self, dataset: CroppedImageDatasetV2) -> Dict[int, List[int]]:
        """Build mapping from class labels to sample indices"""
        class_to_indices = defaultdict(list)

        for idx, sample in enumerate(dataset.samples):
            # Dataset returns (tensor_image, label_idx, metadata)
            label_idx = sample.get("label_idx")
            class_to_indices[label_idx].append(idx)

        return dict(class_to_indices)

    def _validate_sample_counts(self):
        """Validate that each common class has enough samples"""
        for class_id in self.common_classes:
            train_count = len(self.train_class_to_indices[class_id])
            val_count = len(self.val_class_to_indices[class_id])

            if train_count < self.n_shot:
                raise ValueError(f"Class {class_id} has only {train_count} samples in train dataset, "
                               f"but {self.n_shot} shots required")

            if val_count < self.n_query:
                raise ValueError(f"Class {class_id} has only {val_count} samples in val dataset, "
                               f"but {self.n_query} queries required")

    def sample_episode(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Sample one episode with support from train_dataset and query from val_dataset

        Returns:
            support_data: (n_way * n_shot, ...) tensor of support samples
            support_labels: (n_way * n_shot,) tensor of support labels
            query_data: (n_way * n_query, ...) tensor of query samples
            query_labels: (n_way * n_query,) tensor of query labels
        """
        # Sample n_way classes
        episode_classes = random.sample(self.common_classes, self.n_way)

        support_data = []
        support_labels = []
        query_data = []
        query_labels = []

        for new_label, class_id in enumerate(episode_classes):
            # Sample support samples from training dataset
            train_indices = random.sample(self.train_class_to_indices[class_id], self.n_shot)
            for idx in train_indices:
                if self.train_dataset.return_metadata:
                    data, _, _ = self.train_dataset[idx]  # Skip metadata
                else:
                    data, _ = self.train_dataset[idx]  # Skip metadata
                support_data.append(data)
                support_labels.append(new_label)

            # Sample query samples from validation dataset
            val_indices = random.sample(self.val_class_to_indices[class_id], self.n_query)
            for idx in val_indices:
                if self.val_dataset.return_metadata:
                    data, _, _ = self.val_dataset[idx]  # Skip metadata
                else:
                    data, _ = self.val_dataset[idx]  # Skip metadata
                query_data.append(data)
                query_labels.append(new_label)

        # Convert to tensors
        support_data = torch.stack(support_data)
        support_labels = torch.tensor(support_labels)
        query_data = torch.stack(query_data)
        query_labels = torch.tensor(query_labels)

        return support_data, support_labels, query_data, query_labels, episode_classes

    def __iter__(self):
        """Make the sampler iterable"""
        return self

    def get_class_info(self) -> Dict[str, Any]:
        """Get information about classes in both datasets"""
        return {
            'train_classes': self.train_dataset.classes,
            'val_classes': self.val_dataset.classes,
            'common_classes_count': len(self.common_classes),
            'common_class_names': [self.train_dataset.classes[i] for i in self.common_classes],
            'train_label2idx': self.train_dataset.label2idx,
            'val_label2idx': self.val_dataset.label2idx,
        }

    def __next__(self):
        """Generate next episode"""
        return self.sample_episode()


class CrossDatasetEpisodeDataset(Dataset):
    """
    PyTorch Dataset wrapper for cross-dataset episode sampling
    """

    def __init__(self,
                 sampler: CrossDatasetEpisodeSampler,
                 num_episodes: int):
        """
        Args:
            sampler: CrossDatasetEpisodeSampler instance
            num_episodes: Number of episodes in the dataset
        """
        self.sampler = sampler
        self.num_episodes = num_episodes

    def __len__(self):
        return self.num_episodes

    def __getitem__(self, idx):
        return self.sampler.sample_episode()

In [13]:
n_way = 64
n_shot = 10
n_query = 1
n_tasks = 960

train_sampler = TaskSampler(
    dataset=datasets["train"],
    n_way=n_way,
    n_shot=n_shot,
    n_query=5,
    n_tasks=n_tasks
)

train_dataloader = DataLoader(
    datasets["train"],
    batch_sampler=train_sampler,
    collate_fn=train_sampler.episodic_collate_fn
)

val_sampler = CrossDatasetEpisodeSampler(
    train_dataset=datasets.get("train"),
    val_dataset=datasets.get("val"),
    n_way=n_way,
    n_shot=n_shot,
    n_query=n_query
)

val_dataset = CrossDatasetEpisodeDataset(val_sampler, 24)

In [14]:
from hcmus.models.backbone import CLIPBackbone
from hcmus.models.backbone import DinoBackbone

In [8]:
backbone_list = [
    # (DinoBackbone, {"model_id": "facebook/dinov2-small"}),
    # (DinoBackbone, {"model_id": "facebook/dinov2-base"}),
    # (DinoBackbone, {"model_id": "facebook/dino-vitb8"}),
    # (DinoBackbone, {"model_id": "facebook/dino-vits8"}),
    # (DinoBackbone, {"model_id": "facebook/dino-vits16"}),
    # (DinoBackbone, {"model_id": "facebook/dino-vitb16"}),
    (CLIPBackbone, {"backbone_name": "ViT-B/32"}),
    (CLIPBackbone, {"backbone_name": "ViT-B/16"}),
]

In [15]:
import mlflow
def get_or_create_experiment() -> int:
    name = "/PrototypicalNetworks"
    try:
        mlflow.create_experiment(name)
    except:
        pass

    return mlflow.get_experiment_by_name(name).experiment_id

In [16]:
from loguru import logger
from tqdm import tqdm

In [17]:
from torch import nn
from torch import optim
from hcmus.models.prototype import PrototypicalNetwork
from hcmus.models.prototype import PrototypicalTrainer

In [None]:
experiment_id = get_or_create_experiment()
for cls, params in backbone_list:
    lr=1e-2
    with mlflow.start_run(experiment_id=experiment_id):
        mlflow.log_params(params)
        mlflow.log_param("lr", lr)
        mlflow.log_params({
            "n_way": n_way,
            "n_shot": n_shot,
            "n_tasks": n_tasks
        })
        mlflow.log_param("optim", "adam")
        mlflow.log_param("criterion", "cross_entropy")

        backbone = cls(**params)
        model = PrototypicalNetwork(backbone)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()
        trainer = PrototypicalTrainer(model, optimizer, criterion)

        eval_every_n_step = 16
        for step, batch in enumerate(train_dataloader):
            support_data, support_labels, query_data, query_labels, classes = batch
            loss, acc = trainer.train_episode(support_data, support_labels, query_data, query_labels)
            logger.info(f"Train loss={loss}, acc={acc}")
            mlflow.log_metric("train_loss", loss, step=step)
            mlflow.log_metric("train_acc", acc, step=step)

            if (step + 1) % eval_every_n_step == 0:
                val_loss = 0
                val_acc = 0
                n_step = len(val_dataset)
                for idx in tqdm(range(n_step), desc="Evaluating..."):
                    batch = val_dataset[idx]
                    support_data, support_labels, query_data, query_labels, classes = batch
                    loss, acc = trainer.evaluate_episode(support_data, support_labels, query_data, query_labels)
                    val_loss += loss
                    val_acc += acc
                logger.info(f"Validation loss={val_loss/n_step}, acc={val_acc/n_step}")
                mlflow.log_metric("val_loss", val_loss/n_step, step=step)
                mlflow.log_metric("val_acc", val_acc/n_step, step=step)

[32m2025-07-22 21:02:30.170[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=4.128411293029785, acc=0.14687499403953552[0m
[32m2025-07-22 21:02:53.079[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=4.095317363739014, acc=0.203125[0m
[32m2025-07-22 21:03:17.692[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=4.05486536026001, acc=0.28125[0m
[32m2025-07-22 21:03:40.052[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=4.006790637969971, acc=0.2593750059604645[0m
[32m2025-07-22 21:04:01.918[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=3.9507012367248535, acc=0.30000001192092896[0m
[32m2025-07-22 21:04:23.902[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=3.9213550090789795, acc=0.27812498807907104[0m
[32m2025-07-22 21:04:48.59

🏃 View run delicate-kite-926 at: http://jimica.ddns.net:5050/#/experiments/1/runs/0fd0d594b7bc41fc85d4fae7b0692e50
🧪 View experiment at: http://jimica.ddns.net:5050/#/experiments/1


[32m2025-07-23 09:45:45.055[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=4.126090049743652, acc=0.17499999701976776[0m
[32m2025-07-23 09:46:34.932[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=4.098790168762207, acc=0.26249998807907104[0m
[32m2025-07-23 09:47:24.931[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=4.039800643920898, acc=0.3812499940395355[0m
[32m2025-07-23 09:48:15.877[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=4.010291576385498, acc=0.32499998807907104[0m
[32m2025-07-23 09:49:04.982[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=3.940647840499878, acc=0.38749998807907104[0m
[32m2025-07-23 09:49:54.862[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mTrain loss=3.873452663421631, acc=0.37812501192092896[0m
[32m

🏃 View run salty-shoat-295 at: http://jimica.ddns.net:5050/#/experiments/1/runs/c9a64f8e7edd4cdfaf7955ffab2b73c6
🧪 View experiment at: http://jimica.ddns.net:5050/#/experiments/1
