# 1. Introuction

 This notebook outlines the creation, compilation, and training of a Swin Tranformer network to classify 101 types of food. To this end, the **distillation technique** is applied to learn from a larger, pre-trained transformer model, especifically, a ViT-Base/16-384 transformer.

# 2. Importing Libraries

In [2]:
import os
import torch
import torch.backends.cudnn as cudnn

from torchvision import datasets
from torchvision.transforms import v2
from torchinfo import summary
from pathlib import Path
from torch.optim.lr_scheduler import CosineAnnealingLR, ConstantLR, SequentialLR, CosineAnnealingWarmRestarts
from tqdm import tqdm

# Import custom libraries
from utils.classification_utils import display_random_images_classification
from utils.common_utils import set_seeds
from engines.classification import ClassificationEngine, Common
from engines.loss_functions import DistillationLossForDeiT
from dataloaders.image_dataloaders import create_distillation_dataloaders, create_classification_dataloaders
from models.pretrained_models import build_pretrained_model
from models.deit import DeiT

# Dataset
from datasets import load_dataset

import warnings
os.environ['TORCH_USE_CUDA_DSA'] = "1"
warnings.filterwarnings("ignore", category=UserWarning, module="torch.autograd.graph")
warnings.filterwarnings("ignore", category=FutureWarning, module="onnxscript.converter")

# 3. Importing Dataset

The dataset should be organized as follows, with one subdirectory per class containing the corresponding images:

```
dataset/
├── train/
│   └── <class_label>/
│       ├── img1.jpg
│       ├── img2.png
│       └── ...
└── test/ (or val/)/
    └── <class_label>/
        ├── img1.jpg
        ├── img2.png
        └── ...
```

In [None]:
# Define some constants
NUM_WORKERS = os.cpu_count()
AMOUNT_TO_GET = 1.0
SEED = 42
THEME = 'light' # or 'dark'. Default is 'light'

# Define target data directory
TARGET_DIR_NAME = f"data/food-101_{str(int(AMOUNT_TO_GET*100))}_percent"

# Setup training and test directories
TARGET_DIR = Path(TARGET_DIR_NAME)
TRAIN_DIR = TARGET_DIR / "train"
TEST_DIR = TARGET_DIR / "test"
TARGET_DIR.mkdir(parents=True, exist_ok=True)

# Create target model directory
MODEL_DIR = Path("outputs")

# Set seeds
set_seeds(SEED)

IMPORT_DATASET = False

In [None]:
if IMPORT_DATASET:
    # Download dataset from Hugging Face
    ds = load_dataset("ethz/food101")

In [None]:
if IMPORT_DATASET:
    # Get class names
    class_names = ds["train"].features["label"].names

    # Function to save images into appropriate directories
    def save_images(split, target_dir):
        for example in tqdm(ds[split], desc=f"Saving {split} images"):
            image = example["image"]
            label = example["label"]
            class_name = class_names[label]

            # Define class directory
            class_dir = target_dir / class_name
            class_dir.mkdir(parents=True, exist_ok=True)

            # Save image
            img_path = class_dir / f"{len(list(class_dir.iterdir()))}.jpg"
            image.save(img_path)

    # Save training and test images
    save_images("train", TRAIN_DIR)
    save_images("validation", TEST_DIR)

    print("Dataset has been saved successfully!")

# 4. Specifying Target Device

In [None]:
# Activate cuda benchmark
cudnn.benchmark = True

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

#if device == "cuda":
#    !nvidia-smi

# 5. Image Visualization

In [None]:
# Display images
manual_transforms = v2.Compose([
    v2.Resize((256)),
    v2.RandomCrop((256, 256)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

train_data = datasets.ImageFolder(TRAIN_DIR, transform=manual_transforms)
display_random_images_classification(
    train_data,
    n=25,
    classes=train_data.classes,
    rows=5,
    cols=5,
    display_shape=False,
    seed=None,
    theme=THEME)

# 6. Creating Teacher and Student - 101 Classes

To download a teacher model example, click [here](https://drive.google.com/file/d/1uHQ9WotGHh6suSMvkyO0vLj1O8XlJ_yE/view?usp=sharing).

In [None]:
# Specify transformations
IMG_SIZE_TCH = 384
IMG_SIZE_STD = 224
BATCH_SIZE = 16

transform_train_tch = v2.Compose([    
    v2.TrivialAugmentWide(),
    v2.Resize((IMG_SIZE_TCH)), # According to the PyTorch documentation
    v2.CenterCrop((IMG_SIZE_TCH, IMG_SIZE_TCH)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

transform_test_tch = v2.Compose([    
    v2.Resize((IMG_SIZE_TCH)), # According to the PyTorch documentation
    v2.CenterCrop((IMG_SIZE_TCH, IMG_SIZE_TCH)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

transform_train_std = v2.Compose([    
    v2.TrivialAugmentWide(),
    v2.Resize((256)), # According to the PyTorch documentation
    v2.RandomCrop((IMG_SIZE_STD, IMG_SIZE_STD)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

transform_test_std = v2.Compose([    
    v2.Resize((256)), # According to the PyTorch documentation
    v2.CenterCrop((IMG_SIZE_STD, IMG_SIZE_STD)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

# Create data loaders
train_dataloader, test_dataloader, class_names = create_distillation_dataloaders(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    transform_student_train=transform_train_std,
    transform_teacher_train=transform_train_tch,
    transform_student_test=transform_test_std,
    transform_teacher_test=transform_test_tch,
    batch_size=BATCH_SIZE
)

dataloaders = {
    'train': train_dataloader,
    'test':  test_dataloader
}

# Compute number of classees
NUM_CLASSES = len(class_names)

# Load the teacher model: ViT-Base/16-384 
model_tch_type="teacher_model"
model_tch_name = model_tch_type + ".pth"
model_tch = build_pretrained_model(
    model="vit_b_16_384",
    output_dim=NUM_CLASSES,
    seed=SEED,
    freeze_backbone=False,
    device=device
    )
model_tch = torch.compile(model_tch, backend="aot_eager") # Compilation needed as the teacher model was trained with torch.compile

# Load the trained weights
model_tch = Common().load_model(
    model=model_tch,
    target_dir=MODEL_DIR,
    model_name=model_tch_name)

In [None]:
# Instantiate the student model
model_std = DeiT(
    img_size=IMG_SIZE_STD,
    in_channels=3,
    patch_size=16,
    num_transformer_layers=12,
    emb_dim=192,
    mlp_size=768,
    num_heads=3,
    attn_dropout=0,
    mlp_dropout=0.1,
    emb_dropout=0.05,
    drop_path_rate=0.1,
    num_classes=NUM_CLASSES
)

# Print model_tch summary
display(summary(model_tch,
        input_size=(BATCH_SIZE,3,IMG_SIZE_TCH, IMG_SIZE_TCH),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]))

# Print model_std summary
display(summary(model_std,
        input_size=(BATCH_SIZE,3,IMG_SIZE_STD, IMG_SIZE_STD),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]))

# 7. Training the Model

In [None]:
# Train the model
model_std_type="student_deit"
model_std_name = model_std_type + ".pth"

# Epochs and learning rate
EPOCHS = 100
LR = 3e-4
MIN_LR = 1e-6
WARMUP_EPOCHS = 5
ALPHA_MAX = 0.7
ALPHA_RAMP_EPOCHS = 10
T = 4.0

# Create optimizer
optimizer = torch.optim.AdamW(
    params=model_std.parameters(),
    lr=LR,
    betas=(0.9, 0.95),
    weight_decay=0.05
)

# Create loss function with dynamic alpha
# Alpha schedule list (length = EPOCHS)
alpha_schedule = [
    1.0 - (1.0 - ALPHA_MAX) * min(1.0, e / ALPHA_RAMP_EPOCHS)  # linear 1.0 -> 0.7
    for e in range(EPOCHS)
]

# This loss function receives five arguments:
# 1. Student's classification [CLS] output
# 2. Student's distillations [DST] output
# 3. Teacher's classificaiton output
# 4. Ground-truth
# 5. Actual epoch for epoch-base alpha schedule; for constant alpha over epochs, it is just a float
loss_fn = DistillationLossForDeiT(alpha=alpha_schedule, temperature=T, label_smoothing=0.1)

# Initialize the scheduler
warmup = torch.optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1e-3, total_iters=WARMUP_EPOCHS
)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=EPOCHS - WARMUP_EPOCHS, eta_min=MIN_LR
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer, schedulers=[warmup, cosine], milestones=[WARMUP_EPOCHS]
)

# Set seeds
set_seeds(SEED)

# And train...

# Instantiate the engine
engine = ClassificationEngine(
    model=model_std,                                # Model to be trained
    model_teacher=model_tch,                        # Teacher model if knowledge distillation training is enabled (use_distillation=True)
    use_distillation=True,                          # Whether training uses knowledge distillation, then model_teacher is required    
    optimizer=optimizer,                            # Optimizer
    loss_fn=loss_fn,                                # Loss function
    scheduler=scheduler,                            # Scheduler
    theme=THEME,                                    # Theme
    device=device                                   # Target device
    )

# Configure the training method
results = engine.train(
    target_dir=MODEL_DIR,                           # Directory where the model will be saved
    model_name=model_std_name,                      # Name of the student model    
    enable_resume=True,                             # Resume training from the last saved checkpoint
    dataloaders=dataloaders,                        # Dictionary with the dataloaders
    save_best_model=["last", "loss", "acc", "f1"],  # Save the best models based on different criteria
    keep_best_models_in_memory=False,               # If False: do not keep the models stored in memory for the sake of training time and memory efficiency
    apply_validation=True,                          # Enable validation step
    augmentation_strategy="always",                 # Augmentation strategy
    recall_threshold=0.995,                         # False   positive rate at recall_threshold recall
    recall_threshold_pauc=0.95,                     # Partial AUC score above recall_threshold_pauc recall
    epochs=EPOCHS,                                  # Total number of epochs
    amp=True,                                       # Enable Automatic Mixed Precision (AMP)
    enable_clipping=False,                          # Disable clipping on gradients, only useful if training becomes unestable
    debug_mode=False,                               # Disable debug mode    
    accumulation_steps=4,                           # Accumulation steps 4: effective batch size = batch_size x accumulation steps
    )

## 7.1. (Optional) Retraining the Last Best-performing Model

In [None]:
# Instantiate the student model
model_std = DeiT(
    img_size=IMG_SIZE_STD,
    in_channels=3,
    patch_size=16,
    num_transformer_layers=12,
    emb_dim=192,
    mlp_size=768,
    num_heads=3,
    attn_dropout=0,
    mlp_dropout=0.1,
    emb_dropout=0.05,
    drop_path_rate=0.1,
    num_classes=NUM_CLASSES
)

# Load the trained weights for the student model
import glob
import re
model_paths = glob.glob(os.path.join(MODEL_DIR, "student_deit_acc_epoch*.pth"))
assert len(model_paths) > 0, "No matching model files found"
def extract_epoch(path):
    fname = os.path.basename(path)
    match = re.search(r"epoch(\d+)", fname)
    if match is None:
        raise ValueError(f"Could not parse epoch from {fname}")
    return int(match.group(1))

best_path = max(model_paths, key=extract_epoch)
model_name = os.path.basename(best_path)

model_std = Common().load_model(
    model=model_std,
    target_dir=MODEL_DIR,
    model_name=model_name
)

In [None]:
# Train the model
model_std_type="student_deit_pretrained"
model_std_name = model_std_type + ".pth"

# Epochs and learning rate
EPOCHS = 50
LR = 1e-4
MIN_LR = 1e-6
WARMUP_EPOCHS = 2
WARMUP_START_FACTOR = 0.1
ALPHA_MAX = 0.7
ALPHA_RAMP_EPOCHS = 10
T = 4.0

# Create optimizer
optimizer = torch.optim.AdamW(
    params=model_std.parameters(),
    lr=LR,
    betas=(0.9, 0.95),
    weight_decay=0.05
)

# Create loss function with dynamic alpha
# Alpha schedule list (length = EPOCHS)
alpha_schedule = [
    1.0 - (1.0 - ALPHA_MAX) * min(1.0, e / ALPHA_RAMP_EPOCHS)
    for e in range(EPOCHS)
]

# This loss function receives five arguments:
# 1. Student's classification [CLS] output
# 2. Student's distillations [DST] output
# 3. Teacher's classificaiton output
# 4. Ground-truth
# 5. Actual epoch for epoch-base alpha schedule; for constant alpha over epochs, it is just a float
loss_fn = DistillationLossForDeiT(alpha=alpha_schedule, temperature=T, label_smoothing=0.1)

# Initialize the scheduler
warmup = torch.optim.lr_scheduler.LinearLR(
    optimizer, start_factor=WARMUP_START_FACTOR, total_iters=WARMUP_EPOCHS
)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=EPOCHS - WARMUP_EPOCHS, eta_min=MIN_LR
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer, schedulers=[warmup, cosine], milestones=[WARMUP_EPOCHS]
)

# Set seeds
set_seeds(SEED)

# And train...

# Instantiate the engine
engine = ClassificationEngine(
    model=model_std,                                # Model to be trained
    model_teacher=model_tch,                        # Teacher model if knowledge distillation training is enabled (use_distillation=True)
    use_distillation=True,                          # Whether training uses knowledge distillation, then model_teacher is required    
    optimizer=optimizer,                            # Optimizer
    loss_fn=loss_fn,                                # Loss function
    scheduler=scheduler,                            # Scheduler
    theme=THEME,                                    # Theme
    device=device                                   # Target device
    )

# Configure the training method
results = engine.train(
    target_dir=MODEL_DIR,                           # Directory where the model will be saved
    model_name=model_std_name,                      # Name of the student model    
    enable_resume=True,                             # Resume training from the last saved checkpoint
    dataloaders=dataloaders,                        # Dictionary with the dataloaders
    save_best_model=["last", "loss", "acc", "f1"],  # Save the best models based on different criteria
    keep_best_models_in_memory=False,               # If False: do not keep the models stored in memory for the sake of training time and memory efficiency
    apply_validation=True,                          # Enable validation step
    augmentation_strategy="always",                 # Augmentation strategy
    recall_threshold=0.995,                         # False   positive rate at recall_threshold recall
    recall_threshold_pauc=0.95,                     # Partial AUC score above recall_threshold_pauc recall
    epochs=EPOCHS,                                  # Total number of epochs
    amp=True,                                       # Enable Automatic Mixed Precision (AMP)
    enable_clipping=False,                          # Disable clipping on gradients, only useful if training becomes unestable
    debug_mode=False,                               # Disable debug mode    
    accumulation_steps=4,                           # Accumulation steps 4: effective batch size = batch_size x accumulation steps
    )

# 8. Making Predictions

In [None]:
# Create dataloader for the teacher
_, test_dataloader, _ = create_classification_dataloaders(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    train_transform=transform_train_tch,
    test_transform=transform_test_tch,
    batch_size=BATCH_SIZE
)

# Make predictions with the student
preds_tch = ClassificationEngine(
    model= torch.compile(build_pretrained_model(
        model="vit_b_16_384",
        output_dim=NUM_CLASSES,
        seed=SEED,
        freeze_backbone=False,
        device=device
        ), backend="aot_eager"),
    device=device).load(
        target_dir=MODEL_DIR,
        model_name="teacher_model.pth" # Only parameters, not the enginer model
    ).predict(
        dataloader=test_dataloader, #[image_std, image_tch, class]
        output_type="argmax",
        )

In [None]:
# Create dataloader for the student
_, test_dataloader, _ = create_classification_dataloaders(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    train_transform=transform_train_std,
    test_transform=transform_test_std,
    batch_size=BATCH_SIZE
)

# Make predictions with the student
preds_std = ClassificationEngine(
    model=build_pretrained_model(
        model="swin_v2_t",
        output_dim=NUM_CLASSES,
        seed=SEED,
        freeze_backbone=False,
        device=device
        ),
    device=device
    ).load(
        target_dir=MODEL_DIR,
        model_name="student_model.pth" # Only parameters, not the engine model. Modify the file name if necessary.
    ).predict(
        dataloader=test_dataloader,
        output_type="argmax"
        )

In [None]:
# Compare results
matches = preds_std == preds_tch
agreement = matches.float().mean().item()
print(f"Student vs Teacher agree on {agreement*100:.1f}% of samples")