<h1>Tiny ImageNet</h1>

In [1]:
import time
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import os
from sklearn.metrics import accuracy_score
from torchvision.models import vit_b_16
from torchvision.models.feature_extraction import create_feature_extractor
from transformers import BeitModel

# Attention-based UnifiedStudentModel
class UnifiedStudentModel(nn.Module):
    def __init__(self, vision_dim=256, teacher_output_dim=768, num_heads=4):
        super(UnifiedStudentModel, self).__init__()
        # Vision attention
        self.vision_attention = nn.MultiheadAttention(embed_dim=teacher_output_dim, num_heads=num_heads)
        # Projection layer
        self.vision_proj = nn.Linear(teacher_output_dim, vision_dim)
        # Logit scale for similarity computation
        self.logit_scale = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, vision_features):
        """
        vision_features: Tensor of shape (batch_size, seq_len, teacher_output_dim)
        """
        # Attention over vision features
        vision_attn_output, _ = self.vision_attention(vision_features, vision_features, vision_features)
        
        # Project the attention output to vision dimension
        vision_proj = self.vision_proj(vision_attn_output)
        vision_proj = vision_proj / vision_proj.norm(dim=-1, keepdim=True)
        
        # Compute logits (self-similarity or pairwise similarity)
        logits = self.logit_scale.exp() * vision_proj @ vision_proj.transpose(-1, -2)
        return logits
        
# Initialize Tiny ImageNet dataset loaders
def init_tiny_imagenet_data(data_dir, batch_size=32, num_workers=4):
    """
    Prepares the Tiny ImageNet dataset for training and validation.
    """
    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Training dataset and loader
    train_dir = os.path.join(data_dir, "train")
    train_dataset = ImageFolder(root=train_dir, transform=transform)

    val_dir = os.path.join(data_dir, "val/images")
    val_dataset = ImageFolder(root=val_dir, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader

# Save and load checkpoint functions
def save_checkpoint(student_model, optimizer, epoch, loss, checkpoint_dir, prefix):
    if checkpoint_dir and prefix:
        checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}{epoch + 1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': student_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

def load_checkpoint(checkpoint_dir, prefix, student_model, optimizer):
    start_epoch = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch']
            print(f"Resuming training from epoch {start_epoch}.")
    return start_epoch

# Compute model size
def compute_model_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024**2)

# Initialize Tiny ImageNet
tiny_imagenet_dir = "./data/tiny-imagenet-200"
train_loader, val_loader = init_tiny_imagenet_data(tiny_imagenet_dir)

<h2>BEiT</h2>

<h4>5 epochs</h4>

In [None]:
def train_student_model_beit(teacher_model, student_model, train_loader, optimizer, num_epochs=10, checkpoint_dir=None, prefix=None):
    """
    Trains the student model using features extracted from the BEiT teacher model.
    """
    epoch_times = []
    start_epoch = 0
    if checkpoint_dir and prefix:
        start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    student_model.train()

    for epoch in range(start_epoch, num_epochs):
        start_time = time.time()
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        epoch_loss = 0.0

        for images, _ in train_loader:
            images = images.to(device)

            # Teacher model logits
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
                teacher_features = teacher_outputs.last_hidden_state  # Shape: [batch_size, seq_len, 768]
                teacher_features = teacher_features.permute(1, 0, 2)  # Shape: [seq_len, batch_size, 768]

            # Student model logits
            student_logits = student_model(teacher_features)

            # Compute similarity logits and distillation loss
            similarity_logits = student_logits @ student_logits.transpose(-1, -2)  # Shape: [seq_len, batch_size, batch_size]

            # Create target tensor (diagonal of the similarity matrix)
            targets = torch.arange(similarity_logits.size(1)).to(device)  # Shape: [batch_size]
            targets = targets.unsqueeze(0).expand(similarity_logits.size(0), -1)  # Shape: [seq_len, batch_size]

            # Cross-entropy loss
            loss = F.cross_entropy(similarity_logits.reshape(-1, similarity_logits.size(-1)), targets.reshape(-1))

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        scheduler.step()

        print(f"Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss / len(train_loader):.4f}")
        end_time = time.time()
        epoch_time = end_time - start_time
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds.")

        if checkpoint_dir and prefix:
            checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}{epoch + 1}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss / len(train_loader)
            }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")

    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")
    return epoch_times
    
def evaluate_student_model_beit(student_model, teacher_model, val_loader, checkpoint_dir=None, prefix=None):
    """
    Evaluates the student model using features extracted from the BEiT teacher model.
    """
    if checkpoint_dir and prefix:
        print(f"Searching for checkpoints in {checkpoint_dir}...")
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])

    student_model.eval()
    teacher_model.eval()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    teacher_model.to(device)

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)

            # Teacher model logits
            teacher_outputs = teacher_model(images)
            teacher_features = teacher_outputs.last_hidden_state  # Shape: [batch_size, seq_len, 768]

            # Project teacher features to match vision_dim
            teacher_features = teacher_features.permute(1, 0, 2)  # [seq_len, batch_size, embed_dim]
            student_logits = student_model(teacher_features)

            # Compute similarity logits
            similarity_logits = student_logits @ student_logits.transpose(-1, -2)

            # Predictions and ground truth targets
            predictions = torch.argmax(similarity_logits, dim=-1).flatten()
            targets = torch.arange(similarity_logits.size(1)).to(device).repeat(similarity_logits.size(0))

            all_predictions.extend(predictions.cpu().tolist())
            all_labels.extend(targets.cpu().tolist())

    # Ensure correct format for sklearn accuracy_score
    all_predictions = torch.tensor(all_predictions, dtype=torch.int64).flatten().tolist()
    all_labels = torch.tensor(all_labels, dtype=torch.int64).flatten().tolist()

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Evaluation Accuracy: {accuracy:.4f}")


# Student model
vision_dim = 256
teacher_output_dim = 768
num_heads = 4
student_model = UnifiedStudentModel(vision_dim=vision_dim, teacher_output_dim=teacher_output_dim, num_heads=num_heads)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4)

# Load BEiT teacher model
beit_teacher_model = BeitModel.from_pretrained("microsoft/beit-base-patch16-224").to('cuda')

# Training
checkpoint_dir = "./checkpoints"
prefix = "attention_beit_epoch_"

print("Training Student Model with BEiT...")
epoch_times = train_student_model_beit(
    beit_teacher_model, student_model, train_loader, optimizer,
    num_epochs=5, checkpoint_dir=checkpoint_dir, prefix=prefix
)

# Evaluation
print("Evaluating Student Model with BEiT...")
evaluate_student_model_beit(student_model, beit_teacher_model, val_loader, checkpoint_dir=checkpoint_dir, prefix=prefix)

<h4>10 epochs</h4>

In [5]:
print("Training Student Model with BEiT...")
epoch_times = train_student_model_beit(
    beit_teacher_model, student_model, train_loader, optimizer, 
    num_epochs=10, checkpoint_dir=checkpoint_dir, prefix=prefix
)

# Evaluation
print("Evaluating Student Model with BEiT...")
evaluate_student_model_beit(student_model, beit_teacher_model, val_loader, checkpoint_dir=checkpoint_dir, prefix=prefix)

Training Student Model with BEiT...
Loading checkpoint from ./checkpoints/attention_beit_epoch_5.pt...
Resuming training from epoch 5.
Training epoch 6/10...


  checkpoint = torch.load(checkpoint_path)


Epoch 6/10 Loss: 0.3591
Epoch 6 completed in 365.52 seconds.
Checkpoint saved at ./checkpoints/attention_beit_epoch_6.pt
Training epoch 7/10...
Epoch 7/10 Loss: 0.3589
Epoch 7 completed in 364.49 seconds.
Checkpoint saved at ./checkpoints/attention_beit_epoch_7.pt
Training epoch 8/10...
Epoch 8/10 Loss: 0.3588
Epoch 8 completed in 364.68 seconds.
Checkpoint saved at ./checkpoints/attention_beit_epoch_8.pt
Training epoch 9/10...
Epoch 9/10 Loss: 0.3573
Epoch 9 completed in 364.70 seconds.
Checkpoint saved at ./checkpoints/attention_beit_epoch_9.pt
Training epoch 10/10...
Epoch 10/10 Loss: 0.3597
Epoch 10 completed in 364.51 seconds.
Checkpoint saved at ./checkpoints/attention_beit_epoch_10.pt
Student Model Size: 9.76 MB
Evaluating Student Model with BEiT...
Searching for checkpoints in ./checkpoints...
Loading checkpoint from ./checkpoints/attention_beit_epoch_10.pt...


  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 0.8450


<h2>DINO</h2>

<h4>5 epochs</h4>

In [5]:
import time
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import os
from sklearn.metrics import accuracy_score
from torchvision.models import vit_b_16
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

# Attention-based UnifiedStudentModel
class UnifiedStudentModel(nn.Module):
    def __init__(self, vision_dim=256, teacher_output_dim=768, num_heads=4):
        super(UnifiedStudentModel, self).__init__()
        # Vision attention
        self.vision_attention = nn.MultiheadAttention(embed_dim=teacher_output_dim, num_heads=num_heads)
        # Projection layer
        self.vision_proj = nn.Linear(teacher_output_dim, vision_dim)
        # Logit scale for similarity computation
        self.logit_scale = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, vision_features):
        """
        vision_features: Tensor of shape (batch_size, seq_len, teacher_output_dim)
        """
        # Attention over vision features
        vision_attn_output, _ = self.vision_attention(vision_features, vision_features, vision_features)
        
        # Project the attention output to vision dimension
        vision_proj = self.vision_proj(vision_attn_output)
        vision_proj = vision_proj / vision_proj.norm(dim=-1, keepdim=True)
        
        # Compute logits (self-similarity or pairwise similarity)
        logits = self.logit_scale.exp() * vision_proj @ vision_proj.transpose(-1, -2)
        return logits
        
# Initialize Tiny ImageNet dataset loaders
def init_tiny_imagenet_data(data_dir, batch_size=32, num_workers=4):
    """
    Prepares the Tiny ImageNet dataset for training and validation.
    """
    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Training dataset and loader
    train_dir = os.path.join(data_dir, "train")
    train_dataset = ImageFolder(root=train_dir, transform=transform)

    val_dir = os.path.join(data_dir, "val/images")
    val_dataset = ImageFolder(root=val_dir, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader

# Save and load checkpoint functions
def save_checkpoint(student_model, optimizer, epoch, loss, checkpoint_dir, prefix):
    if checkpoint_dir and prefix:
        checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}{epoch + 1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': student_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

def load_checkpoint(checkpoint_dir, prefix, student_model, optimizer):
    start_epoch = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch']
            print(f"Resuming training from epoch {start_epoch}.")
    return start_epoch

# Compute model size
def compute_model_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024**2)

# Initialize Tiny ImageNet
tiny_imagenet_dir = "./data/tiny-imagenet-200"
train_loader, val_loader = init_tiny_imagenet_data(tiny_imagenet_dir)

def train_student_model_dino(teacher_model, student_model, train_loader, optimizer, num_epochs=10, checkpoint_dir=None, prefix=None):
    """
    Trains the student model using features extracted from the DINO teacher model.
    """
    epoch_times = []
    start_epoch = 0
    if checkpoint_dir and prefix:
        start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    teacher_model.to(device)
    student_model.train()

    for epoch in range(start_epoch, num_epochs):
        start_time = time.time()
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        epoch_loss = 0.0

        for images, _ in train_loader:
            images = images.to(device)

            # Teacher model logits (DINO)
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
                teacher_features = teacher_outputs["features"]  # Use pooled features
                teacher_features = F.normalize(teacher_features, p=2, dim=-1)  # Normalize features

            # Student model logits
            student_logits = student_model(teacher_features.unsqueeze(0))  # Add sequence dimension

            # Compute similarity logits
            similarity_logits = student_logits @ student_logits.transpose(-1, -2)
            similarity_logits /= torch.sqrt(torch.tensor(student_logits.size(-1), dtype=torch.float32, device=device))

            # Ensure logits shape matches [batch_size, batch_size]
            similarity_logits = similarity_logits.squeeze(0)

            # Create target tensor
            targets = torch.arange(similarity_logits.size(0)).to(device)  # [batch_size]

            # Cross-entropy loss
            loss = F.cross_entropy(similarity_logits, targets)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        scheduler.step()

        print(f"Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss / len(train_loader):.4f}")
        end_time = time.time()
        epoch_time = end_time - start_time
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds.")

        if checkpoint_dir and prefix:
            checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}{epoch + 1}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss / len(train_loader)
            }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")

    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")
    return epoch_times

# Adjust feature extractor for DINO
print("Loading DINO Teacher Model...")
dino_teacher_model = vit_b_16(weights="IMAGENET1K_V1").to('cuda')
train_nodes, eval_nodes = get_graph_node_names(dino_teacher_model)
print("Train Nodes:", train_nodes)
print("Eval Nodes:", eval_nodes)
dino_teacher_model = create_feature_extractor(
    dino_teacher_model, 
    return_nodes={"getitem_5": "features"}
)

# Initialize optimizer and student model
student_model = UnifiedStudentModel(vision_dim=256, teacher_output_dim=768, num_heads=4)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4, weight_decay=1e-4)

# Training
checkpoint_dir = "./checkpoints"
prefix = "attention2_dino_epoch_"

print("Training Attention-based Student Model...")
epoch_times = train_student_model_dino(
    teacher_model=dino_teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    optimizer=optimizer,
    num_epochs=5,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

# Evaluation
def evaluate_student_model_dino(student_model, teacher_model, val_loader, checkpoint_dir=None, prefix=None):
    """
    Evaluates the student model using features extracted from the DINO teacher model.
    """
    if checkpoint_dir:
        print(f"Searching for checkpoints in {checkpoint_dir}...")
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])

    student_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)

            teacher_outputs = teacher_model(images)
            teacher_features = teacher_outputs["features"]  # Global pooling
            teacher_features = F.normalize(teacher_features, p=2, dim=-1)  # Normalize features

            # Student logits
            student_logits = student_model(teacher_features.unsqueeze(0))
            similarity_logits = student_logits @ student_logits.transpose(-1, -2)
            similarity_logits /= torch.sqrt(torch.tensor(student_logits.size(-1), dtype=torch.float32, device=device))

            # Generate predictions and targets
            predictions = torch.argmax(similarity_logits, dim=1).flatten()
            batch_size = images.size(0)  # Ensure targets match batch size
            targets = torch.arange(batch_size).to(device).flatten()

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(targets.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Evaluation Accuracy: {accuracy:.4f}")

print("Evaluating Attention-based Student Model...")
evaluate_student_model_dino(
    student_model=student_model,
    teacher_model=dino_teacher_model,
    val_loader=val_loader,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)


Loading DINO Teacher Model...
Train Nodes: ['x', 'getattr', 'getitem', 'getitem_1', 'getitem_2', 'getitem_3', 'eq', '_assert', 'eq_1', '_assert_1', 'floordiv', 'floordiv_1', 'conv_proj', 'mul', 'reshape', 'permute', 'getattr_1', 'getitem_4', 'class_token', 'expand', 'cat', 'encoder.dim', 'encoder.eq', 'encoder.getattr', 'encoder._assert', 'encoder.encoder_pos_embedding', 'encoder.add', 'encoder.dropout', 'encoder.layers.encoder_layer_0.dim', 'encoder.layers.encoder_layer_0.eq', 'encoder.layers.encoder_layer_0.getattr', 'encoder.layers.encoder_layer_0._assert', 'encoder.layers.encoder_layer_0.ln', 'encoder.layers.encoder_layer_0.self_attention', 'encoder.layers.encoder_layer_0.getitem', 'encoder.layers.encoder_layer_0.getitem_1', 'encoder.layers.encoder_layer_0.dropout', 'encoder.layers.encoder_layer_0.add', 'encoder.layers.encoder_layer_0.ln_1', 'encoder.layers.encoder_layer_0.mlp', 'encoder.layers.encoder_layer_0.add_1', 'encoder.layers.encoder_layer_1.dim', 'encoder.layers.encoder_la

  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 0.2004


<h4>10 epochs</h4>

In [6]:
print("Training Attention-based Student Model...")
epoch_times = train_student_model_dino(
    teacher_model=dino_teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    optimizer=optimizer,
    num_epochs=10,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

print("Evaluating Attention-based Student Model...")
evaluate_student_model_dino(
    student_model=student_model,
    teacher_model=dino_teacher_model,
    val_loader=val_loader,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)


Training Attention-based Student Model...
Loading checkpoint from ./checkpoints/attention2_dino_epoch_5.pt...
Resuming training from epoch 5.
Training epoch 6/10...


  checkpoint = torch.load(checkpoint_path)


Epoch 6/10 Loss: 1.2058
Epoch 6 completed in 877.40 seconds.
Checkpoint saved at ./checkpoints/attention2_dino_epoch_6.pt
Training epoch 7/10...
Epoch 7/10 Loss: 1.1912
Epoch 7 completed in 877.39 seconds.
Checkpoint saved at ./checkpoints/attention2_dino_epoch_7.pt
Training epoch 8/10...
Epoch 8/10 Loss: 1.2033
Epoch 8 completed in 877.98 seconds.
Checkpoint saved at ./checkpoints/attention2_dino_epoch_8.pt
Training epoch 9/10...
Epoch 9/10 Loss: 1.2065
Epoch 9 completed in 875.80 seconds.
Checkpoint saved at ./checkpoints/attention2_dino_epoch_9.pt
Training epoch 10/10...
Epoch 10/10 Loss: 1.1927
Epoch 10 completed in 877.62 seconds.
Checkpoint saved at ./checkpoints/attention2_dino_epoch_10.pt
Student Model Size: 9.76 MB
Evaluating Attention-based Student Model...
Searching for checkpoints in ./checkpoints...
Loading checkpoint from ./checkpoints/attention2_dino_epoch_10.pt...


  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 0.2004


<h1>COCO</h1>

In [1]:
import time
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.models import vit_b_16
from torchvision.models.feature_extraction import create_feature_extractor
from pycocotools.coco import COCO
from PIL import Image
import os
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

# COCO dataset loader for classification
class COCOClassification(torch.utils.data.Dataset):
    def __init__(self, root, annotation_file, transform=None):
        self.root = root
        self.coco = COCO(annotation_file)
        self.transform = transform
        self.image_ids = list(self.coco.imgToAnns.keys())
        self.classes = list(self.coco.cats.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        annotations = self.coco.loadAnns(self.coco.getAnnIds(imgIds=image_id))

        # Load image
        image_info = self.coco.loadImgs(image_id)[0]
        image_path = os.path.join(self.root, image_info['file_name'])
        image = Image.open(image_path).convert("RGB")

        # Create multi-label vector
        labels = [0] * len(self.classes)
        for ann in annotations:
            category_idx = self.classes.index(ann['category_id'])
            labels[category_idx] = 1

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(labels, dtype=torch.float32)

# Save checkpoint
def save_checkpoint(student_model, optimizer, epoch, loss, checkpoint_dir, prefix):
    if checkpoint_dir and prefix:
        checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}{epoch + 1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': student_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

# Load checkpoint
def load_checkpoint(checkpoint_dir, prefix, student_model, optimizer=None):
    start_epoch = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])
            if optimizer and 'optimizer_state_dict' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            if 'epoch' in checkpoint:
                start_epoch = checkpoint['epoch']
            print(f"Resuming training from epoch {start_epoch}.")
    return start_epoch
    
# Initialize COCO dataset loaders
def init_coco_data(data_dir, annotation_file, batch_size=32):
    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    dataset = COCOClassification(root=data_dir, annotation_file=annotation_file, transform=transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    return data_loader

# Student Model
def compute_model_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024**2)

class SimpleProjectionModel(nn.Module):
    def __init__(self, vision_dim=256, num_classes=80):
        super(SimpleProjectionModel, self).__init__()
        self.vision_encoder = nn.Linear(vision_dim, num_classes)
        self.logit_scale = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, teacher_features):
        logits = self.vision_encoder(teacher_features)
        return logits


# Paths for COCO dataset
train_data_dir = "/home/yx3493/train2017/train2017"
train_annotation_file = "/home/yx3493/annotations_trainval2017/annotations/instances_train2017.json"
val_data_dir = "/home/yx3493/val2017/val2017"
val_annotation_file = "/home/yx3493/annotations_trainval2017/annotations/instances_val2017.json"

train_loader = init_coco_data(train_data_dir, train_annotation_file, batch_size=8)
val_loader = init_coco_data(val_data_dir, val_annotation_file, batch_size=8)


loading annotations into memory...
Done (t=18.73s)
creating index...
index created!
loading annotations into memory...
Done (t=2.61s)
creating index...
index created!


<h2>BEiT</h2>

<h4>5 epochs</h4>

In [2]:
# Training logic
def train_student_model_beit(
    teacher_model, teacher_processor, student_model, train_loader, optimizer,
    num_epochs=10, checkpoint_dir=None, prefix=None
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)

    # Multi-label loss
    criterion = torch.nn.BCEWithLogitsLoss()

    for epoch in range(start_epoch, num_epochs):
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        epoch_loss = 0.0

        for images, labels in train_loader:
            pil_images = [transforms.ToPILImage()(img) for img in images]

            # Teacher model logits
            with torch.no_grad():
                teacher_inputs = teacher_processor(images=pil_images, return_tensors="pt").to(device)
                teacher_outputs = teacher_model(**teacher_inputs)
                teacher_features = teacher_outputs.last_hidden_state.mean(dim=1)

            # Student model logits
            student_logits = student_model(teacher_features)

            # Compute loss (multi-label)
            loss = criterion(student_logits, labels.to(device))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss:.4f}")
        save_checkpoint(student_model, optimizer, epoch, epoch_loss, checkpoint_dir, prefix)
    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")

# Evaluation logic
def evaluate_student_model_beit(
    student_model, teacher_model, teacher_processor, val_loader, checkpoint_dir=None, prefix=None
):
    start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer=None)
    student_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            pil_images = [transforms.ToPILImage()(img) for img in images]

            # Teacher model logits
            teacher_inputs = teacher_processor(images=pil_images, return_tensors="pt").to(device)
            teacher_outputs = teacher_model(**teacher_inputs)
            teacher_features = teacher_outputs.last_hidden_state.mean(dim=1)

            # Student model logits
            student_logits = student_model(teacher_features)
            predictions = torch.sigmoid(student_logits) > 0.5  # Multi-label thresholding

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute multi-label accuracy
    accuracy = accuracy_score(
        [tuple(map(int, x)) for x in all_labels], [tuple(map(int, x)) for x in all_predictions]
    )
    print(f"Evaluation Accuracy: {accuracy:.4f}")

# Initialize teacher and student models
from transformers import BeitModel, BeitFeatureExtractor
teacher_model = BeitModel.from_pretrained("microsoft/beit-base-patch16-224").to('cuda')
teacher_processor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224")

student_output_dim = 256
beit_teacher_output_dim = 768
num_classes = 80  # Number of COCO categories
student_model = SimpleProjectionModel(vision_dim=beit_teacher_output_dim, num_classes=num_classes)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

# Train the student model
checkpoint_dir = "./checkpoints"
prefix = "attention_beit_coco_epoch_"

print("Training Student Model with COCO dataset...")
train_student_model_beit(
    teacher_model, teacher_processor, student_model, train_loader, optimizer,
    num_epochs=5, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Training complete.")

# Evaluate the student model
print("Evaluating Student Model...")
evaluate_student_model_beit(
    student_model, teacher_model, teacher_processor, val_loader, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Evaluation complete.")

Training Student Model with COCO dataset...
Loading checkpoint from ./checkpoints/attention_beit_coco_epoch_2.pt...
Resuming training from epoch 2.
Training epoch 3/5...


  return func(*args, **kwargs)
  checkpoint = torch.load(checkpoint_path)


Epoch 3/5 Loss: 1093.0803
Checkpoint saved at ./checkpoints/attention_beit_coco_epoch_3.pt
Training epoch 4/5...
Epoch 4/5 Loss: 1090.1349
Checkpoint saved at ./checkpoints/attention_beit_coco_epoch_4.pt
Training epoch 5/5...
Epoch 5/5 Loss: 1088.9544
Checkpoint saved at ./checkpoints/attention_beit_coco_epoch_5.pt
Student Model Size: 0.23 MB
Training complete.
Evaluating Student Model...
Loading checkpoint from ./checkpoints/attention_beit_coco_epoch_5.pt...
Resuming training from epoch 5.
Evaluation Accuracy: 0.2161
Evaluation complete.


<h4>10 epochs</h4>

In [None]:
print("Training Student Model with COCO dataset...")
train_student_model_beit(
    teacher_model, teacher_processor, student_model, train_loader, optimizer,
    num_epochs=10, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Training complete.")

# Evaluate the student model
print("Evaluating Student Model...")
evaluate_student_model_beit(
    student_model, teacher_model, teacher_processor, val_loader, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Evaluation complete.")