<h1>Tiny ImageNet</h1>

In [1]:
import time
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import os
from sklearn.metrics import accuracy_score
from torchvision.models import vit_b_16
from torchvision.models.feature_extraction import create_feature_extractor

# Define the UnifiedStudentModel (Basic Block)
class UnifiedStudentModel(nn.Module):
    def __init__(self, vision_dim=256, teacher_output_dim=768, bottleneck_dim=128):
        super(UnifiedStudentModel, self).__init__()
        # Vision encoder with bottleneck
        self.vision_encoder = nn.Sequential(
            nn.Linear(teacher_output_dim, bottleneck_dim),
            nn.ReLU(),
            nn.Linear(bottleneck_dim, vision_dim),
        )
        self.logit_scale = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, teacher_features):
        # Normalize and project features
        vision_proj = self.vision_encoder(teacher_features)
        vision_proj = vision_proj / vision_proj.norm(dim=-1, keepdim=True)
        logits = self.logit_scale.exp() * vision_proj @ vision_proj.t()
        return logits

# Initialize Tiny ImageNet dataset loaders
def init_tiny_imagenet_data(data_dir, batch_size=32):
    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dir = os.path.join(data_dir, "train")
    train_dataset = ImageFolder(root=train_dir, transform=transform)

    val_dir = os.path.join(data_dir, "val/images")
    val_dataset = ImageFolder(root=val_dir, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader

# Save and load checkpoint functions
def save_checkpoint(student_model, optimizer, epoch, loss, checkpoint_dir, prefix):
    if checkpoint_dir and prefix:
        checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}{epoch + 1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': student_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

def load_checkpoint(checkpoint_dir, prefix, student_model, optimizer):
    start_epoch = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch']
            print(f"Resuming training from epoch {start_epoch}.")
    return start_epoch

# Compute model size
def compute_model_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024**2)

data_dir = "./data/tiny-imagenet-200"
train_loader, val_loader = init_tiny_imagenet_data(data_dir, batch_size=8)

<h2>BEiT</h2>

<h4>5 epochs</h4>

In [3]:
# Training function for the student model
def train_student_model_beit(
    teacher_model, student_model, train_loader, optimizer, 
    num_epochs=10, checkpoint_dir=None, prefix=None
):
    epoch_times = []
    start_epoch = 0
    if checkpoint_dir and prefix:
        start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    student_model.train()

    for epoch in range(start_epoch, num_epochs):
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        start_time = time.time()
        epoch_loss = 0.0

        for images, _ in train_loader:
            images = images.to(device)

            # Teacher model logits
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
                teacher_features = teacher_outputs.last_hidden_state.mean(dim=1)  # Global pooling

            # Student model logits
            student_logits = student_model(teacher_features)

            # Compute distillation loss
            loss = torch.nn.functional.cross_entropy(student_logits, torch.arange(len(student_logits)).to(device))

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        scheduler.step()

        print(f"Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss:.4f}")
        end_time = time.time()
        epoch_time = end_time - start_time
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds.")

        if checkpoint_dir and prefix:
            save_checkpoint(student_model, optimizer, epoch, epoch_loss, checkpoint_dir, prefix)
            
    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")
    return epoch_times


# Evaluation function
def evaluate_student_model_beit(student_model, teacher_model, val_loader, checkpoint_dir=None, prefix=None):
    if checkpoint_dir and prefix:
        print(f"Searching for checkpoints in {checkpoint_dir}...")
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])

    student_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)

            # Teacher model logits
            teacher_outputs = teacher_model(images)
            teacher_features = teacher_outputs.last_hidden_state.mean(dim=1)  # Global pooling

            # Student model logits
            student_logits = student_model(teacher_features)
            predictions = torch.argmax(student_logits, dim=-1)

            all_predictions.extend(predictions.cpu().tolist())
            all_labels.extend(torch.arange(len(predictions)).tolist())

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Evaluation Accuracy: {accuracy:.4f}")

print("Initialize student model...")
# Initialize student model
student_output_dim = 256
teacher_output_dim = 768
student_model = UnifiedStudentModel(vision_dim=student_output_dim, teacher_output_dim=teacher_output_dim)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4)

print("Load teacher model...")
# Load teacher model (DINO/BEiT)
from transformers import BeitModel
teacher_model = BeitModel.from_pretrained("microsoft/beit-base-patch16-224").to('cuda')

# Train the student model
checkpoint_dir = "./checkpoints"
prefix = "basic_block_beit_epoch_"

print("Training Student Model...")
epoch_times = train_student_model_beit(
    teacher_model, student_model, train_loader, optimizer, 
    num_epochs=5, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Evaluating Student Model...")
evaluate_student_model_beit(
    student_model, teacher_model, val_loader, 
    checkpoint_dir=checkpoint_dir, prefix=prefix
)

Initialize student model...
Load teacher model...
Training Student Model...
Loading checkpoint from ./checkpoints/basic_block_beit_epoch_10.pt...
Resuming training from epoch 10.
Student Model Size: 0.50 MB
Evaluating Student Model...
Searching for checkpoints in ./checkpoints...
Loading checkpoint from ./checkpoints/basic_block_beit_epoch_10.pt...


  checkpoint = torch.load(checkpoint_path)
  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 1.0000


<h4>10 epochs</h4>

In [3]:
print("Training Student Model...")
epoch_times = train_student_model_beit(
    teacher_model, student_model, train_loader, optimizer, 
    num_epochs=10, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Evaluating Student Model...")
evaluate_student_model_beit(
    student_model, teacher_model, val_loader, 
    checkpoint_dir=checkpoint_dir, prefix=prefix
)

Training Student Model...
Loading checkpoint from ./checkpoints/basic_block_beit_epoch_5.pt...
Resuming training from epoch 5.
Training epoch 6/10...


  checkpoint = torch.load(checkpoint_path)


Epoch 6/10 Loss: 0.2207
Epoch 6 completed in 1345.64 seconds.
Checkpoint saved at ./checkpoints/basic_block_beit_epoch_6.pt
Training epoch 7/10...
Epoch 7/10 Loss: 0.2700
Epoch 7 completed in 1344.25 seconds.
Checkpoint saved at ./checkpoints/basic_block_beit_epoch_7.pt
Training epoch 8/10...
Epoch 8/10 Loss: 0.2281
Epoch 8 completed in 1344.08 seconds.
Checkpoint saved at ./checkpoints/basic_block_beit_epoch_8.pt
Training epoch 9/10...
Epoch 9/10 Loss: 0.2606
Epoch 9 completed in 1344.90 seconds.
Checkpoint saved at ./checkpoints/basic_block_beit_epoch_9.pt
Training epoch 10/10...
Epoch 10/10 Loss: 0.1026
Epoch 10 completed in 1344.53 seconds.
Checkpoint saved at ./checkpoints/basic_block_beit_epoch_10.pt
Student Model Size: 0.50 MB
Evaluating Student Model...
Searching for checkpoints in ./checkpoints...
Loading checkpoint from ./checkpoints/basic_block_beit_epoch_10.pt...


  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 1.0000


<h2>DINO</h2>

<h4>5 epochs</h4>

In [3]:
# Training function for the student model using DINO
def train_student_model_dino(
    teacher_model, student_model, train_loader, optimizer, 
    num_epochs=10, checkpoint_dir=None, prefix=None
):
    epoch_times = []
    start_epoch = 0
    if checkpoint_dir and prefix:
        start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    student_model.train()

    for epoch in range(start_epoch, num_epochs):
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        start_time = time.time()
        epoch_loss = 0.0

        for images, _ in train_loader:
            images = images.to(device)

            # Teacher model logits (DINO)
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
                teacher_features = teacher_outputs["features"].mean(dim=1)  # Global pooling

            # Student model logits
            student_logits = student_model(teacher_features)

            # Compute distillation loss
            loss = torch.nn.functional.cross_entropy(student_logits, torch.arange(len(student_logits)).to(device))

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        scheduler.step()

        print(f"Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss:.4f}")
        end_time = time.time()
        epoch_time = end_time - start_time
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds.")

        if checkpoint_dir and prefix:
            save_checkpoint(student_model, optimizer, epoch, epoch_loss, checkpoint_dir, prefix)
            
    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")
    return epoch_times


# Evaluation function
def evaluate_student_model_dino(student_model, teacher_model, val_loader, checkpoint_dir=None, prefix=None):
    if checkpoint_dir and prefix:
        print(f"Searching for checkpoints in {checkpoint_dir}...")
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])

    student_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)

            # Teacher model logits (DINO)
            teacher_outputs = teacher_model(images)
            teacher_features = teacher_outputs["features"].mean(dim=1)  # Global pooling

            # Student model logits
            student_logits = student_model(teacher_features)
            predictions = torch.argmax(student_logits, dim=-1)

            all_predictions.extend(predictions.cpu().tolist())
            all_labels.extend(torch.arange(len(predictions)).tolist())

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Evaluation Accuracy: {accuracy:.4f}")

# Main Code to Train and Evaluate
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize Bottleneck Student Model
print("Initializing Bottleneck Student Model...")
student_model = UnifiedStudentModel(vision_dim=256, teacher_output_dim=768, bottleneck_dim=128)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4)

# Load DINO Teacher Model
print("Loading DINO Teacher Model...")
dino_teacher_model = vit_b_16(weights="IMAGENET1K_V1").to('cuda')
dino_teacher_model = create_feature_extractor(dino_teacher_model, return_nodes={"encoder.layers": "features"})

# Load the latest checkpoint if it exists
prefix = "bottleneck_dino_epoch_"

# Train the student model
print("Training Bottleneck Student Model...")
epoch_times = train_student_model_dino(
    teacher_model=dino_teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    optimizer=optimizer,
    num_epochs= 5,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

# Evaluate the student model
print("Evaluating Bottleneck Student Model...")
evaluate_student_model_dino(
    student_model=student_model,
    teacher_model=dino_teacher_model,
    val_loader=val_loader,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

Initializing Bottleneck Student Model...
Loading DINO Teacher Model...
Training Bottleneck Student Model...
Loading checkpoint from ./checkpoints/bottleneck_dino_epoch_3.pt...
Resuming training from epoch 3.
Training epoch 4/5...


  checkpoint = torch.load(checkpoint_path)


Epoch 4/5 Loss: 0.1621
Epoch 4 completed in 744.66 seconds.
Checkpoint saved at ./checkpoints/bottleneck_dino_epoch_4.pt
Training epoch 5/5...
Epoch 5/5 Loss: 0.1039
Epoch 5 completed in 762.88 seconds.
Checkpoint saved at ./checkpoints/bottleneck_dino_epoch_5.pt
Student Model Size: 0.50 MB
Evaluating Bottleneck Student Model...
Searching for checkpoints in ./checkpoints...
Loading checkpoint from ./checkpoints/bottleneck_dino_epoch_5.pt...


  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 1.0000


<h4>10 epochs</h4>

In [None]:
print("Training Student Model with DINO...")
epoch_times = train_student_model_dino(
    dino_teacher_model, student_model, train_loader, optimizer, 
    num_epochs=10, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Evaluating Student Model with DINO...")
evaluate_student_model_dino(
    student_model, dino_teacher_model, val_loader, 
    checkpoint_dir=checkpoint_dir, prefix=prefix
)