In [46]:
import math
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from clip import CLIPModel
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from datasets import load_dataset
from transformers import AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [47]:
captions_col = 'text_EN'
image_col = 'img'    
ds = load_dataset("Attila1011/img_caption_EN_AppleFlair_Blip")

ds = ds['train']

# Split the dataset: take 2000 rows for testing, and the rest will remain for training
train_test_split = ds.train_test_split(test_size=2000, seed=42)

# Access the splits
train_split = train_test_split['train']
test_split = train_test_split['test']


In [48]:
# save the test split to a file
test_split.save_to_disk('test_split')

Saving the dataset (0/1 shards):   0%|          | 0/2000 [00:00<?, ? examples/s]

In [49]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [50]:
class HuggingFaceImageTextDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform
        tokenized_output = tokenizer(self.dataset[captions_col], padding='max_length', truncation=True, max_length=16, return_tensors='pt')

        self.caption_tokens = tokenized_output['input_ids']

        self.attention_mask = tokenized_output['attention_mask']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get the image and the text label
        image = self.dataset[idx][image_col] 
        caption = self.caption_tokens[idx]
        mask = self.attention_mask[idx]

        if np.array(image).shape[-1] != 3:
            image = self.dataset[0][image_col] 
            caption = self.caption_tokens[0]
            mask = self.attention_mask[0]

        # Apply transformations to the image
        if self.transform:
            image = self.transform(image)

        # Return the image tensor and the text (or caption)
        return image, caption, mask

# Example transform (for images)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image
    transforms.ToTensor(),          # Convert image to a tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize (using ImageNet stats)
])

# Instantiate the custom dataset
train_dataset = HuggingFaceImageTextDataset(train_split, transform=transform)
val_dataset = HuggingFaceImageTextDataset(test_split, transform=transform)

# Create DataLoader for batching
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)

In [55]:
model_params = {
    'embed_dim': 256,
    'img_embed_dim': 512,
    'patch_size': 16,
    'image_size': 224,
    'num_layers': 3,
    'num_heads': 4,
    'mlp_ratio': 4,
}

model = CLIPModel(**model_params)

def initialize_weights(m):
    if isinstance(m, nn.Linear):
        # Xavier/Glorot initialization for linear layers
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)  # Initialize biases to 0
    elif isinstance(m, nn.Conv2d):
        # Kaiming initialization for convolutional layers
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        # Initialize LayerNorm to standard normal
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

model.apply(initialize_weights)

CLIPModel(
  (image_embeddings): ImageEmbeddings(
    (patch_embedding): Conv2d(3, 512, kernel_size=(16, 16), stride=(16, 16), padding=valid)
    (position_embedding): Embedding(196, 512)
  )
  (image_encoder): ImageEncoder(
    (layers): ModuleList(
      (0-2): 3 x ImageEncoderLayer(
        (norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (wq): Linear(in_features=512, out_features=512, bias=False)
          (wk): Linear(in_features=512, out_features=512, bias=False)
          (wv): Linear(in_features=512, out_features=512, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (wo): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU(approximate='none')
          (fc2):

In [56]:
def get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    num_cycles: float = 0.5,
    min_lr_ratio: float = 0.01
):
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.
    """
    def lr_lambda(current_step):
        # Warmup
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        
        # Cosine decay
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        cosine_decay = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
        
        # Ensure we don't go below min_lr_ratio
        return max(min_lr_ratio, cosine_decay)

    return LambdaLR(optimizer, lr_lambda)

def train_clip(
    model,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    num_epochs: int = 20,
    learning_rate: float = 1e-4,
    weight_decay: float = 0.1,
    warmup_ratio: float = 0.1,
    min_lr_ratio: float = 0.01,
    device: str = 'cuda'
):
    avg_train_losses = []
    avg_val_losses = []
    # Move model to device
    model.to(device)
    
    # Calculate total steps
    total_steps = len(train_dataloader) * num_epochs
    warmup_steps = int(total_steps * warmup_ratio)
    
    # Initialize optimizer
    optimizer = AdamW(
        model.parameters(),
        lr=learning_rate,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=weight_decay
    )
    
    # Initialize scheduler
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
        min_lr_ratio=min_lr_ratio
    )
    
    # Training loop
    best_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        losses = []
        
        # Progress bar for batches
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for  images, captions, mask in progress_bar:
            # Zero the gradients
            optimizer.zero_grad()
            
            # Move data to device
            images = images.to(device)
            captions = captions.to(device)
            if mask is not None:
                mask = mask.to(device)
            
            # Forward pass
            image_embeddings, text_embeddings = model(images, captions, mask)
            
            # Compute loss
            loss = model.clip_loss(image_embeddings, text_embeddings)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Optimizer step
            optimizer.step()
            
            # Scheduler step
            scheduler.step()
            
            # Record loss
            losses.append(loss.item())
            
            # Update progress bar
            current_lr = scheduler.get_last_lr()[0]
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'lr': f"{current_lr:.2e}"
            })
        
        # Calculate average loss for the epoch
        avg_train_loss = sum(losses) / len(losses)

        # Validation loop
        model.eval()
        val_losses = []

        with torch.no_grad():
            for images, captions, mask in tqdm(val_dataloader, desc="Validation Loop"):
                # Move data to device
                images = images.to(device)
                captions = captions.to(device)
                if mask is not None:
                    mask = mask.to(device)
                
                # Forward pass
                image_embeddings, text_embeddings = model(images, captions, mask)
                
                # Compute loss
                loss = model.clip_loss(image_embeddings, text_embeddings)
                
                # Record loss
                val_losses.append(loss.item())

        # Calculate average validation loss
        avg_val_loss = sum(val_losses) / len(val_losses)
        
        # Save best model
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': best_loss,
                'model_params': model_params
            }, 'best_clip_model.pth')
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Average Train Loss: {avg_train_loss:.4f}")
        print(f"Average Validation Loss: {avg_val_loss:.4f}")
        print(f"Learning Rate: {scheduler.get_last_lr()[0]:.2e}")
        print("-" * 50)

        avg_train_losses.append(avg_train_loss)
        avg_val_losses.append(avg_val_loss)
    return avg_train_losses, avg_val_losses


In [57]:
train_loss, val_loss = train_clip(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    num_epochs=10,
    learning_rate=1e-4,
    weight_decay=0.1,
    warmup_ratio=0.1,
    device=device
)

Epoch 1/10: 100%|██████████| 563/563 [02:08<00:00,  4.39it/s, loss=2.2053, lr=1.00e-04]
Validation Loop: 100%|██████████| 63/63 [00:08<00:00,  7.52it/s]


Epoch 1/10
Average Train Loss: 3.2760
Average Validation Loss: 3.0762
Learning Rate: 1.00e-04
--------------------------------------------------


Epoch 2/10: 100%|██████████| 563/563 [02:10<00:00,  4.32it/s, loss=2.1652, lr=9.70e-05]
Validation Loop: 100%|██████████| 63/63 [00:08<00:00,  7.23it/s]


Epoch 2/10
Average Train Loss: 2.9630
Average Validation Loss: 2.8288
Learning Rate: 9.70e-05
--------------------------------------------------


Epoch 3/10: 100%|██████████| 563/563 [02:11<00:00,  4.27it/s, loss=1.8403, lr=8.83e-05]
Validation Loop: 100%|██████████| 63/63 [00:08<00:00,  7.52it/s]


Epoch 3/10
Average Train Loss: 2.7236
Average Validation Loss: 2.7497
Learning Rate: 8.83e-05
--------------------------------------------------


Epoch 4/10: 100%|██████████| 563/563 [02:15<00:00,  4.16it/s, loss=2.3180, lr=7.50e-05]
Validation Loop: 100%|██████████| 63/63 [00:08<00:00,  7.25it/s]


Epoch 4/10
Average Train Loss: 2.5562
Average Validation Loss: 2.6685
Learning Rate: 7.50e-05
--------------------------------------------------


Epoch 5/10: 100%|██████████| 563/563 [02:17<00:00,  4.10it/s, loss=2.0332, lr=5.87e-05]
Validation Loop: 100%|██████████| 63/63 [00:09<00:00,  6.84it/s]


Epoch 5/10
Average Train Loss: 2.4025
Average Validation Loss: 2.6494
Learning Rate: 5.87e-05
--------------------------------------------------


Epoch 6/10: 100%|██████████| 563/563 [02:17<00:00,  4.10it/s, loss=1.7398, lr=4.13e-05]
Validation Loop: 100%|██████████| 63/63 [00:08<00:00,  7.04it/s]


Epoch 6/10
Average Train Loss: 2.2595
Average Validation Loss: 2.6569
Learning Rate: 4.13e-05
--------------------------------------------------


Epoch 7/10: 100%|██████████| 563/563 [02:18<00:00,  4.06it/s, loss=1.3657, lr=2.50e-05]
Validation Loop: 100%|██████████| 63/63 [00:08<00:00,  7.19it/s]


Epoch 7/10
Average Train Loss: 2.1296
Average Validation Loss: 2.6698
Learning Rate: 2.50e-05
--------------------------------------------------


Epoch 8/10: 100%|██████████| 563/563 [02:19<00:00,  4.02it/s, loss=1.6164, lr=1.17e-05]
Validation Loop: 100%|██████████| 63/63 [00:08<00:00,  7.09it/s]


Epoch 8/10
Average Train Loss: 2.0062
Average Validation Loss: 2.6931
Learning Rate: 1.17e-05
--------------------------------------------------


Epoch 9/10: 100%|██████████| 563/563 [02:21<00:00,  3.99it/s, loss=1.1964, lr=3.02e-06]
Validation Loop: 100%|██████████| 63/63 [00:09<00:00,  6.89it/s]


Epoch 9/10
Average Train Loss: 1.9203
Average Validation Loss: 2.7532
Learning Rate: 3.02e-06
--------------------------------------------------


Epoch 10/10: 100%|██████████| 563/563 [02:18<00:00,  4.06it/s, loss=0.9494, lr=1.00e-06]
Validation Loop: 100%|██████████| 63/63 [00:09<00:00,  7.00it/s]

Epoch 10/10
Average Train Loss: 1.8755
Average Validation Loss: 2.7671
Learning Rate: 1.00e-06
--------------------------------------------------





In [54]:
def plot_losses(train_losses, val_losses):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(8, 6))

    # Plotting the losses
    plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)

    # Adding titles and labels
    plt.title('Training and Validation Loss Over Epochs', fontsize=14)
    plt.xlabel('Epochs', fontsize=12)
    plt.ylabel('Loss', fontsize=12)

    # Adding a grid
    plt.grid(True, linestyle='--', alpha=0.7)

    # Adding a legend
    plt.legend(loc='upper right', fontsize=12)

    # Setting limits for better visualization
    plt.xlim(1, len(train_losses))
    plt.ylim(min(train_losses + val_losses) * 0.95, max(train_losses + val_losses) * 1.05)

    # Show the plot
    plt.tight_layout()
    plt.show()
