In [2]:
# %%
# =========================
# Import Necessary Libraries
# =========================
from utils.utils import CaptionGenerator
from utils.dataset import get_dataset
import pprint
from tqdm import tqdm
import os
from PIL import Image
import torch
from transformers import CLIPModel, CLIPTokenizer, CLIPFeatureExtractor
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import accuracy_score
from collections import defaultdict
import torch.nn.functional as F
import json

# %%
# =========================
# Configuration Parameters
# =========================
DATASET_NAME = 'NWPU-RESISC45'
MODEL_NAME = "openai/clip-vit-base-patch32"
BATCH_SIZE = 64         # Adjust based on your GPU memory
KNN_BATCH_SIZE = 1024   # For zero-shot accuracy computation
NUM_CAPTIONS = 32       # Number of alternative captions per class

# %%
# =========================
# Load and Prepare Dataset
# =========================
from torchvision.datasets import ImageFolder

# Load dataset
cub_dataset = get_dataset(DATASET_NAME)
n_object = len(cub_dataset.classes)
print(f"Number of classes: {n_object}")
n_classes = cub_dataset.classes[:n_object]
pprint.pprint(n_classes)
# Update class_to_idx mapping for quick lookup
cub_dataset.class_to_idx = {class_name: idx for idx, class_name in enumerate(cub_dataset.classes)}

# %%
# =========================
# Generate Prompts and Captions
# =========================

# Generate standard prompts
standard_prompts = [f"a photo of a {class_name}" for class_name in n_classes]
standard_labels = n_classes  # Labels for standard prompts

# Initialize CaptionGenerator and generate alternative captions
capGenerator = CaptionGenerator(dataset_name=DATASET_NAME, class_names=n_classes, num_captions=NUM_CAPTIONS)

alter_caption_list = []
labels = []
for class_name in tqdm(n_classes, desc="Generating Alternative Captions"):
    alter_captions = capGenerator.get_alternative_captions(class_name)
    alter_caption_list.extend(alter_captions)
    labels.extend([class_name] * len(alter_captions))

# %%
# =========================
# Initialize CLIP Model and Tokenizers
# =========================
from utils.model_utils import get_text_embeddings, get_image_embeddings, load_model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_info = load_model(MODEL_NAME)

# %%
# =========================
# Compute Text Embeddings
# =========================
# Get text embeddings for alternative captions
text_embeddings = get_text_embeddings(alter_caption_list, model_info, device)
text_labels = labels  # Labels for text embeddings

# Create text_data entries
text_data = [
    {'embedding': emb, 'label': label, 'modality': 'text'}
    for emb, label in zip(text_embeddings, text_labels)
]

# Get text embeddings for standard prompts
standard_embeddings = get_text_embeddings(standard_prompts, model_info, device)
standard_data = [
    {'embedding': emb, 'label': label, 'modality': 'standard'}
    for emb, label in zip(standard_embeddings, standard_labels)
]

# %%
# =========================
# Compute Image Embeddings for Real Images
# =========================
def get_image_loader(dataset, class_indices, batch_size, shuffle=False):
    if hasattr(dataset, 'targets'):
        selected_indices = [i for i, label in enumerate(dataset.targets) if label in class_indices]
    else:
        selected_indices = [i for i, (_, label) in enumerate(dataset.imgs) if label in class_indices]
    selected_dataset = Subset(dataset, selected_indices)

    loader = DataLoader(
        selected_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])  # Exclude paths
    )
    return loader

selected_class_indices = [cub_dataset.class_to_idx[class_name] for class_name in n_classes]
image_loader = get_image_loader(cub_dataset, selected_class_indices, BATCH_SIZE, shuffle=False)

# Get image embeddings and labels for real images
image_embeddings, image_labels = get_image_embeddings(image_loader, model_info, device)

# Map integer labels to class names
image_labels = [cub_dataset.classes[label] for label in image_labels]

# Create image_data entries
image_data = [
    {'embedding': emb, 'label': label, 'modality': 'image'}
    for emb, label in zip(image_embeddings, image_labels)
]

# %%
# =========================
# Combine All Data
# =========================
combined_data = text_data + standard_data + image_data

combined_embeddings = np.array([item['embedding'] for item in combined_data])
combined_labels = [item['label'] for item in combined_data]
combined_modalities = [item['modality'] for item in combined_data]

# %%
# =========================
# Compute Zero-Shot Accuracy Function
# =========================
def compute_zero_shot_accuracy(image_embeddings, image_labels, text_embeddings, text_labels, batch_size=1024):
    """
    Compute zero-shot accuracy by matching image embeddings to text embeddings.

    Args:
        image_embeddings (np.ndarray): Image embeddings of shape [num_images, embedding_dim].
        image_labels (List[str]): True labels for each image.
        text_embeddings (np.ndarray): Text embeddings of shape [num_text_embeddings, embedding_dim].
        text_labels (List[str]): Labels corresponding to each text embedding.
        batch_size (int): Batch size for processing.

    Returns:
        Tuple[dict, float, dict, dict, List[dict]]: Per-class accuracy, overall accuracy, class_correct counts,
                                                     class_total counts, per-image predictions.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_features = torch.from_numpy(image_embeddings).to(device)
    text_features = torch.from_numpy(text_embeddings).to(device)

    # Normalize embeddings
    image_features = F.normalize(image_features, p=2, dim=-1)
    text_features = F.normalize(text_features, p=2, dim=-1)

    num_images = image_features.size(0)
    predicted_labels = []

    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    per_image_predictions = []

    # Build mapping from class names to indices of text embeddings
    class_to_text_indices = defaultdict(list)
    for idx, class_name in enumerate(text_labels):
        class_to_text_indices[class_name].append(idx)

    # Prepare tensor of class names
    class_names = list(class_to_text_indices.keys())

    # For each class, get the indices of text embeddings
    class_text_indices = [torch.tensor(class_to_text_indices[class_name], device=device) for class_name in class_names]

    # We will compute per-class similarities by aggregating similarities

    for start in tqdm(range(0, num_images, batch_size), desc="Computing Zero-Shot Predictions"):
        end = min(start + batch_size, num_images)
        batch_image = image_features[start:end]  # Shape: [batch_size, embedding_dim]

        # Compute similarities between batch images and text embeddings
        similarities = batch_image @ text_features.T  # Shape: [batch_size, num_text_embeddings]

        # For each class, aggregate similarities
        batch_class_similarities = []
        for indices in class_text_indices:
            class_similarities = similarities[:, indices]  # Shape: [batch_size, num_texts_in_class]
            # Aggregate similarities, e.g., by taking mean
            class_similarity = class_similarities.mean(dim=1)  # Shape: [batch_size]
            batch_class_similarities.append(class_similarity)
        # Stack to get tensor of shape [batch_size, num_classes]
        batch_class_similarities = torch.stack(batch_class_similarities, dim=1)  # Shape: [batch_size, num_classes]

        # Now compute softmax over classes
        probs = F.softmax(batch_class_similarities, dim=1)  # Shape: [batch_size, num_classes]

        # Get top-5 predictions
        topk = 5
        topk_probs, topk_indices = probs.topk(topk, dim=1)  # Shape: [batch_size, topk]

        # Map indices to class names
        batch_topk_labels = [[class_names[idx] for idx in indices.cpu().numpy()] for indices in topk_indices]

        # Get the predicted labels (top-1)
        batch_predicted_labels = [labels[0] for labels in batch_topk_labels]
        predicted_labels.extend(batch_predicted_labels)

        # Update class_correct and class_total
        for i, (true_label, pred_label) in enumerate(zip(image_labels[start:end], batch_predicted_labels)):
            class_total[true_label] += 1
            if true_label == pred_label:
                class_correct[true_label] += 1

            # Collect per-image predictions
            per_image_prediction = {
                'image_index': start + i,
                'true_label': true_label,
                'predicted_label': pred_label,
                'top5_predicted_labels': batch_topk_labels[i],
                'top5_predicted_probs': topk_probs[i].cpu().numpy().tolist()
            }
            per_image_predictions.append(per_image_prediction)

    # Compute per-class accuracies
    per_class_accuracy = {}
    for class_name in set(image_labels):
        if class_total[class_name] > 0:
            accuracy = class_correct[class_name] / class_total[class_name]
            per_class_accuracy[class_name] = accuracy
        else:
            per_class_accuracy[class_name] = None  # No samples for this class

    # Compute overall accuracy
    total_correct = sum(class_correct.values())
    total_samples = sum(class_total.values())
    overall_accuracy = total_correct / total_samples if total_samples > 0 else 0

    return per_class_accuracy, overall_accuracy, class_correct, class_total, per_image_predictions

# %%
# =========================
# Compute Zero-Shot Accuracy Using Mean Text Embeddings
# =========================
# Organize text embeddings by class
class_to_text_embeddings = defaultdict(list)
for embedding, label in zip(text_embeddings, labels):
    class_to_text_embeddings[label].append(embedding)

# Compute per-class mean embeddings
mean_text_embeddings = []
mean_text_labels = []
for class_name in n_classes:
    embeddings = class_to_text_embeddings[class_name]
    mean_embedding = np.mean(embeddings, axis=0)
    mean_text_embeddings.append(mean_embedding)
    mean_text_labels.append(class_name)

mean_text_embeddings = np.array(mean_text_embeddings)  # Shape: CxD

# Compute zero-shot accuracy using mean text embeddings
mean_per_class_accuracy, mean_overall_accuracy, mean_class_correct, mean_class_total, _ = compute_zero_shot_accuracy(
    image_embeddings=image_embeddings,
    image_labels=image_labels,
    text_embeddings=mean_text_embeddings,
    text_labels=mean_text_labels,
    batch_size=KNN_BATCH_SIZE
)

# Display per-class accuracies
print("\nPer-Class Zero-Shot Accuracy Using Mean Text Embeddings:")
for class_name in n_classes:
    accuracy = mean_per_class_accuracy[class_name]
    if accuracy is not None:
        print(f"{class_name}: {accuracy:.2%} ({mean_class_correct[class_name]}/{mean_class_total[class_name]})")
    else:
        print(f"{class_name}: No samples")

print(f"\nOverall Zero-Shot Accuracy Using Mean Text Embeddings: {mean_overall_accuracy:.2%}")

# %%
# =========================
# Compute Zero-Shot Accuracy Using Standard Embeddings
# =========================
# Compute zero-shot accuracy using standard embeddings
standard_per_class_accuracy, standard_overall_accuracy, standard_class_correct, standard_class_total, _ = compute_zero_shot_accuracy(
    image_embeddings=image_embeddings,
    image_labels=image_labels,
    text_embeddings=standard_embeddings,
    text_labels=standard_labels,
    batch_size=KNN_BATCH_SIZE
)

# Display per-class accuracies
print("\nPer-Class Zero-Shot Accuracy Using Standard Embeddings:")
for class_name in n_classes:
    accuracy = standard_per_class_accuracy[class_name]
    if accuracy is not None:
        print(f"{class_name}: {accuracy:.2%} ({standard_class_correct[class_name]}/{standard_class_total[class_name]})")
    else:
        print(f"{class_name}: No samples")

print(f"\nOverall Zero-Shot Accuracy Using Standard Embeddings: {standard_overall_accuracy:.2%}")

# %%
# =========================
# Compare the Results
# =========================
# Create a comparison table
print("\nComparison of Per-Class Zero-Shot Accuracies:")
print(f"{'Class Name':30s} {'Mean Embedding Acc':20s} {'Standard Embedding Acc':20s}")
for class_name in n_classes:
    mean_acc = mean_per_class_accuracy[class_name]
    standard_acc = standard_per_class_accuracy[class_name]
    if mean_acc is not None and standard_acc is not None:
        print(f"{class_name:30s} {mean_acc:.2%} ({mean_class_correct[class_name]}/{mean_class_total[class_name]}), "
              f"{standard_acc:.2%} ({standard_class_correct[class_name]}/{standard_class_total[class_name]})")
    else:
        print(f"{class_name:30s} No samples")

# Compare overall accuracies
print(f"\nOverall Zero-Shot Accuracy Using Mean Text Embeddings: {mean_overall_accuracy:.2%}")
print(f"Overall Zero-Shot Accuracy Using Standard Embeddings: {standard_overall_accuracy:.2%}")


Number of classes: 45
['airplane',
 'airport',
 'baseball diamond',
 'basketball court',
 'beach',
 'bridge',
 'chaparral',
 'church',
 'circular farmland',
 'cloud',
 'commercial area',
 'dense residential',
 'desert',
 'forest',
 'freeway',
 'golf course',
 'ground track field',
 'harbor',
 'industrial area',
 'intersection',
 'island',
 'lake',
 'meadow',
 'medium residential',
 'mobile home park',
 'mountain',
 'overpass',
 'palace',
 'parking lot',
 'railway',
 'railway station',
 'rectangular farmland',
 'river',
 'roundabout',
 'runway',
 'sea ice',
 'ship',
 'snowberg',
 'sparse residential',
 'stadium',
 'storage tank',
 'tennis court',
 'terrace',
 'thermal power station',
 'wetland']
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45_global_traits.json
Using OpenAI model: gpt-4o-mini-2024-07-18
Configured to generate 32 captions.
Meta Prompt: You are an AI assistant that generates creative and diverse image captions suitable for use with image generation mode

Generating Alternative Captions:  76%|███████▌  | 34/45 [00:00<00:00, 165.63it/s]

Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\airplane.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\airport.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\baseball_diamond.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\basketball_court.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\beach.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\bridge.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\chaparral.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\church.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\circular_farmland.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\cloud.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\commercial_area.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\N

Generating Alternative Captions: 100%|██████████| 45/45 [00:00<00:00, 167.29it/s]


Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\runway.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\sea_ice.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\ship.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\snowberg.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\sparse_residential.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\stadium.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\storage_tank.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\tennis_court.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\terrace.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\thermal_power_station.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\NWPU-RESISC45\32\wetland.json
Processed batch 1/23
Processed batch 2/23
Processed bat

Processing Images: 100%|██████████| 493/493 [04:40<00:00,  1.76it/s]
Computing Zero-Shot Predictions: 100%|██████████| 31/31 [00:02<00:00, 11.03it/s]



Per-Class Zero-Shot Accuracy Using Mean Text Embeddings:
airplane: 6.86% (48/700)
airport: 10.29% (72/700)
baseball diamond: 91.86% (643/700)
basketball court: 50.57% (354/700)
beach: 62.57% (438/700)
bridge: 17.00% (119/700)
chaparral: 18.14% (127/700)
church: 6.00% (42/700)
circular farmland: 83.29% (583/700)
cloud: 59.57% (417/700)
commercial area: 0.00% (0/700)
dense residential: 44.57% (312/700)
desert: 95.29% (667/700)
forest: 65.29% (457/700)
freeway: 69.43% (486/700)
golf course: 84.14% (589/700)
ground track field: 17.14% (120/700)
harbor: 81.00% (567/700)
industrial area: 69.43% (486/700)
intersection: 78.29% (548/700)
island: 80.57% (564/700)
lake: 66.57% (466/700)
meadow: 0.00% (0/700)
medium residential: 0.00% (0/700)
mobile home park: 98.14% (687/700)
mountain: 43.29% (303/700)
overpass: 15.86% (111/700)
palace: 18.00% (126/700)
parking lot: 96.29% (674/700)
railway: 13.00% (91/700)
railway station: 23.14% (162/700)
rectangular farmland: 91.14% (638/700)
river: 41.00% (2

Computing Zero-Shot Predictions: 100%|██████████| 31/31 [00:02<00:00, 10.51it/s]


Per-Class Zero-Shot Accuracy Using Standard Embeddings:
airplane: 17.57% (123/700)
airport: 73.29% (513/700)
baseball diamond: 89.43% (626/700)
basketball court: 47.57% (333/700)
beach: 58.00% (406/700)
bridge: 42.86% (300/700)
chaparral: 0.00% (0/700)
church: 12.29% (86/700)
circular farmland: 86.43% (605/700)
cloud: 59.86% (419/700)
commercial area: 29.57% (207/700)
dense residential: 80.29% (562/700)
desert: 89.00% (623/700)
forest: 74.29% (520/700)
freeway: 75.14% (526/700)
golf course: 97.86% (685/700)
ground track field: 22.57% (158/700)
harbor: 78.29% (548/700)
industrial area: 63.57% (445/700)
intersection: 78.00% (546/700)
island: 93.14% (652/700)
lake: 77.29% (541/700)
meadow: 0.14% (1/700)
medium residential: 1.43% (10/700)
mobile home park: 97.86% (685/700)
mountain: 41.57% (291/700)
overpass: 18.14% (127/700)
palace: 38.57% (270/700)
parking lot: 91.71% (642/700)
railway: 79.43% (556/700)
railway station: 4.00% (28/700)
rectangular farmland: 85.00% (595/700)
river: 56.29%




In [None]:
# New story
# P(cls|linearity score) = 