In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoImageProcessor
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from datasets import load_dataset
from metrics import AlignmentMetrics  # Ensure this is correctly implemented/imported
import itertools
import numpy as np
import gc
import logging
from multiprocessing import Pool, cpu_count
from functools import partial
import timm
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision

# -------------------- Setup Logging --------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# -------------------- Configuration --------------------
# Configuration for the Vision Model (ViT)
VIT_MODEL_NAME = "timm/vit_base_patch16_224"
VIT_FEATURES_DIR = "vit_features"
VIT_CHECKPOINT_DIR = "./vit_checkpoints"
BATCH_SIZE = 128
NUM_EPOCHS = 100
SAVE_EVERY = 5  # Save every 5 epochs to ensure at least 20 checkpoints

# Configuration for the Language Model (LLM)
LLM_FEATURES_DIR = "llm_features"
LLM_MODEL_CHECKPOINTS = {
    "pythia-160m": [
        "step0", "step1000", "step8000", "step15000", "step22000", "step29000", "step36000",
        "step43000", "step50000", "step57000", "step64000", "step71000", "step78000",
        "step85000", "step92000", "step99000", "step106000", "step113000", "step120000",
        "step127000", "step134000", "step143000"
    ]
}

# Create directories if they do not exist
os.makedirs(VIT_FEATURES_DIR, exist_ok=True)
os.makedirs(LLM_FEATURES_DIR, exist_ok=True)
os.makedirs(VIT_CHECKPOINT_DIR, exist_ok=True)

# Load an image-captioning dataset (Conceptual Captions)
logger.info("Loading Conceptual Captions dataset for alignment...")
dataset = load_dataset("google-research-datasets/conceptual_captions", "labeled", split="validation[:1000]")  # Using 1000 samples for faster processing
images = dataset["image"]
captions = dataset["caption"]
logger.info(f"Loaded {len(images)} image-caption pairs.")

# -------------------- Feature Extraction Functions --------------------

def extract_vit_features(model, processor, images, batch_size=64):
    """Extract last hidden state features from the ViT model for given images."""
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    features = []

    with torch.no_grad():
        for i in tqdm(range(0, len(images), batch_size), desc="Extracting VIT features", leave=False):
            batch_images = images[i:i + batch_size]
            inputs = processor(batch_images, return_tensors="pt", padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs, output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1]
            cls_tokens = last_hidden_state[:, 0, :]  # Extract CLS token
            features.append(cls_tokens.cpu())

    features = torch.cat(features, dim=0)
    return features

def extract_llm_features(model, tokenizer, captions, batch_size=64):
    """Extract last hidden state features from the LLM for given captions."""
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    features = []

    with torch.no_grad():
        for i in tqdm(range(0, len(captions), batch_size), desc="Extracting LLM features", leave=False):
            batch_captions = captions[i:i + batch_size]
            inputs = tokenizer(batch_captions, return_tensors="pt", padding=True, truncation=True, max_length=128)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs, output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1]
            cls_tokens = last_hidden_state[:, 0, :]  # Extract CLS token
            features.append(cls_tokens.cpu())

    features = torch.cat(features, dim=0)
    return features

# -------------------- Training the ViT Model on CIFAR-10 --------------------

def train_vit_on_cifar():
    """Train ViT on CIFAR-10 and save checkpoints."""
    model = timm.create_model(VIT_MODEL_NAME, pretrained=True, num_classes=10)
    processor = AutoImageProcessor.from_pretrained(VIT_MODEL_NAME)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Load CIFAR-10 dataset
    transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
    train_dataset = torchvision.datasets.CIFAR10(root="./cifar_data", train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

    # Define optimizer and loss function
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
    
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Save checkpoint every SAVE_EVERY epochs
        if (epoch + 1) % SAVE_EVERY == 0 or (epoch + 1) == NUM_EPOCHS:
            checkpoint_path = os.path.join(VIT_CHECKPOINT_DIR, f"vit_epoch_{epoch + 1}.pth")
            torch.save(model.state_dict(), checkpoint_path)
            logger.info(f"Checkpoint saved: {checkpoint_path}")
        
        logger.info(f"Epoch [{epoch + 1}/{NUM_EPOCHS}], Loss: {total_loss:.4f}")

# -------------------- Compute Alignment Scores --------------------

def compute_alignment_between_vit_llm(vit_features, llm_features):
    """Compute alignment metrics between ViT and LLM features."""
    alignment_score = AlignmentMetrics.mutual_knn(vit_features, llm_features, topk=50)
    return alignment_score

def process_alignment_scores(vit_checkpoints, llm_model_checkpoints, images, captions):
    """Extract features from each model and compute alignment scores."""
    alignment_records = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Extract ViT features from saved checkpoints
    processor = AutoImageProcessor.from_pretrained(VIT_MODEL_NAME)
    for vit_epoch in vit_checkpoints:
        checkpoint_path = os.path.join(VIT_CHECKPOINT_DIR, f"vit_epoch_{vit_epoch}.pth")
        model = timm.create_model(VIT_MODEL_NAME, pretrained=False, num_classes=10)
        model.load_state_dict(torch.load(checkpoint_path))
        vit_features = extract_vit_features(model, processor, images)
        vit_save_path = os.path.join(VIT_FEATURES_DIR, f"vit_epoch_{vit_epoch}.pt")
        torch.save(vit_features, vit_save_path)

        for llm_name, checkpoints in llm_model_checkpoints.items():
            for checkpoint in checkpoints:
                # Load LLM and extract features
                llm_model = AutoModelForCausalLM.from_pretrained(
                    f"EleutherAI/{llm_name}", revision=checkpoint
                ).to(device)
                tokenizer = AutoTokenizer.from_pretrained(f"EleutherAI/{llm_name}")
                if tokenizer.pad_token is None:
                    tokenizer.pad_token = tokenizer.eos_token
                llm_features = extract_llm_features(llm_model, tokenizer, captions)
                
                # Compute alignment
                alignment_score = compute_alignment_between_vit_llm(vit_features, llm_features)
                alignment_records.append({
                    'VIT Epoch': vit_epoch,
                    'LLM Name': llm_name,
                    'LLM Checkpoint': checkpoint,
                    'Alignment Score': alignment_score
                })
                logger.info(f"Alignment score between ViT (epoch {vit_epoch}) and {llm_name} ({checkpoint}): {alignment_score}")

    return pd.DataFrame(alignment_records)

# -------------------- Plotting the Results --------------------

def plot_alignment_scores(alignment_df):
    """Plot alignment scores over training steps."""
    plt.figure(figsize=(15, 8))
    sns.lineplot(
        data=alignment_df,
        x='LLM Checkpoint',
        y='Alignment Score',
        hue='VIT Epoch',
        marker='o'
    )
    plt.title('Alignment Scores between ViT and LLM Checkpoints')
    plt.xlabel('LLM Training Step')
    plt.ylabel('Alignment Score')
    plt.grid(True)
    plt.legend(title='ViT Epoch')
    plt.tight_layout()
    plt.show()

# -------------------- Main Execution --------------------

if __name__ == "__main__":
    train_vit_on_cifar()
    vit_checkpoints = list(range(5, NUM_EPOCHS + 1, SAVE_EVERY))  # Epoch checkpoints for ViT
    alignment_df = process_alignment_scores(vit_checkpoints, LLM_MODEL_CHECKPOINTS, images, captions)
    alignment_df.to_csv('alignment_scores_vit_llm.csv', index=False)
    plot_alignment_scores(alignment_df)


2024-10-25 23:25:00,965 [INFO] Loading Conceptual Captions dataset for alignment...


README.md:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/178M [00:00<?, ?B/s]