<h1>TinyImageNet</h1>

In [2]:
import time
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.datasets import ImageFolder
from sklearn.metrics import accuracy_score
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.models import vit_b_16
from torchvision.models.feature_extraction import create_feature_extractor
import os

class SimpleProjectionModel(nn.Module):
    def __init__(self, vision_dim=256, teacher_output_dim=768):
        super(SimpleProjectionModel, self).__init__()
        # Vision encoder
        self.vision_encoder = nn.Linear(teacher_output_dim, vision_dim)
        # Logit scale for distillation
        self.logit_scale = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, teacher_features):
        # Normalize and project features
        vision_proj = self.vision_encoder(teacher_features)
        vision_proj = vision_proj / vision_proj.norm(dim=-1, keepdim=True)
        # Compute logits (self-similarity or pairwise similarity)
        logits = self.logit_scale.exp() * vision_proj @ vision_proj.t()
        return logits

def compute_model_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024**2)
    
# Save checkpoint
def save_checkpoint(student_model, optimizer, epoch, loss, checkpoint_dir, prefix):
    if checkpoint_dir and prefix:
        checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}{epoch + 1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': student_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

# Load checkpoint
def load_checkpoint(checkpoint_dir, prefix, student_model, optimizer):
    start_epoch = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch']
            print(f"Resuming training from epoch {start_epoch}.")
    return start_epoch

# Initialize Tiny ImageNet dataset loaders
def init_tiny_imagenet_data(data_dir, batch_size=32):
    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dir = os.path.join(data_dir, "train")
    train_dataset = ImageFolder(root=train_dir, transform=transform)

    val_dir = os.path.join(data_dir, "val/images")
    val_dataset = ImageFolder(root=val_dir, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader
    
data_dir = "./data/tiny-imagenet-200"
train_loader, val_loader = init_tiny_imagenet_data(data_dir, batch_size=8)

<h2>BEiT</h2>

<h4>5 epochs</h4>

In [2]:
# Training logic
def train_student_model_beit(
    teacher_model, teacher_processor, student_model, train_loader, optimizer, 
    num_epochs=10, temperature=2.0, checkpoint_dir=None, prefix=None
):
    epoch_times = []
    start_epoch = 0
    if checkpoint_dir and prefix:
        start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    student_model.train()

    for epoch in range(start_epoch, num_epochs):
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        start_time = time.time()
        epoch_loss = 0.0

        for images, _ in train_loader:
            pil_images = [transforms.ToPILImage()(img) for img in images]

            # Teacher model logits
            with torch.no_grad():
                teacher_inputs = teacher_processor(images=pil_images, return_tensors="pt").to(device)
                teacher_outputs = teacher_model(**teacher_inputs)
                teacher_features = teacher_outputs.last_hidden_state.mean(dim=1)

            # Student model logits
            student_logits = student_model(teacher_features)

            # Compute loss
            loss = F.cross_entropy(student_logits, torch.arange(len(student_logits)).to(device))

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        scheduler.step()

        print(f"Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss:.4f}")
        end_time = time.time()
        epoch_time = end_time - start_time
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds.")

        if checkpoint_dir and prefix:
            save_checkpoint(student_model, optimizer, epoch, epoch_loss, checkpoint_dir, prefix)
            print(f"saving checkpoint {prefix}...")

    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")
    return epoch_times
    
# Evaluation logic
def evaluate_student_model_beit(
    student_model, teacher_model, teacher_processor, val_loader, checkpoint_dir=None, prefix=None
):
    if checkpoint_dir and prefix:
        print(f"Searching for checkpoints in {checkpoint_dir}...")
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])

    student_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, _ in val_loader:
            pil_images = [transforms.ToPILImage()(img) for img in images]

            # Teacher model logits
            teacher_inputs = teacher_processor(images=pil_images, return_tensors="pt").to(device)
            teacher_outputs = teacher_model(**teacher_inputs)
            teacher_features = teacher_outputs.last_hidden_state.mean(dim=1)

            # Student model logits
            student_logits = student_model(teacher_features)
            predictions = torch.argmax(student_logits, dim=-1)

            all_predictions.extend(predictions.cpu().tolist())
            all_labels.extend(torch.arange(len(predictions)).tolist())

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Evaluation Accuracy: {accuracy:.4f}")

# Initialize student model
student_output_dim = 256
beit_teacher_output_dim = 768
student_model = SimpleProjectionModel(vision_dim=student_output_dim, teacher_output_dim=beit_teacher_output_dim)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4)

# Load teacher model
from transformers import BeitModel, BeitFeatureExtractor
teacher_model = BeitModel.from_pretrained("microsoft/beit-base-patch16-224").to('cuda')
teacher_processor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224")

# Train the student model - 5 epoches
checkpoint_dir = "./checkpoints"
prefix = "simpleprojection_beit_epoch_"

print("Training Student Model with Tiny ImageNet...")
epoch_times = train_student_model_beit(
    teacher_model, teacher_processor, student_model, train_loader, optimizer, 
    num_epochs=5, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Training complete.")

# Evaluate the student model
print("Evaluating Student Model...")
evaluate_student_model_beit(
    student_model, teacher_model, teacher_processor, val_loader, 
    checkpoint_dir=checkpoint_dir, prefix=prefix
)

Training Student Model with Tiny ImageNet...
Loading checkpoint from ./checkpoints/simpleprojection_beit_epoch_5.pt...
Resuming training from epoch 5.
Student Model Size: 0.75 MB
Training complete.
Evaluating Student Model...
Searching for checkpoints in ./checkpoints...
Loading checkpoint from ./checkpoints/simpleprojection_beit_epoch_5.pt...


  return func(*args, **kwargs)
  checkpoint = torch.load(checkpoint_path)
  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 1.0000


<h4>10 epochs</h4>

In [3]:
# Train the student model - 10 epoches
checkpoint_dir = "./checkpoints"
prefix = "simpleprojection_beit_epoch_"

print("Training Student Model with Tiny ImageNet...")
epoch_times = train_student_model_beit(
    teacher_model, teacher_processor, student_model, train_loader, optimizer, 
    num_epochs=10, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Training complete.")

# Evaluate the student model
print("Evaluating Student Model...")
evaluate_student_model_beit(
    student_model, teacher_model, teacher_processor, val_loader, 
    checkpoint_dir=checkpoint_dir, prefix=prefix
)

Training Student Model with Tiny ImageNet...
Loading checkpoint from ./checkpoints/simpleprojection_beit_epoch_5.pt...
Resuming training from epoch 5.
Training epoch 6/10...


  checkpoint = torch.load(checkpoint_path)


Epoch 6/10 Loss: 0.0048
Epoch 6 completed in 1756.25 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_beit_epoch_6.pt
saving checkpoint simpleprojection_beit_epoch_...
Training epoch 7/10...
Epoch 7/10 Loss: 0.0043
Epoch 7 completed in 1761.47 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_beit_epoch_7.pt
saving checkpoint simpleprojection_beit_epoch_...
Training epoch 8/10...
Epoch 8/10 Loss: 0.0036
Epoch 8 completed in 1758.53 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_beit_epoch_8.pt
saving checkpoint simpleprojection_beit_epoch_...
Training epoch 9/10...
Epoch 9/10 Loss: 0.0021
Epoch 9 completed in 1793.96 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_beit_epoch_9.pt
saving checkpoint simpleprojection_beit_epoch_...
Training epoch 10/10...
Epoch 10/10 Loss: 0.0012
Epoch 10 completed in 1994.52 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_beit_epoch_10.pt
saving checkpoint simpleprojection_beit_epoch_...
Student M

  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 1.0000


<h2>DINO</h2>

<h4>5 epochs</h4>

In [3]:
# Training function for the student model using DINO with SimpleProjection
def train_student_model_dino(
    teacher_model, student_model, train_loader, optimizer, 
    num_epochs=10, checkpoint_dir=None, prefix=None
):
    epoch_times = []
    start_epoch = 0
    if checkpoint_dir and prefix:
        start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    student_model.train()

    for epoch in range(start_epoch, num_epochs):
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        start_time = time.time()
        epoch_loss = 0.0

        for images, _ in train_loader:
            images = images.to(device)

            # Teacher model logits (DINO)
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
                teacher_features = teacher_outputs["features"].mean(dim=1)  # Global pooling

            # Student model logits
            student_logits = student_model(teacher_features)

            # Compute distillation loss
            loss = torch.nn.functional.cross_entropy(student_logits, torch.arange(len(student_logits)).to(device))

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        scheduler.step()

        print(f"Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss:.4f}")
        end_time = time.time()
        epoch_time = end_time - start_time
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds.")

        if checkpoint_dir and prefix:
            save_checkpoint(student_model, optimizer, epoch, epoch_loss, checkpoint_dir, prefix)
            
    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")
    return epoch_times


# Evaluation function
def evaluate_student_model_dino(student_model, teacher_model, val_loader, checkpoint_dir=None, prefix=None):
    if checkpoint_dir and prefix:
        print(f"Searching for checkpoints in {checkpoint_dir}...")
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])

    student_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)

            # Teacher model logits (DINO)
            teacher_outputs = teacher_model(images)
            teacher_features = teacher_outputs["features"].mean(dim=1)  # Global pooling

            # Student model logits
            student_logits = student_model(teacher_features)
            predictions = torch.argmax(student_logits, dim=-1)

            all_predictions.extend(predictions.cpu().tolist())
            all_labels.extend(torch.arange(len(predictions)).tolist())

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Evaluation Accuracy: {accuracy:.4f}")

# Main Code to Train and Evaluate
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize SimpleProjection Student Model
print("Initializing SimpleProjection Student Model...")
student_model = SimpleProjectionModel(vision_dim=256, teacher_output_dim=768)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4)

# Load DINO Teacher Model
print("Loading DINO Teacher Model...")
dino_teacher_model = vit_b_16(weights="IMAGENET1K_V1").to('cuda')
dino_teacher_model = create_feature_extractor(dino_teacher_model, return_nodes={"encoder.layers": "features"})

# Load the latest checkpoint if it exists
prefix = "simpleprojection_dino_epoch_"

# Train the student model
print("Training SimpleProjection Student Model...")
epoch_times = train_student_model_dino(
    teacher_model=dino_teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    optimizer=optimizer,
    num_epochs=5,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

# Evaluate the student model
print("Evaluating SimpleProjection Student Model...")
evaluate_student_model_dino(
    student_model=student_model,
    teacher_model=dino_teacher_model,
    val_loader=val_loader,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

Initializing SimpleProjection Student Model...
Loading DINO Teacher Model...
Loading checkpoint from ./checkpoints/simpleprojection_dino_epoch_2.pt...
Resuming training from epoch 2.
Training SimpleProjection Student Model...
Loading checkpoint from ./checkpoints/simpleprojection_dino_epoch_2.pt...
Resuming training from epoch 2.
Training epoch 3/5...


  checkpoint = torch.load(checkpoint_path)


Epoch 3/5 Loss: 0.1055
Epoch 3 completed in 754.40 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_epoch_3.pt
Training epoch 4/5...
Epoch 4/5 Loss: 0.0555
Epoch 4 completed in 797.98 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_epoch_4.pt
Training epoch 5/5...
Epoch 5/5 Loss: 0.0221
Epoch 5 completed in 1226.30 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_epoch_5.pt
Student Model Size: 0.75 MB
Evaluating SimpleProjection Student Model...
Searching for checkpoints in ./checkpoints...
Loading checkpoint from ./checkpoints/simpleprojection_dino_epoch_5.pt...


  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 1.0000


<h4>10 epochs</h4>

In [3]:
# Train the student model
print("Training SimpleProjection Student Model...")
epoch_times = train_student_model_dino(
    teacher_model=dino_teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    optimizer=optimizer,
    num_epochs=10,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

# Evaluate the student model
print("Evaluating SimpleProjection Student Model...")
evaluate_student_model_dino(
    student_model=student_model,
    teacher_model=dino_teacher_model,
    val_loader=val_loader,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

Initializing SimpleProjection Student Model...
Loading DINO Teacher Model...
Training SimpleProjection Student Model...
Loading checkpoint from ./checkpoints/simpleprojection_dino_epoch_5.pt...
Resuming training from epoch 5.
Training epoch 6/10...


  checkpoint = torch.load(checkpoint_path)


Epoch 6/10 Loss: 0.0107
Epoch 6 completed in 753.19 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_epoch_6.pt
Training epoch 7/10...
Epoch 7/10 Loss: 0.0034
Epoch 7 completed in 392.35 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_epoch_7.pt
Training epoch 8/10...
Epoch 8/10 Loss: 0.0043
Epoch 8 completed in 378.20 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_epoch_8.pt
Training epoch 9/10...
Epoch 9/10 Loss: 0.0035
Epoch 9 completed in 375.94 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_epoch_9.pt
Training epoch 10/10...
Epoch 10/10 Loss: 0.0023
Epoch 10 completed in 380.70 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_epoch_10.pt
Student Model Size: 0.75 MB
Evaluating SimpleProjection Student Model...
Searching for checkpoints in ./checkpoints...
Loading checkpoint from ./checkpoints/simpleprojection_dino_epoch_10.pt...


  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 1.0000


<h1>COCO</h1>

In [1]:
import time
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.models import vit_b_16
from torchvision.models.feature_extraction import create_feature_extractor
from pycocotools.coco import COCO
from PIL import Image
import os

# COCO dataset loader for classification
class COCOClassification(torch.utils.data.Dataset):
    def __init__(self, root, annotation_file, transform=None):
        self.root = root
        self.coco = COCO(annotation_file)
        self.transform = transform
        self.image_ids = list(self.coco.imgToAnns.keys())
        self.classes = list(self.coco.cats.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        annotations = self.coco.loadAnns(self.coco.getAnnIds(imgIds=image_id))

        # Load image
        image_info = self.coco.loadImgs(image_id)[0]
        image_path = os.path.join(self.root, image_info['file_name'])
        image = Image.open(image_path).convert("RGB")

        # Create multi-label vector
        labels = [0] * len(self.classes)
        for ann in annotations:
            category_idx = self.classes.index(ann['category_id'])
            labels[category_idx] = 1

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(labels, dtype=torch.float32)

# Save checkpoint
def save_checkpoint(student_model, optimizer, epoch, loss, checkpoint_dir, prefix):
    if checkpoint_dir and prefix:
        checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}{epoch + 1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': student_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

# Load checkpoint
def load_checkpoint(checkpoint_dir, prefix, student_model, optimizer=None):
    start_epoch = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])
            if optimizer and 'optimizer_state_dict' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            if 'epoch' in checkpoint:
                start_epoch = checkpoint['epoch']
            print(f"Resuming training from epoch {start_epoch}.")
    return start_epoch
    
# Initialize COCO dataset loaders
def init_coco_data(data_dir, annotation_file, batch_size=32):
    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    dataset = COCOClassification(root=data_dir, annotation_file=annotation_file, transform=transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    return data_loader

# Student Model
def compute_model_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024**2)

class SimpleProjectionModel(nn.Module):
    def __init__(self, vision_dim=256, num_classes=80):
        super(SimpleProjectionModel, self).__init__()
        self.vision_encoder = nn.Linear(vision_dim, num_classes)
        self.logit_scale = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, teacher_features):
        logits = self.vision_encoder(teacher_features)
        return logits


# Paths for COCO dataset
train_data_dir = "/home/yx3493/train2017/train2017"
train_annotation_file = "/home/yx3493/annotations_trainval2017/annotations/instances_train2017.json"
val_data_dir = "/home/yx3493/val2017/val2017"
val_annotation_file = "/home/yx3493/annotations_trainval2017/annotations/instances_val2017.json"

train_loader = init_coco_data(train_data_dir, train_annotation_file, batch_size=8)
val_loader = init_coco_data(val_data_dir, val_annotation_file, batch_size=8)

loading annotations into memory...
Done (t=17.66s)
creating index...
index created!
loading annotations into memory...
Done (t=2.40s)
creating index...
index created!


<h2>BEiT</h2>

<h4>5 epochs</h4>

In [2]:
# Training logic
def train_student_model_beit(
    teacher_model, teacher_processor, student_model, train_loader, optimizer,
    num_epochs=10, checkpoint_dir=None, prefix=None
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)

    # Multi-label loss
    criterion = torch.nn.BCEWithLogitsLoss()

    for epoch in range(start_epoch, num_epochs):
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        epoch_loss = 0.0

        for images, labels in train_loader:
            pil_images = [transforms.ToPILImage()(img) for img in images]

            # Teacher model logits
            with torch.no_grad():
                teacher_inputs = teacher_processor(images=pil_images, return_tensors="pt").to(device)
                teacher_outputs = teacher_model(**teacher_inputs)
                teacher_features = teacher_outputs.last_hidden_state.mean(dim=1)

            # Student model logits
            student_logits = student_model(teacher_features)

            # Compute loss (multi-label)
            loss = criterion(student_logits, labels.to(device))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss:.4f}")
        save_checkpoint(student_model, optimizer, epoch, epoch_loss, checkpoint_dir, prefix)
    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")

# Evaluation logic
def evaluate_student_model_beit(
    student_model, teacher_model, teacher_processor, val_loader, checkpoint_dir=None, prefix=None
):
    start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer=None)
    student_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            pil_images = [transforms.ToPILImage()(img) for img in images]

            # Teacher model logits
            teacher_inputs = teacher_processor(images=pil_images, return_tensors="pt").to(device)
            teacher_outputs = teacher_model(**teacher_inputs)
            teacher_features = teacher_outputs.last_hidden_state.mean(dim=1)

            # Student model logits
            student_logits = student_model(teacher_features)
            predictions = torch.sigmoid(student_logits) > 0.5  # Multi-label thresholding

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute multi-label accuracy
    accuracy = accuracy_score(
        [tuple(map(int, x)) for x in all_labels], [tuple(map(int, x)) for x in all_predictions]
    )
    print(f"Evaluation Accuracy: {accuracy:.4f}")

# Initialize teacher and student models
from transformers import BeitModel, BeitFeatureExtractor
teacher_model = BeitModel.from_pretrained("microsoft/beit-base-patch16-224").to('cuda')
teacher_processor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224")

student_output_dim = 256
beit_teacher_output_dim = 768
num_classes = 80  # Number of COCO categories
student_model = SimpleProjectionModel(vision_dim=beit_teacher_output_dim, num_classes=num_classes)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4)

# Train the student model
checkpoint_dir = "./checkpoints"
prefix = "simpleprojection_beit_coco_epoch_"

print("Training Student Model with COCO dataset...")
train_student_model_beit(
    teacher_model, teacher_processor, student_model, train_loader, optimizer,
    num_epochs=5, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Training complete.")

# Evaluate the student model
print("Evaluating Student Model...")
evaluate_student_model_beit(
    student_model, teacher_model, teacher_processor, val_loader, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Evaluation complete.")

Training Student Model with COCO dataset...
Loading checkpoint from ./checkpoints/simpleprojection_beit_coco_epoch_5.pt...
Resuming training from epoch 5.
Student Model Size: 0.23 MB
Training complete.
Evaluating Student Model...
Loading checkpoint from ./checkpoints/simpleprojection_beit_coco_epoch_5.pt...
Resuming training from epoch 5.


  return func(*args, **kwargs)
  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 0.2078
Evaluation complete.


<h4>10 epochs</h4>

In [4]:
print("Training Student Model with COCO dataset...")
train_student_model_beit(
    teacher_model, teacher_processor, student_model, train_loader, optimizer,
    num_epochs=10, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Training complete.")

# Evaluate the student model
print("Evaluating Student Model...")
evaluate_student_model_beit(
    student_model, teacher_model, teacher_processor, val_loader, checkpoint_dir=checkpoint_dir, prefix=prefix
)

print("Evaluation complete.")

Training Student Model with COCO dataset...
Loading checkpoint from ./checkpoints/simpleprojection_beit_coco_epoch_8.pt...
Resuming training from epoch 8.
Training epoch 9/10...


  return func(*args, **kwargs)
  checkpoint = torch.load(checkpoint_path)


Epoch 9/10 Loss: 1086.5213
Checkpoint saved at ./checkpoints/simpleprojection_beit_coco_epoch_9.pt
Training epoch 10/10...
Epoch 10/10 Loss: 1086.4022
Checkpoint saved at ./checkpoints/simpleprojection_beit_coco_epoch_10.pt
Student Model Size: 0.23 MB
Training complete.
Evaluating Student Model...
Loading checkpoint from ./checkpoints/simpleprojection_beit_coco_epoch_10.pt...
Resuming training from epoch 10.
Evaluation Accuracy: 0.2137
Evaluation complete.


<h2>DINO</h2>

<h4>5 epochs</h4>

In [4]:
# Training function for the student model using DINO with SimpleProjection
def train_student_model_dino(
    teacher_model, student_model, train_loader, optimizer, 
    num_epochs=10, checkpoint_dir=None, prefix=None
):
    epoch_times = []
    start_epoch = 0
    if checkpoint_dir and prefix:
        start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    student_model.train()

    for epoch in range(start_epoch, num_epochs):
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        start_time = time.time()
        epoch_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Teacher model features (DINO)
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
                teacher_features = teacher_outputs["features"].mean(dim=1)  # Global pooling

            # Student model logits
            student_logits = student_model(teacher_features)

            # Compute distillation loss (multi-label classification)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(student_logits, labels)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        scheduler.step()

        print(f"Epoch {epoch + 1}/{num_epochs} Loss: {epoch_loss:.4f}")
        end_time = time.time()
        epoch_time = end_time - start_time
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds.")

        if checkpoint_dir and prefix:
            save_checkpoint(student_model, optimizer, epoch, epoch_loss, checkpoint_dir, prefix)
            
    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")
    return epoch_times


# Evaluation function
def evaluate_student_model_dino(student_model, teacher_model, val_loader, checkpoint_dir=None, prefix=None):
    if checkpoint_dir and prefix:
        print(f"Searching for checkpoints in {checkpoint_dir}...")
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])

    student_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            # Teacher model features (DINO)
            teacher_outputs = teacher_model(images)
            teacher_features = teacher_outputs["features"].mean(dim=1)  # Global pooling

            # Student model logits
            student_logits = student_model(teacher_features)
            predictions = torch.sigmoid(student_logits) > 0.5  # Multi-label thresholding

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute multi-label accuracy
    accuracy = accuracy_score(
        [tuple(map(int, x)) for x in all_labels], [tuple(map(int, x)) for x in all_predictions]
    )
    print(f"Evaluation Accuracy: {accuracy:.4f}")


# Main Code to Train and Evaluate
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize SimpleProjection Student Model
print("Initializing SimpleProjection Student Model...")
teacher_output_dim = 768  # Output dimension of DINO teacher
student_output_dim = 256  # Dimension of student model bottleneck
num_classes = 80          # Number of COCO categories
student_model = SimpleProjectionModel(vision_dim=teacher_output_dim, num_classes=num_classes)


optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4)

# Load DINO Teacher Model
print("Loading DINO Teacher Model...")
dino_teacher_model = vit_b_16(weights="IMAGENET1K_V1").to('cuda')
dino_teacher_model = create_feature_extractor(dino_teacher_model, return_nodes={"encoder.layers": "features"})

# Load the latest checkpoint if it exists
prefix = "simpleprojection_dino_coco_epoch_"

# Train the student model
print("Training SimpleProjection Student Model...")
epoch_times = train_student_model_dino(
    teacher_model=dino_teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    optimizer=optimizer,
    num_epochs=5,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

# Evaluate the student model
print("Evaluating SimpleProjection Student Model...")
evaluate_student_model_dino(
    student_model=student_model,
    teacher_model=dino_teacher_model,
    val_loader=val_loader,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

Initializing SimpleProjection Student Model...
Loading DINO Teacher Model...
Training SimpleProjection Student Model...
Training epoch 1/5...
Epoch 1/5 Loss: 1227.8319
Epoch 1 completed in 606.73 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_coco_epoch_1.pt
Training epoch 2/5...
Epoch 2/5 Loss: 1025.7893
Epoch 2 completed in 552.57 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_coco_epoch_2.pt
Training epoch 3/5...
Epoch 3/5 Loss: 983.6097
Epoch 3 completed in 865.27 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_coco_epoch_3.pt
Training epoch 4/5...
Epoch 4/5 Loss: 963.2624
Epoch 4 completed in 1143.16 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_coco_epoch_4.pt
Training epoch 5/5...
Epoch 5/5 Loss: 953.6911
Epoch 5 completed in 1307.75 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_coco_epoch_5.pt
Student Model Size: 0.23 MB
Evaluating SimpleProjection Student Model...
Searching for checkpoin

  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 0.2254


<h4>10 epochs</h4>

In [2]:
# Train the student model
print("Training SimpleProjection Student Model...")
epoch_times = train_student_model_dino(
    teacher_model=dino_teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    optimizer=optimizer,
    num_epochs=10,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

# Evaluate the student model
print("Evaluating SimpleProjection Student Model...")
evaluate_student_model_dino(
    student_model=student_model,
    teacher_model=dino_teacher_model,
    val_loader=val_loader,
    checkpoint_dir=checkpoint_dir,
    prefix=prefix
)

Initializing SimpleProjection Student Model...
Loading DINO Teacher Model...
Training SimpleProjection Student Model...
Loading checkpoint from ./checkpoints/simpleprojection_dino_coco_epoch_8.pt...
Resuming training from epoch 8.
Training epoch 9/10...


  checkpoint = torch.load(checkpoint_path)


Epoch 9/10 Loss: 951.9704
Epoch 9 completed in 1304.23 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_coco_epoch_9.pt
Training epoch 10/10...
Epoch 10/10 Loss: 951.9878
Epoch 10 completed in 1300.92 seconds.
Checkpoint saved at ./checkpoints/simpleprojection_dino_coco_epoch_10.pt
Student Model Size: 0.23 MB
Evaluating SimpleProjection Student Model...
Searching for checkpoints in ./checkpoints...
Loading checkpoint from ./checkpoints/simpleprojection_dino_coco_epoch_10.pt...


  checkpoint = torch.load(checkpoint_path)


Evaluation Accuracy: 0.2254
