In [1]:
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.datasets import CocoCaptions
from sklearn.metrics import accuracy_score
import timm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import ToPILImage
import os
from torchvision.models import resnet18

In [2]:
class UnifiedStudentModel(nn.Module):
    def __init__(self, vision_dim=256, text_dim=256, teacher_output_dim=512, bottleneck_dim=128):
        super(UnifiedStudentModel, self).__init__()
        
        # ResNet backbone for vision features
        self.vision_backbone = resnet18(pretrained=True)
        in_features = self.vision_backbone.fc.in_features
        self.vision_backbone.fc = nn.Identity() 

        self.vision_proj = nn.Sequential(
            nn.Linear(in_features, bottleneck_dim),
            nn.ReLU(),
            nn.Linear(bottleneck_dim, vision_dim)
        )

        self.text_encoder = nn.Sequential(
            nn.Linear(teacher_output_dim, bottleneck_dim),
            nn.ReLU(),
            nn.Linear(bottleneck_dim, text_dim)
        )
  
        self.logit_scale = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, vision_features, text_features):
        # Vision features through ResNet backbone
        vision_backbone_features = self.vision_backbone(vision_features)
        vision_proj = self.vision_proj(vision_backbone_features)
        vision_proj = vision_proj / vision_proj.norm(dim=-1, keepdim=True)

        # Text features through text encoder
        text_proj = self.text_encoder(text_features)
        text_proj = text_proj / text_proj.norm(dim=-1, keepdim=True)

        # Compute similarity logits
        logits = self.logit_scale.exp() * vision_proj @ text_proj.t()
        return logits, vision_proj, text_proj


In [3]:
def contrastive_loss(vision_proj, text_proj, temperature=0.07):
    logits = torch.matmul(vision_proj, text_proj.t()) / temperature
    labels = torch.arange(logits.size(0)).to(logits.device)
    loss = nn.CrossEntropyLoss()(logits, labels)
    return loss

def kl_divergence_loss(student_logits, teacher_logits, temperature=2.0):
    return nn.KLDivLoss(reduction="batchmean")(
        torch.log_softmax(student_logits / temperature, dim=-1),
        torch.softmax(teacher_logits / temperature, dim=-1),
    ) * (temperature ** 2)

def combined_loss(vision_proj, text_proj, student_logits, teacher_logits, temperature=0.07):

    contrastive = contrastive_loss(vision_proj, text_proj, temperature)
    kl_divergence = kl_divergence_loss(student_logits, teacher_logits, temperature)

    total_loss = contrastive + kl_divergence
    return total_loss


In [4]:
def init_coco_data_clip(data_dir, batch_size=8):
    """
    Initialize MS COCO dataset for CLIP.
    """
    transform = Compose([
        Resize((224, 224)),  # CLIP expects raw image data
        ToTensor(),  # Ensures the image is in the range [0, 1]
    ])
    def collate_fn(batch):
        """
        Custom collate function to handle missing captions and ensure consistent batching.
        """
        images, captions = zip(*batch)
        
        # Filter out entries without captions
        filtered_images = []
        filtered_captions = []
        for img, caption in zip(images, captions):
            if caption:  # Ensure there is at least one caption
                filtered_images.append(img)
                filtered_captions.append(caption[0])  # Use the first caption if multiple are present
        
        if not filtered_images:
            raise ValueError("No valid images with captions found in batch.")
        
        # Stack images into a single tensor batch
        image_tensor = torch.stack(filtered_images)
        return image_tensor, filtered_captions

    train_dataset = CocoCaptions(
        root=os.path.join(data_dir, "train2017"),
        annFile=os.path.join(data_dir, "annotations", "captions_train2017.json"),
        transform=transform
    )
    
    val_dataset = CocoCaptions(
        root=os.path.join(data_dir, "val2017"),
        annFile=os.path.join(data_dir, "annotations", "captions_val2017.json"),
        transform=transform
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn
    )
    
    return train_loader, val_loader


def init_coco_data_blip(data_dir, batch_size=8):

    transform = Compose([
        Resize((224, 224)),
        ToTensor(),  # Ensures the image is in the range [0, 1]
    ])
    def collate_fn(batch):
        """
        Custom collate function to handle missing captions and ensure consistent batching.
        """
        images, captions = zip(*batch)
        
        # Filter out entries without captions
        filtered_images = []
        filtered_captions = []
        for img, caption in zip(images, captions):
            if caption:  # Ensure there is at least one caption
                filtered_images.append(img)
                filtered_captions.append(caption[0])  # Use the first caption if multiple are present
        
        if not filtered_images:
            raise ValueError("No valid images with captions found in batch.")
        
        # Stack images into a single tensor batch
        image_tensor = torch.stack(filtered_images)
        return image_tensor, filtered_captions

    train_dataset = CocoCaptions(
        root=os.path.join(data_dir, "train2017"),
        annFile=os.path.join(data_dir, "annotations", "captions_train2017.json"),
        transform=transform
    )
    
    val_dataset = CocoCaptions(
        root=os.path.join(data_dir, "val2017"),
        annFile=os.path.join(data_dir, "annotations", "captions_val2017.json"),
        transform=transform
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn
    )
    
    return train_loader, val_loader



In [5]:
def compute_model_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024**2)

In [6]:
def save_checkpoint(student_model, optimizer, epoch, loss, checkpoint_dir, prefix):
    if checkpoint_dir and prefix:
        checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}{epoch + 1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': student_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")


In [7]:
def load_checkpoint(checkpoint_dir, prefix, student_model, optimizer):

    start_epoch = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)

    if checkpoint_dir and os.path.exists(checkpoint_dir):
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path)
            student_model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            # Move optimizer states to the same device as the model
            for state in optimizer.state.values():
                if isinstance(state, torch.Tensor):
                    state.data = state.data.to(device)
                elif isinstance(state, dict):
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to(device)

            start_epoch = checkpoint['epoch']
            print(f"Resuming training from epoch {start_epoch}.")
    return start_epoch

In [8]:
def train_student_model(teacher_model, teacher_processor, student_model, train_loader, optimizer, num_epochs=5, checkpoint_dir=None, prefix=None):
    start_epoch = load_checkpoint(checkpoint_dir, prefix, student_model, optimizer)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    student_model.train()

    for epoch in range(start_epoch, num_epochs):
        print(f"Training epoch {epoch + 1}/{num_epochs}...")
        epoch_loss = 0.0
        for images, captions in train_loader:
            images = images.to('cuda')
            student_model.to('cuda')

            # Teacher model logits
            with torch.no_grad():
                teacher_inputs = teacher_processor(images=images, return_tensors="pt").to('cuda')
                teacher_vision_features = teacher_model.get_image_features(pixel_values=teacher_inputs['pixel_values'])

                random_texts = [caption[0] if caption else "" for caption in captions]

                text_inputs = teacher_processor(text=random_texts, return_tensors="pt", padding=True).to('cuda')

                teacher_text_features = teacher_model.get_text_features(
                    input_ids=text_inputs['input_ids'],
                    attention_mask=text_inputs['attention_mask']
                )

                teacher_logits = teacher_vision_features @ teacher_text_features.t()

            # Student model logits
            student_logits, vision_proj, text_proj = student_model(images, teacher_text_features)

            # Compute distillation loss
            loss = combined_loss(vision_proj, text_proj, student_logits, teacher_logits)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs} Loss: {epoch_loss:.4f}")

        # Save checkpoint after each epoch
        save_checkpoint(student_model, optimizer, epoch, epoch_loss, checkpoint_dir, prefix)

    print(f"Student Model Size: {compute_model_size(student_model):.2f} MB")

In [9]:
from sklearn.metrics import accuracy_score
def evaluate_student_model(student_model, teacher_model, teacher_processor, val_loader, checkpoint_dir=None, prefix=None):
    if checkpoint_dir:
        print(f"Searching for checkpoints in {checkpoint_dir}...")
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir) if f.startswith(prefix) and f.endswith(".pt")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files,
                key=lambda x: int(x[len(prefix):-3])
            )
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path, weights_only=True)
            student_model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Model loaded from checkpoint at epoch {checkpoint['epoch']}.")

    student_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_model.to(device)
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, captions in val_loader:
            images = images.to('cuda')

            teacher_inputs = teacher_processor(images=images, return_tensors="pt").to('cuda')
            teacher_vision_features = teacher_model.get_image_features(pixel_values=teacher_inputs['pixel_values'])

            random_texts = [caption[0] if caption else "" for caption in captions]

            text_inputs = teacher_processor(text=random_texts, return_tensors="pt", padding=True).to('cuda')
            teacher_text_features = teacher_model.get_text_features(
                input_ids=text_inputs['input_ids'],
                attention_mask=text_inputs['attention_mask']
            )

            student_logits, _, _ = student_model(images, teacher_text_features)
            predictions = torch.argmax(student_logits, dim=-1)
            # print(f"Batch predictions: {predictions}")

            all_predictions.extend(predictions.cpu().tolist())
            all_labels.extend([0] * len(predictions))  # Replace [0] with actual labels if available

    # Calculate accuracy
    if all_labels:
        accuracy = accuracy_score(all_labels, all_predictions)
        print(f"Evaluation Accuracy: {accuracy:.4f}")
    else:
        print("No ground truth labels available to calculate accuracy.")

    print(f"Complete the evaluation")

In [11]:
import os


data_dir = "./coco/coco2017"
train_loader, val_loader = init_coco_data_clip(data_dir, batch_size=8)

# Define student model configuration
student_output_dim = 256

# CLIP
clip_teacher_output_dim = 512
clip_student_model = UnifiedStudentModel(vision_dim=student_output_dim, text_dim=student_output_dim, teacher_output_dim=clip_teacher_output_dim)
clip_optimizer = torch.optim.AdamW(clip_student_model.parameters(), lr=5e-5)

from transformers import CLIPModel, CLIPProcessor
clip_teacher_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to('cuda')
clip_teacher_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

print("Training Student Model with CLIP...")
train_student_model(clip_teacher_model, clip_teacher_processor, clip_student_model, train_loader, clip_optimizer, num_epochs=10, checkpoint_dir="./checkpoints_VLM_ResNet_COCO", prefix="clip_student_model_epoch_")

print("Evaluating Student Model with CLIP...")
evaluate_student_model(clip_student_model, clip_teacher_model, clip_teacher_processor, val_loader, checkpoint_dir="./checkpoints_VLM_ResNet_COCO", prefix="clip_student_model_epoch_")


loading annotations into memory...
Done (t=0.93s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Training Student Model with CLIP...
Loading checkpoint from ./checkpoints_VLM_ResNet_COCO/clip_student_model_epoch_5.pt...


  checkpoint = torch.load(checkpoint_path)


Resuming training from epoch 5.
Training epoch 6/10...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 6/10 Loss: 26478.6722
Checkpoint saved at ./checkpoints_VLM_ResNet_COCO/clip_student_model_epoch_6.pt
Training epoch 7/10...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 7/10 Loss: 25218.6552
Checkpoint saved at ./checkpoints_VLM_ResNet_COCO/clip_student_model_epoch_7.pt
Training epoch 8/10...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 8/10 Loss: 24365.2572
Checkpoint saved at ./checkpoints_VLM_ResNet_COCO/clip_student_model_epoch_8.pt
Training epoch 9/10...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 9/10 Loss: 23921.0767
Checkpoint saved at ./checkpoints_VLM_ResNet_COCO/clip_student_model_epoch_9.pt
Training epoch 10/10...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 10/10 Loss: 23608.5956
Checkpoint saved at ./checkpoints_VLM_ResNet_COCO/clip_student_model_epoch_10.pt
Student Model Size: 43.39 MB
Evaluating Student Model with CLIP...
Searching for checkpoints in ./checkpoints_VLM_ResNet_COCO...
Loading checkpoint from ./checkpoints_VLM_ResNet_COCO/clip_student_model_epoch_10.pt...
Model loaded from checkpoint at epoch 10.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Evaluation Accuracy: 0.6462
Complete the evaluation


In [15]:
data_dir = "./coco/coco2017"
student_output_dim = 256

train_loader, val_loader = init_coco_data_blip(data_dir, batch_size=8)
blip_teacher_output_dim = 512
blip_student_model = UnifiedStudentModel(vision_dim=student_output_dim, text_dim=student_output_dim, teacher_output_dim=blip_teacher_output_dim)
blip_optimizer = torch.optim.AdamW(blip_student_model.parameters(), lr=5e-5)

from transformers import BlipModel, BlipProcessor
blip_teacher_model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base").to('cuda')
blip_teacher_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

print("Training Student Model with BLIP...")
train_student_model(blip_teacher_model, blip_teacher_processor, blip_student_model, train_loader, blip_optimizer, num_epochs=10, checkpoint_dir="./checkpoints_VLM_ResNet_COCO", prefix="blip_student_model_epoch_")

print("Evaluating Student Model with BLIP...")
evaluate_student_model(blip_student_model, blip_teacher_model, blip_teacher_processor, val_loader, checkpoint_dir="./checkpoints_VLM_ResNet_COCO", prefix="blip_student_model_epoch_")

loading annotations into memory...
Done (t=0.93s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
`BlipModel` is going to be deprecated in future release, please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase.
Some weights of BlipModel were not initialized from the model checkpoint at Salesforce/blip-image-captioning-base and are newly initialized: ['logit_scale', 'text_model.embeddings.LayerNorm.bias', 'text_model.embeddings.LayerNorm.weight', 'text_model.embeddings.position_embeddings.weight', 'text_model.embeddings.word_embeddings.weight', 'text_model.encoder.layer.0.attention.output.LayerNorm.bias', 'text_model.encoder.layer.0.attention.output.LayerNorm.weight', 'text_model.encoder.layer.0.attention.output.dense.bias', 'text_mod

Training Student Model with BLIP...
Loading checkpoint from ./checkpoints_VLM_ResNet_COCO/blip_student_model_epoch_10.pt...


  checkpoint = torch.load(checkpoint_path)


Resuming training from epoch 10.
Student Model Size: 43.39 MB
Evaluating Student Model with BLIP...
Searching for checkpoints in ./checkpoints_VLM_ResNet_COCO...
Loading checkpoint from ./checkpoints_VLM_ResNet_COCO/blip_student_model_epoch_10.pt...
Model loaded from checkpoint at epoch 10.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Evaluation Accuracy: 0.4504
Complete the evaluation
