# 1. Introuction

 This notebook outlines the creation, compilation, and training of a Swin Tranformer network to classify 101 types of food. To this end, the **distillation technique** is applied to learn from a larger, pre-trained transformer model, especifically, a ViT-Base/16-384 transformer.

# 2. Importing Libraries

In [None]:
import os
import torch
import torchvision
import torch.backends.cudnn as cudnn

from torchvision.transforms import v2
from torchinfo import summary
from pathlib import Path
from torchvision import datasets
from torch.optim.lr_scheduler import CosineAnnealingLR, ConstantLR, SequentialLR, CosineAnnealingWarmRestarts
from tqdm import tqdm

# Import custom libraries
from utils.classification_utils import set_seeds, display_random_images
from engines.classification import ClassificationEngine, Common
from dataloaders.image_dataloaders import create_distillation_dataloaders, create_dataloaders
from engines.loss_functions import DistillationLoss

# Dataset
from datasets import load_dataset

import warnings
os.environ['TORCH_USE_CUDA_DSA'] = "1"
warnings.filterwarnings("ignore", category=UserWarning, module="torch.autograd.graph")
warnings.filterwarnings("ignore", category=FutureWarning, module="onnxscript.converter")

# 3. Importing Dataset

In [None]:
# Define some constants
NUM_WORKERS = os.cpu_count()
AMOUNT_TO_GET = 1.0
SEED = 42

# Define target data directory
TARGET_DIR_NAME = f"data/food-101_{str(int(AMOUNT_TO_GET*100))}_percent"

# Setup training and test directories
TARGET_DIR = Path(TARGET_DIR_NAME)
TRAIN_DIR = TARGET_DIR / "train"
TEST_DIR = TARGET_DIR / "test"
TARGET_DIR.mkdir(parents=True, exist_ok=True)

# Create target model directory
MODEL_DIR = Path("outputs")

# Set seeds
set_seeds(SEED)

IMPORT_DATASET = False

In [None]:
if IMPORT_DATASET:
    # Download dataset from Hugging Face
    ds = load_dataset("ethz/food101")

In [None]:
if IMPORT_DATASET:
    # Get class names
    class_names = ds["train"].features["label"].names

    # Function to save images into appropriate directories
    def save_images(split, target_dir):
        for example in tqdm(ds[split], desc=f"Saving {split} images"):
            image = example["image"]
            label = example["label"]
            class_name = class_names[label]

            # Define class directory
            class_dir = target_dir / class_name
            class_dir.mkdir(parents=True, exist_ok=True)

            # Save image
            img_path = class_dir / f"{len(list(class_dir.iterdir()))}.jpg"
            image.save(img_path)

    # Save training and test images
    save_images("train", TRAIN_DIR)
    save_images("validation", TEST_DIR)

    print("Dataset has been saved successfully!")

# 4. Specifying Target Device

In [None]:
# Activate cuda benchmark
cudnn.benchmark = True

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

#if device == "cuda":
#    !nvidia-smi

# 5. Image Visualization

In [None]:
# Display images
manual_transforms = v2.Compose([
    v2.Resize((256)),
    v2.RandomCrop((256, 256)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

train_data = datasets.ImageFolder(TRAIN_DIR, transform=manual_transforms)
display_random_images(train_data,
                      n=25,
                      classes=train_data.classes,
                      rows=5,
                      cols=5,
                      display_shape=False,
                      seed=None)

# 6. Creating Teacher and Student - 101 Classes

To download a teacher model example, click [here](https://drive.google.com/file/d/1uHQ9WotGHh6suSMvkyO0vLj1O8XlJ_yE/view?usp=sharing).

In [None]:
# Specify transformations
IMG_SIZE_TCH = 384
IMG_SIZE_STD = 256
BATCH_SIZE = 16

transform_train_tch = v2.Compose([    
    v2.TrivialAugmentWide(),
    v2.Resize((IMG_SIZE_TCH)),
    v2.CenterCrop((IMG_SIZE_TCH, IMG_SIZE_TCH)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

transform_test_tch = v2.Compose([    
    v2.Resize((IMG_SIZE_TCH)),
    v2.CenterCrop((IMG_SIZE_TCH, IMG_SIZE_TCH)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

transform_train_std = v2.Compose([    
    v2.TrivialAugmentWide(),
    v2.Resize((260)),
    v2.RandomCrop((IMG_SIZE_STD, IMG_SIZE_STD)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

transform_test_std = v2.Compose([    
    v2.Resize((260)),
    v2.CenterCrop((IMG_SIZE_STD, IMG_SIZE_STD)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

# Create data loaders
train_dataloader, test_dataloader, class_names = create_distillation_dataloaders(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    transform_student_train=transform_train_std,
    transform_teacher_train=transform_train_tch,
    transform_student_test=transform_test_std,
    transform_teacher_test=transform_test_tch,
    batch_size=BATCH_SIZE
)

dataloaders = {
    'train': train_dataloader,
    'test':  test_dataloader
}

# Load ViT-Base/16-384. Run classification_example.ipynb to genereate the teacher model. 
model_tch_type="teacher_model"
model_tch_name = model_tch_type + ".pth"

# Instantiate the model
NUM_CLASSES = len(class_names)
model_tch = torchvision.models.vit_b_16(image_size=IMG_SIZE_TCH).to(device)
model_tch.heads = torch.nn.Linear(in_features=768, out_features=NUM_CLASSES).to(device)
model_tch = torch.compile(model_tch, backend="aot_eager")

# Load the trained weights
model_tch = Common().load_model(
    model=model_tch,
    target_dir=MODEL_DIR,
    model_name=model_tch_name)

# Copy weights from torchvision.models
set_seeds(SEED)

# Instantiate the model: Swin Transformer V2-Tiny
model_std = torchvision.models.swin_v2_t(weights=torchvision.models.Swin_V2_T_Weights.DEFAULT)
model_std.head = torch.nn.Linear(in_features=768, out_features=NUM_CLASSES).to(device)

# Unfreeze the base parameters
for parameter in model_std.parameters():
    parameter.requires_grad = True

# Print summary
#summary(model_tch,
#        input_size=(BATCH_SIZE,3,IMG_SIZE_TCH, IMG_SIZE_TCH),
#        col_names=["input_size", "output_size", "num_params", "trainable"],
#        col_width=20,
#        row_settings=["var_names"])

# Print summary
summary(model_std,
        input_size=(BATCH_SIZE,3,IMG_SIZE_STD, IMG_SIZE_STD),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

# 7. Training the Model

In [None]:
# Train the model
model_std_type="student_model"
model_std_name = model_std_type + ".pth"

# Epochs and learning rate
EPOCHS = 30
LR = 1e-4
MIN_LR = 1e-6

# Create optimizer
optimizer = torch.optim.AdamW(
    params=model_std.parameters(),
    lr=LR,
    betas=(0.9, 0.999),
    weight_decay=0.01
)

# Create loss function
loss_fn = DistillationLoss(alpha=0.1, temperature=2, label_smoothing=0.1)

# Initialize the scheduler
cosine = CosineAnnealingLR(optimizer, T_max=20, eta_min=MIN_LR) # 1-20:  LR = LR -> MIN_LR (cosine)
fixed = ConstantLR(optimizer, factor=MIN_LR/LR, total_iters=10) # 20-30: LR = MIN_LR
scheduler = SequentialLR(
    optimizer,
    schedulers=[cosine, fixed],
    milestones=[20] 
)

# Set seeds
set_seeds(SEED)

# And train...

# Instantiate the engine
engine = ClassificationEngine(
    model=model_std,                                # Model to be trained
    model_teacher=model_tch,                        # Teacher model if knowledge distillation training is enabled (use_distillation=True)
    use_distillation=True,                          # Whether training uses knowledge distillation, then model_teacher is required    
    optimizer=optimizer,                            # Optimizer
    loss_fn=loss_fn,                                # Loss function
    scheduler=scheduler,                            # Scheduler
    theme='dark',                                   # Theme
    device=device                                   # Target device
    )

# Configure the training method
results = engine.train(
    target_dir=MODEL_DIR,                           # Directory where the model will be saved
    model_name=model_std_name,                      # Name of the student model    
    resume=True,                                    # Resume training from the last saved checkpoint
    dataloaders=dataloaders,                        # Dictionary with the dataloaders
    save_best_model=["last", "loss", "acc", "f1"],  # Save the best models based on different criteria
    keep_best_models_in_memory=False,               # Do not keep the models stored in memory for the sake of training time and memory efficiency    
    apply_validation=True,                          # Enable validation step
    augmentation_strategy="always",                 # Augmentation strategy
    recall_threshold=0.995,                         # False   positive rate at recall_threshold recall
    recall_threshold_pauc=0.95,                     # Partial AUC score above recall_threshold_pauc recall
    epochs=EPOCHS,                                  # Total number of epochs
    amp=True,                                       # Enable Automatic Mixed Precision (AMP)
    enable_clipping=False,                          # Disable clipping on gradients, only useful if training becomes unestable
    debug_mode=False,                               # Disable debug mode    
    accumulation_steps=4,                           # Accumulation steps 4: effective batch size = batch_size x accumulation steps
    )

## 7.1. (Optional) Retraining the Last Best-performing Model

In [None]:
# Specify transformations
IMG_SIZE_TCH = 384
IMG_SIZE_STD = 224
BATCH_SIZE = 32

transform_train_tch = v2.Compose([    
    v2.TrivialAugmentWide(),
    v2.Resize((IMG_SIZE_TCH)),
    v2.CenterCrop((IMG_SIZE_TCH, IMG_SIZE_TCH)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

transform_test_tch = v2.Compose([    
    v2.Resize((IMG_SIZE_TCH)),
    v2.CenterCrop((IMG_SIZE_TCH, IMG_SIZE_TCH)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

transform_train_std = v2.Compose([    
    v2.TrivialAugmentWide(),
    v2.Resize((256)),
    v2.RandomCrop((IMG_SIZE_STD, IMG_SIZE_STD)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

transform_test_std = v2.Compose([    
    v2.Resize((256)),
    v2.CenterCrop((IMG_SIZE_STD, IMG_SIZE_STD)),    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]) 
])

# Create data loaders
train_dataloader, test_dataloader, class_names = create_distillation_dataloaders(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    transform_student_train=transform_train_std,
    transform_teacher_train=transform_train_tch,
    transform_student_test=transform_test_std,
    transform_teacher_test=transform_test_tch,
    batch_size=BATCH_SIZE
)

dataloaders = {
    'train': train_dataloader,
    'test':  test_dataloader
}

# Load ViT-Base/16-384. Run classification_example.ipynb to genereate the teacher model. 
model_tch_type="teacher_model"
model_tch_name = model_tch_type + ".pth"

# Instantiate the teacher model
NUM_CLASSES = len(class_names)
model_tch = torchvision.models.vit_b_16(image_size=IMG_SIZE_TCH).to(device)
model_tch.heads = torch.nn.Linear(in_features=768, out_features=NUM_CLASSES).to(device)
model_tch = torch.compile(model_tch, backend="aot_eager")

# Load the trained weights
model_tch = Common().load_model(
    model=model_tch,
    target_dir=MODEL_DIR,
    model_name=model_tch_name)

# Instantiate the base student model: Swin Transformer V2-Tiny
model_std = torchvision.models.swin_v2_t(weights=torchvision.models.Swin_V2_T_Weights.DEFAULT)
model_std.head = torch.nn.Linear(in_features=768, out_features=NUM_CLASSES).to(device)

# Unfreeze the base parameters
for parameter in model_std.parameters():
    parameter.requires_grad = True

# Load the trained weights
model_std = Common().load_model(
    model=model_std,
    target_dir=MODEL_DIR,
    model_name="student_model2_acc_epoch25.pth") # Modify the file name if necessary


In [None]:
# Train model
model_std_type="student_model_retrained"
model_std_name = model_std_type + ".pth"

# Epochs and learning rate
EPOCHS = 30
LR = 1e-4
MIN_LR = 1e-6

# Create optimizer
optimizer = torch.optim.AdamW(
    params=model_std.parameters(),
    lr=LR,
    betas=(0.9, 0.999),
    weight_decay=0.01
)

# Create loss function
loss_fn = DistillationLoss(alpha=0.1, temperature=2, label_smoothing=0.1)

# Initialize the scheduler
#scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, last_epoch=-1, eta_min=MIN_LR)
cosine = CosineAnnealingLR(optimizer, T_max=20, eta_min=MIN_LR) # 1-20:  LR = LR -> MIN_LR (cosine)
fixed = ConstantLR(optimizer, factor=MIN_LR/LR, total_iters=10) # 20-30: LR = MIN_LR
scheduler = SequentialLR(
    optimizer,
    schedulers=[cosine, fixed],
    milestones=[20] 
)

# And train...

# Instantiate the engine
engine = ClassificationEngine(
    model=model_std,                                # Model to be trained
    model_teacher=model_tch,                        # Teacher model if knowledge distillation training is enabled (use_distillation=True)
    use_distillation=True,                          # Whether training uses knowledge distillation, then model_teacher is required    
    optimizer=optimizer,                            # Optimizer
    loss_fn=loss_fn,                                # Loss function
    scheduler=scheduler,                            # Scheduler
    color_map={"train": 'blue', "test": "orange"},  # Color map
    theme='dark',                                   # Color theme
    device=device                                   # Target device
    )

# Configure the training method
results = engine.train(
    target_dir=MODEL_DIR,                           # Directory where the model will be saved
    model_name=model_std_name,                      # Name of the student model    
    resume=False,                                   # Resume training from the last saved checkpoint
    dataloaders=dataloaders,                        # Dictionary with the dataloaders
    save_best_model=["last", "loss", "acc", "f1"],  # Save the best models based on different criteria
    keep_best_models_in_memory=False,               # Do not keep the models stored in memory for the sake of training time and memory efficiency    
    apply_validation=True,                          # Enable validation step
    augmentation_strategy="always",                 # Augmentation strategy
    recall_threshold=0.995,                         # False   positive rate at recall_threshold recall
    recall_threshold_pauc=0.95,                     # Partial AUC score above recall_threshold_pauc recall
    epochs=EPOCHS,                                  # Total number of epochs
    amp=True,                                       # Enable Automatic Mixed Precision (AMP)
    enable_clipping=False,                          # Disable clipping on gradients, only useful if training becomes unestable
    debug_mode=False,                               # Disable debug mode    
    accumulation_steps=2,                           # Accumulation steps 4: effective batch size = batch_size x accumulation steps
    )

# 8. Making Predictions

In [None]:
# Create dataloader for the student
_, test_dataloader, _ = create_dataloaders(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    train_transform=transform_train_std,
    test_transform=transform_test_std,
    batch_size=BATCH_SIZE
)

# Make predictions with the student
model_std = torchvision.models.swin_v2_t(weights=torchvision.models.Swin_V2_T_Weights.DEFAULT)
model_std.head = torch.nn.Linear(in_features=768, out_features=NUM_CLASSES).to(device)
preds_std = ClassificationEngine(
    model=model_std,    
    device=device).load(
        target_dir=MODEL_DIR,
        model_name="student_model.pth" # Only parameters, not the enginer model. Modify the file name if necessary.
    ).predict(
        dataloader=test_dataloader,
        output_type="argmax"
        )

In [None]:
# Create dataloader for the teacher
_, test_dataloader, _ = create_dataloaders(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    train_transform=transform_train_tch,
    test_transform=transform_test_tch,
    batch_size=BATCH_SIZE
)

# Make predictions with the student
model_tch = torchvision.models.vit_b_16(image_size=IMG_SIZE_TCH).to(device)
model_tch.heads = torch.nn.Linear(in_features=768, out_features=NUM_CLASSES).to(device)
model_tch = torch.compile(model_tch, backend="aot_eager")

preds_tch = ClassificationEngine(
    model=model_tch,    
    device=device).load(
        target_dir=MODEL_DIR,
        model_name="teacher_model.pth" # Only parameters, not the enginer model
    ).predict(
        dataloader=test_dataloader, #[image_std, image_tch, class]
        output_type="argmax",
        )

In [None]:
# Compare results
matches = preds_std == preds_tch
agreement = matches.float().mean().item()
print(f"Student vs Teacher agree on {agreement*100:.1f}% of samples")