In [1]:
%%capture
!pip install open_clip_torch torch torchvision torchaudio pandas scikit-learn tqdm ftfy regex

In [2]:
import torch
import open_clip
import pandas as pd
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms # We might use CLIP's specific transforms
from tqdm.notebook import tqdm
import torch.nn.functional as F
from sklearn.model_selection import train_test_split # If train_test_split.txt is not used
from sklearn.metrics import accuracy_score
import re
import torch.optim as optim
# Use either transformers scheduler or implement manually. Transformers is easier.
# !pip install -q transformers # Add this if transformers not installed in Cell 3
from transformers import get_cosine_schedule_with_warmup

In [3]:
DATA_DIR = "/kaggle/input/cub2002011/CUB_200_2011/"
IMAGE_DIR = os.path.join(DATA_DIR, "images")

# Load image paths and IDs
images_df = pd.read_csv(os.path.join(DATA_DIR, 'images.txt'), sep=' ', names=['img_id', 'filepath'])

# Load image class labels
image_class_labels_df = pd.read_csv(os.path.join(DATA_DIR, 'image_class_labels.txt'), sep=' ', names=['img_id', 'class_id'])

# Load class names
classes_df = pd.read_csv(os.path.join(DATA_DIR, 'classes.txt'), sep=' ', names=['class_id', 'classname'])
# Preprocess class names (e.g., "001.Black_footed_Albatross" -> "black footed albatross")
classes_df['classname'] = classes_df['classname'].apply(lambda x: x.split('.')[-1].replace('_', ' ').lower())

# Load train/test split
train_test_split_df = pd.read_csv(os.path.join(DATA_DIR, 'train_test_split.txt'), sep=' ', names=['img_id', 'is_training_img'])

# Merge dataframes
data_df = images_df.merge(image_class_labels_df, on='img_id')
data_df = data_df.merge(classes_df, on='class_id')
data_df = data_df.merge(train_test_split_df, on='img_id')

# Split into train and test sets
train_df = data_df[data_df['is_training_img'] == 1].reset_index(drop=True)
test_df = data_df[data_df['is_training_img'] == 0].reset_index(drop=True)

print(f"Total images: {len(data_df)}")
print(f"Training images: {len(train_df)}")
print(f"Testing images: {len(test_df)}")
print(f"Number of classes: {classes_df['class_id'].nunique()}")
data_df.head()

Total images: 11788
Training images: 5994
Testing images: 5794
Number of classes: 200


Unnamed: 0,img_id,filepath,class_id,classname,is_training_img
0,1,001.Black_footed_Albatross/Black_Footed_Albatr...,1,black footed albatross,0
1,2,001.Black_footed_Albatross/Black_Footed_Albatr...,1,black footed albatross,1
2,3,001.Black_footed_Albatross/Black_Footed_Albatr...,1,black footed albatross,0
3,4,001.Black_footed_Albatross/Black_Footed_Albatr...,1,black footed albatross,1
4,5,001.Black_footed_Albatross/Black_Footed_Albatr...,1,black footed albatross,1


In [4]:
# Simple templates used in many CLIP papers
prompt_templates = [
    'a photo of a {}.',
    'a bad photo of a {}.',
    'a photo of the {}.',
    'a picture of a {}.',
    'a bird called {}.',
    'the bird {}.',
    'a type of bird called {}.',
    '{} bird.',
]

def create_prompt(class_name, template):
    return template.format(class_name)

# Create prompts for all classes for zero-shot evaluation
all_class_names = classes_df['classname'].tolist()
num_classes = len(all_class_names)
zeroshot_prompts = [create_prompt(c, prompt_templates[0]) for c in all_class_names]
print(f"Example prompt: {zeroshot_prompts[0]}")

Example prompt: a photo of a black footed albatross.


In [5]:
class CUBDataset(Dataset):
    def __init__(self, df, image_dir, transform, tokenizer, templates, augment_templates=False):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.templates = templates
        self.augment_templates = augment_templates # Flag for template augmentation

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_dir, row['filepath'])
        class_name = row['classname']
        class_id = row['class_id'] - 1 # 0-indexed class ID

        try:
            image = Image.open(img_path).convert('RGB')
            image_tensor = self.transform(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}. Using placeholder.")
            # Return a dummy tensor or handle appropriately
            # Try loading a different image or return None to skip in collate_fn
            try: # Attempt to load the first image as a fallback
                 placeholder_path = os.path.join(self.image_dir, self.df.iloc[0]['filepath'])
                 image = Image.open(placeholder_path).convert('RGB')
                 image_tensor = self.transform(image)
            except: # If even that fails, return zeros
                 image_tensor = torch.zeros((3, 224, 224)) # Adjust size if needed

        # Create prompt for this specific image
        # For FLYP-style augmentation, choose a random template during training
        if self.augment_templates and self.transform == preprocess_train: # Only augment for training
             template = np.random.choice(self.templates)
        else:
             template = self.templates[0] # Default template otherwise

        text_prompt = create_prompt(class_name, template)
        # Note: open_clip tokenizer returns a tensor batch; we need the first element [0]
        tokenized_text = self.tokenizer(text_prompt)[0]

        return image_tensor, tokenized_text

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Model Setup ---
model_name = 'ViT-B-32'
pretrained_dataset = 'laion2b_s34b_b79k'

model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
    model_name,
    pretrained=pretrained_dataset,
    device=device,
    jit=False
)
tokenizer = open_clip.get_tokenizer(model_name)

for param in model.parameters():
    param.requires_grad = True

# --- Loss Function ---
loss_fn = open_clip.ClipLoss(
    local_loss=False,
    gather_with_grad=False,
    cache_labels=True,
    rank=0,
    world_size=1
)

# --- Optimizer ---
# Try a potentially lower LR for fine-tuning
learning_rate = 5e-7 # Lowered LR
weight_decay = 0.1
betas = (0.9, 0.98)
eps = 1e-6

params_to_optimize = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.AdamW(
    params_to_optimize,
    lr=learning_rate,
    betas=betas,
    eps=eps,
    weight_decay=weight_decay
)

# --- Scheduler ---
# Need total steps: Define num_epochs and calculate *after* creating train_loader
num_epochs = 10 # Increased epochs
# Placeholder for total_steps, will be calculated after train_loader is ready
# total_steps = len(train_loader) * num_epochs
warmup_steps = 100 # Number of warmup steps (e.g., ~1 epoch or fixed number)
# scheduler will be created *after* train_loader in the next cell block
scheduler = None

Using device: cuda


open_clip_model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

In [7]:
batch_size = 64 # Adjust based on GPU memory
num_workers = 2 # Adjust based on system capability

# --- Training Dataset ---
# Enable template augmentation for training
train_dataset = CUBDataset(train_df, IMAGE_DIR, preprocess_train, tokenizer, prompt_templates, augment_templates=True)

# Define a simple collate function to handle potential None items from dataset
def safe_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return None # Return None if the whole batch is problematic
    try:
        return torch.utils.data.dataloader.default_collate(batch)
    except Exception as e:
        print(f"Error in collate_fn: {e}")
        return None # Return None on collation error

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True, collate_fn=safe_collate)

# --- Test Dataset for Evaluation (Zero-shot style) ---
class CUBTestDataset(Dataset):
    def __init__(self, df, image_dir, transform):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_dir, row['filepath'])
        class_id = row['class_id'] - 1 # 0-indexed

        try:
            image = Image.open(img_path).convert('RGB')
            image_tensor = self.transform(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}. Using placeholder.")
            # Fallback logic
            try:
                placeholder_path = os.path.join(self.image_dir, self.df.iloc[0]['filepath'])
                image = Image.open(placeholder_path).convert('RGB')
                image_tensor = self.transform(image)
            except:
                image_tensor = torch.zeros((3, 224, 224))

        return image_tensor, torch.tensor(class_id, dtype=torch.long)

test_dataset_eval = CUBTestDataset(test_df, IMAGE_DIR, preprocess_val)
test_loader_eval = DataLoader(test_dataset_eval, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=safe_collate)

print(f"Train loader size: {len(train_loader)} batches")
print(f"Test loader size: {len(test_loader_eval)} batches")

# --- Create Scheduler (Now that train_loader is defined) ---
if len(train_loader) > 0:
    total_steps = len(train_loader) * num_epochs
    print(f"Total training steps: {total_steps}")
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
else:
    print("Warning: Train loader has zero length. Cannot create scheduler.")
    total_steps = 0
    scheduler = None

Train loader size: 93 batches
Test loader size: 91 batches
Total training steps: 930


In [8]:
def train_one_epoch(model, loader, loss_fn, optimizer, device, epoch, scheduler=None, grad_clip_norm=1.0): # Added grad_clip_norm
    model.train()
    # Check if loader is valid
    if not loader:
        print(f"Skipping training for epoch {epoch} due to invalid loader.")
        return 0.0

    pbar = tqdm(loader, desc=f"Epoch {epoch} Training")
    total_loss = 0.0
    steps = 0

    for i, batch in enumerate(pbar):
        # Handle potential errors from dataset loading/collation
        if batch is None:
            print(f"Skipping problematic batch {i}")
            continue
        images, texts = batch
        images = images.to(device, non_blocking=True)
        texts = texts.to(device, non_blocking=True)

        optimizer.zero_grad()

        # Forward pass
        try:
            image_features, text_features, logit_scale = model(images, texts)
        except Exception as e:
            print(f"Error during model forward pass in batch {i}: {e}")
            continue # Skip batch on forward error

        # Ensure features are valid
        if image_features is None or text_features is None:
             print(f"Skipping batch {i} due to None features")
             continue

        # Calculate contrastive loss (pass raw logit_scale)
        loss = loss_fn(image_features, text_features, logit_scale) # Use raw logit_scale

        loss.backward()

        # Gradient Clipping
        if grad_clip_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)

        optimizer.step()

        if scheduler:
            scheduler.step() # Step scheduler

        total_loss += loss.item()
        steps += 1
        current_lr = optimizer.param_groups[0]['lr'] # Get current LR
        pbar.set_postfix({"Loss": loss.item(), "Avg Loss": total_loss / steps, "LR": f"{current_lr:.2e}"}) # Display LR

    avg_loss = total_loss / steps if steps > 0 else 0
    print(f"Epoch {epoch} Average Training Loss: {avg_loss:.4f}")
    return avg_loss

In [9]:
# Evaluation function remains the same as in the previous step
# Ensure it uses preprocess_val for the test_dataset_eval
def evaluate_zeroshot(model, loader, tokenizer, all_class_names, device): # Changed arg to all_class_names
    model.eval()
    all_image_features = []
    all_labels = []

    # Encode all class prompts once
    print("Encoding text prompts for evaluation...")
    with torch.no_grad():
        # Use the first template for zero-shot evaluation consistency
        prompts_for_eval = [create_prompt(c, prompt_templates[0]) for c in all_class_names]
        text_inputs = tokenizer(prompts_for_eval).to(device)
        class_text_features = model.encode_text(text_inputs)
        class_text_features = F.normalize(class_text_features, dim=-1)
    print("Text prompts encoded.")

    # Check if loader is valid
    if not loader:
        print("Evaluation loader is invalid. Skipping evaluation.")
        return 0.0

    pbar = tqdm(loader, desc="Evaluating")
    with torch.no_grad():
        for i, batch in enumerate(pbar):
             # Handle potential errors from dataset loading/collation
            if batch is None:
                print(f"Skipping problematic batch {i} in evaluation")
                continue
            images, labels = batch
            images = images.to(device, non_blocking=True)

            try:
                image_features = model.encode_image(images)
                if image_features is None:
                   print(f"Skipping batch {i} due to None image features")
                   continue
                image_features = F.normalize(image_features, dim=-1)

                all_image_features.append(image_features.cpu())
                all_labels.append(labels.cpu())
            except Exception as e:
                print(f"Error during image encoding in batch {i}: {e}")
                continue

    if not all_image_features:
        print("No image features were collected for evaluation.")
        return 0.0

    all_image_features = torch.cat(all_image_features)
    all_labels = torch.cat(all_labels)

    # Calculate similarities and predict
    try:
        logit_scale = model.logit_scale.exp().item() # Get current temperature
    except AttributeError:
         # Fallback if logit_scale is not learnable or directly accessible
         print("Warning: Could not get learnable logit_scale. Using default 100.")
         logit_scale = 100.0 # Default CLIP value

    similarities = (logit_scale * all_image_features.to(device) @ class_text_features.T).softmax(dim=-1)
    _, predictions = similarities.topk(1, dim=-1)
    predictions = predictions.squeeze().cpu()

    accuracy = accuracy_score(all_labels.numpy(), predictions.numpy())
    print(f"Zero-shot Accuracy: {accuracy * 100:.2f}%")
    return accuracy

In [None]:
num_epochs = 15 # Increased epochs
best_accuracy = 0.0
output_dir = "/kaggle/working/"

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

print("--- Initial Zero-Shot Evaluation ---")
# Pass all_class_names instead of pre-generated prompts
initial_accuracy = evaluate_zeroshot(model, test_loader_eval, tokenizer, all_class_names, device)
print("-----------------------------------\n")
best_accuracy = initial_accuracy # Initialize best accuracy with the pre-fine-tuning score

for epoch in range(num_epochs):
    epoch_num = epoch + 1
    print(f"\n--- Starting Epoch {epoch_num}/{num_epochs} ---")
    # Pass scheduler and grad clip norm
    train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, device, epoch_num, scheduler, grad_clip_norm=1.0)

    print(f"--- Evaluating after Epoch {epoch_num} ---")
    # Pass all_class_names instead of pre-generated prompts
    current_accuracy = evaluate_zeroshot(model, test_loader_eval, tokenizer, all_class_names, device)

    # Only save if accuracy *improves* compared to the current best
    if current_accuracy > best_accuracy:
        best_accuracy = current_accuracy
        print(f"*** New best accuracy: {best_accuracy*100:.2f}%. Saving model... ***")
        save_path = os.path.join(output_dir, f'flyp_clip_cub_best_epoch_{epoch_num}_acc{best_accuracy:.3f}.pt')
        # Save model state for potential resuming or inference
        torch.save({
            'epoch': epoch_num,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss, # Last epoch's train loss
            'accuracy': best_accuracy, # Best accuracy so far
            'model_name': model_name,
            'pretrained_dataset': pretrained_dataset
        }, save_path)
        print(f"Model saved to {save_path}")
    else:
         print(f"Accuracy ({current_accuracy*100:.2f}%) did not improve from best ({best_accuracy*100:.2f}%). Not saving.")

    print("-----------------------------------\n")

print(f"\nFine-tuning finished. Best zero-shot accuracy achieved: {best_accuracy*100:.2f}%")

--- Initial Zero-Shot Evaluation ---
Encoding text prompts for evaluation...
Text prompts encoded.


Evaluating:   0%|          | 0/91 [00:00<?, ?it/s]