In [1]:
# Image Method

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr

from utils.dataset import get_dataset
from utils.model_utils import load_model, get_text_embeddings, get_image_embeddings
import os
import glob
from PIL import Image

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "openai/clip-vit-base-patch32"

datasets_info = [
    {"name": "NWPU-RESISC45", "template": "a satellite photo containing {}."},
    # {"name": "Stanford_dogs", "template": "a photo of {}, a type of dog."},
    # {"name": "CUB_200_2011", "template": "a photo of {}, a type of bird."},
    # {"name": "Flower102", "template": "a photo of {}, a type of flower."},
]

class PseudoImageDataset(Dataset):
    def __init__(self, image_folder, class_names, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.class_names = class_names

        # Store class_names in a set for faster lookup
        class_name_set = set(class_names)

        for image_path in glob.glob(os.path.join(image_folder, '*.png')):
            filename = os.path.basename(image_path)
            class_name = filename.split('-')[0].replace('_', ' ')
            if class_name in class_name_set:
                self.image_paths.append(image_path)
                self.labels.append(class_name)

        # Convert labels to indices
        self.labels = [class_names.index(label) for label in self.labels]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

def get_image_loader(dataset, class_indices, batch_size, shuffle=False):
    if hasattr(dataset, 'targets'):
        selected_indices = [i for i, label in enumerate(dataset.targets) if label in class_indices]
    else:
        selected_indices = [i for i, (_, label) in enumerate(dataset.imgs) if label in class_indices]
    selected_dataset = Subset(dataset, selected_indices)

    # Ensure no transform or compatible transform is applied
    selected_dataset.dataset.transform = None

    loader = DataLoader(
        selected_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])
    )
    return loader

def run_experiment(dataset_name, template, use_pseudo_data):
    # 1. Load dataset (test split)
    dataset = get_dataset(dataset_name, data_root='data', split='test')
    class_names = dataset.classes
    num_classes = len(class_names)

    if use_pseudo_data:
        PSEUDO_IMAGES_FOLDER = f'pseudo_images/{dataset_name}'
        pseudo_dataset = PseudoImageDataset(
            image_folder=PSEUDO_IMAGES_FOLDER,
            class_names=class_names,
            transform=None
        )
        dataloader = DataLoader(
            pseudo_dataset,
            batch_size=64,
            shuffle=False,
            pin_memory=True,
            collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])
        )
    else:
        selected_class_indices = list(range(num_classes))
        dataloader = get_image_loader(dataset, selected_class_indices, 64, shuffle=False)

    # 2. Load model
    model_info = load_model(model_name, device=device)

    # 3. Create captions and compute text embeddings
    captions = [template.format(c) for c in class_names]
    text_embeddings = get_text_embeddings(captions, model_info, device, batch_size=64)

    # 4. Compute image embeddings
    image_embeddings, all_targets = get_image_embeddings(dataloader, model_info, device)
    all_targets = np.array(all_targets)

    # Normalize embeddings
    image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
    text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)

    # 5. Compute similarities and predictions
    temperature = 0.01
    similarities = np.matmul(image_embeddings, text_embeddings.T) / temperature
    probs = np.exp(similarities) / np.sum(np.exp(similarities), axis=1, keepdims=True)

    preds = np.argmax(probs, axis=1)
    confidences = np.max(probs, axis=1)

    # Compute class-wise actual and predicted accuracies
    actual_accuracies = []
    predicted_accuracies = []
    for c in range(num_classes):
        cur_class_indices = (all_targets == c)
        cur_class_pred = preds[cur_class_indices]
        cur_class_accuracy = (cur_class_pred == c).mean()

        predicted_cur_class_indices = (preds == c)
        predicted_confidences = confidences[predicted_cur_class_indices]
        predicted_acc = predicted_confidences.mean() if predicted_confidences.size > 0 else np.nan

        actual_accuracies.append(cur_class_accuracy)
        predicted_accuracies.append(predicted_acc if not np.isnan(predicted_acc) else 0.0)
            
    # Compute Spearman correlation
    spearman_corr = spearmanr(np.array(actual_accuracies), np.array(predicted_accuracies))

    # Return data for plotting outside
    return np.array(actual_accuracies), np.array(predicted_accuracies), spearman_corr

# Now we call run_experiment for pseudo and real data and plot them together
for dinfo in datasets_info:
    actual_pseudo, predicted_pseudo, spearman_pseudo = run_experiment(dinfo["name"], dinfo["template"], use_pseudo_data=True)
    actual_real, predicted_real, spearman_real = run_experiment(dinfo["name"], dinfo["template"], use_pseudo_data=False)

    # Recompute pseudo Spearman correlation using actual_real for the y-values
    pseudo_spearman_corr = spearmanr(actual_real, predicted_pseudo)

    # Create a single plot for both pseudo and real
    fig, ax = plt.subplots(figsize=(8,8))
    # For pseudo, use actual_real as y-axis
    ax.scatter(predicted_pseudo, actual_real, color='tab:orange', label='Pseudo Images', alpha=0.5)
    ax.scatter(predicted_real, actual_real, color='tab:blue', label='Real Images', alpha=0.5)

    ax.set_xlabel("Predicted Accuracy (Confidence)")
    ax.set_ylabel("Actual Accuracy (Real Data)")
    pseudo_corr_str = f"{pseudo_spearman_corr.correlation:.2f}" if pseudo_spearman_corr.correlation is not None else "N/A"
    real_corr_str = f"{spearman_real.correlation:.2f}" if spearman_real.correlation is not None else "N/A"
    ax.set_title(f"Calibration Approach: {dinfo['name']}\nSpearman (Pseudo vs Real Accuracy): {pseudo_corr_str}, Spearman (Real vs Real Accuracy): {real_corr_str}")
    ax.grid(True)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    ax.legend()

    # Save the figure
    plot_filename = f"figures/{dinfo['name']}_combined_pred_vs_actual.png"
    plt.savefig(plot_filename)
    plt.close(fig)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Processed batch 1/1


Processing Images: 100%|██████████| 25/25 [00:33<00:00,  1.35s/it]


Processed batch 1/1


Processing Images: 100%|██████████| 493/493 [04:28<00:00,  1.84it/s]


In [11]:
# Text Method

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr

from utils.dataset import get_dataset
from utils.model_utils import load_model, get_text_embeddings, get_image_embeddings
import os
import glob
from PIL import Image

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "openai/clip-vit-base-patch32"

datasets_info = [
    # {"name": "NWPU-RESISC45", "template": "a satellite photo containing {}."},
    # {"name": "Stanford_dogs", "template": "a photo of {}, a type of dog."},
    {"name": "CUB_200_2011", "template": "a photo of {}, a type of bird."},
    {"name": "Flower102", "template": "a photo of {}, a type of flower."},
]

from utils.utils import CaptionGenerator
from tqdm import tqdm

def load_descriptive_captions(dataset_name, n_classes):
    capGenerator = CaptionGenerator(dataset_name=dataset_name, class_names=n_classes, num_captions=40)
    alter_caption_list = []
    labels = []
    for class_name in tqdm(n_classes, desc="Generating Alternative Captions"):
        alter_captions = capGenerator.get_alternative_captions(class_name)
        alter_caption_list.extend(alter_captions)
        labels.extend([class_name] * len(alter_captions))
    # we should use index instead of class name
    labels = [n_classes.index(label) for label in labels]
    return alter_caption_list, labels


class PseudoImageDataset(Dataset):
    def __init__(self, image_folder, class_names, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.class_names = class_names

        # Store class_names in a set for faster lookup
        class_name_set = set(class_names)

        for image_path in glob.glob(os.path.join(image_folder, '*.png')):
            filename = os.path.basename(image_path)
            class_name = filename.split('-')[0].replace('_', ' ')
            if class_name in class_name_set:
                self.image_paths.append(image_path)
                self.labels.append(class_name)

        # Convert labels to indices
        self.labels = [class_names.index(label) for label in self.labels]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

def get_image_loader(dataset, class_indices, batch_size, shuffle=False):
    if hasattr(dataset, 'targets'):
        selected_indices = [i for i, label in enumerate(dataset.targets) if label in class_indices]
    else:
        selected_indices = [i for i, (_, label) in enumerate(dataset.imgs) if label in class_indices]
    selected_dataset = Subset(dataset, selected_indices)

    # Ensure no transform or compatible transform is applied
    selected_dataset.dataset.transform = None

    loader = DataLoader(
        selected_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])
    )
    return loader

def run_experiment(dataset_name, template, use_pseudo_data):
    # 1. Load dataset (test split)
    dataset = get_dataset(dataset_name, data_root='data', split='test')
    class_names = dataset.classes
    num_classes = len(class_names)

    selected_class_indices = list(range(num_classes))
    dataloader = get_image_loader(dataset, selected_class_indices, 64, shuffle=False)

    # load_descriptive_captions(dataset_name, class_names)
    descriptive_text, descriptive_labels = load_descriptive_captions(dataset_name, class_names)

    # 2. Load model
    model_info = load_model(model_name, device=device)

    # 3. Create captions and compute text embeddings
    captions = [template.format(c) for c in class_names]
    text_embeddings = get_text_embeddings(captions, model_info, device, batch_size=64)

    # 4. Compute image embeddings
    image_embeddings, all_targets = get_image_embeddings(dataloader, model_info, device)

    all_targets = np.array(all_targets)
    descriptive_labels = np.array(descriptive_labels)

    descriptive_text_embeddings = get_text_embeddings(descriptive_text, model_info, device, batch_size=64)
    # Normalize embeddings
    image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
    text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
    descriptive_text_embeddings = descriptive_text_embeddings / np.linalg.norm(descriptive_text_embeddings, axis=1, keepdims=True)
    
    # 5. Compute similarities and predictions
    
    temperature = 0.1
    similarities = np.matmul(descriptive_text_embeddings, text_embeddings.T) / temperature
    probs = np.exp(similarities) / np.sum(np.exp(similarities), axis=1, keepdims=True) # shape: (num_descriptive, num_classes)
    preds = np.argmax(probs, axis=1)
    confidences = np.max(probs, axis=1)

    img_text_similarities = np.matmul(image_embeddings, text_embeddings.T) / temperature
    img_text_probs = np.exp(img_text_similarities) / np.sum(np.exp(img_text_similarities), axis=1, keepdims=True)

    img_text_preds = np.argmax(img_text_probs, axis=1)
    # confidences = np.max(probs, axis=1)

    # Compute class-wise actual and predicted accuracies
    actual_accuracies = []
    predicted_accuracies = []
    for c in range(num_classes):
        cur_class_indices = (all_targets == c)
        cur_class_pred = img_text_preds[cur_class_indices]
        cur_class_accuracy = (cur_class_pred == c).mean()

        predicted_cur_class_indices = (preds == c)
        predicted_confidences = confidences[predicted_cur_class_indices]
        predicted_acc = predicted_confidences.mean() if predicted_confidences.size > 0 else np.nan

        actual_accuracies.append(cur_class_accuracy)
        predicted_accuracies.append(predicted_acc if not np.isnan(predicted_acc) else 0.0)
            
    # Compute Spearman correlation
    spearman_corr = spearmanr(np.array(actual_accuracies), np.array(predicted_accuracies))

    # Return data for plotting outside
    return np.array(actual_accuracies), np.array(predicted_accuracies), spearman_corr

# Now we call run_experiment for pseudo and real data and plot them together
for dinfo in datasets_info:
    actual_real, predicted_real, spearman_real = run_experiment(dinfo["name"], dinfo["template"], use_pseudo_data=False)

      # Create a single plot for both pseudo and real
    fig, ax = plt.subplots(figsize=(8,8))

    ax.scatter(predicted_real, actual_real, color='tab:blue', label='Real Images', alpha=0.5)

    ax.set_xlabel("Predicted Accuracy (Confidence)")
    ax.set_ylabel("Actual Accuracy (Real Data)")
    real_corr_str = f"{spearman_real.correlation:.2f}" if spearman_real.correlation is not None else "N/A"
    ax.set_title(f"Calibration Approach: {dinfo['name']}\nSpearman (Text Consifence vs Real Accuracy): {real_corr_str}")
    ax.grid(True)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    ax.legend()

    # Save the figure
    plot_filename = f"figures/text_{dinfo['name']}_combined_pred_vs_actual.png"
    plt.savefig(plot_filename)
    plt.close(fig)

Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011_global_traits.json
Using OpenAI model: gpt-4o-mini-2024-07-18
Configured to generate 40 captions.
Meta Prompt: You are an AI assistant that generates creative and diverse image captions suitable for use with image generation models like DALL-E. Given a subject, provide 40 distinct, diverse and descriptive captions, considering the following global taxonomical traits when generating captions: ['Feathered body', 'Beak shape and size', 'Wing structure and length', 'Color patterns and plumage', 'Size and weight', 'Tail shape and length', 'Behavior (e.g. migratory, territorial)', 'Habitat preference (e.g. aquatic, forest, grassland)', 'Vocalizations and songs', 'Nesting habits (e.g. ground, trees, cliffs)'].


Generating Alternative Captions: 100%|██████████| 200/200 [00:00<00:00, 1908.40it/s]

Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Black_footed_Albatross.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Laysan_Albatross.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Sooty_Albatross.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Groove_billed_Ani.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Crested_Auklet.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Least_Auklet.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Parakeet_Auklet.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Rhinoceros_Auklet.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Brewer_Blackbird.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Red_winged_Blackbird.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\CUB_200_2011\40\Rusty_Blackbird.j




Processed batch 1/4
Processed batch 2/4
Processed batch 3/4
Processed batch 4/4


Processing Images: 100%|██████████| 185/185 [01:27<00:00,  2.13it/s]


Processed batch 1/117
Processed batch 2/117
Processed batch 3/117
Processed batch 4/117
Processed batch 5/117
Processed batch 6/117
Processed batch 7/117
Processed batch 8/117
Processed batch 9/117
Processed batch 10/117
Processed batch 11/117
Processed batch 12/117
Processed batch 13/117
Processed batch 14/117
Processed batch 15/117
Processed batch 16/117
Processed batch 17/117
Processed batch 18/117
Processed batch 19/117
Processed batch 20/117
Processed batch 21/117
Processed batch 22/117
Processed batch 23/117
Processed batch 24/117
Processed batch 25/117
Processed batch 26/117
Processed batch 27/117
Processed batch 28/117
Processed batch 29/117
Processed batch 30/117
Processed batch 31/117
Processed batch 32/117
Processed batch 33/117
Processed batch 34/117
Processed batch 35/117
Processed batch 36/117
Processed batch 37/117
Processed batch 38/117
Processed batch 39/117
Processed batch 40/117
Processed batch 41/117
Processed batch 42/117
Processed batch 43/117
Processed batch 44/1

Generating Alternative Captions: 100%|██████████| 102/102 [00:00<00:00, 3232.09it/s]

Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\alpine_sea_holly.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\anthurium.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\artichoke.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\azalea.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\ball_moss.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\balloon_flower.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\barbeton_daisy.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\bearded_iris.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\bee_balm.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\bird_of_paradise.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\bishop_of_llandaff.json
Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102\40\black-eyed_




Processed batch 1/2
Processed batch 2/2


Processing Images: 100%|██████████| 128/128 [01:08<00:00,  1.86it/s]


Processed batch 1/60
Processed batch 2/60
Processed batch 3/60
Processed batch 4/60
Processed batch 5/60
Processed batch 6/60
Processed batch 7/60
Processed batch 8/60
Processed batch 9/60
Processed batch 10/60
Processed batch 11/60
Processed batch 12/60
Processed batch 13/60
Processed batch 14/60
Processed batch 15/60
Processed batch 16/60
Processed batch 17/60
Processed batch 18/60
Processed batch 19/60
Processed batch 20/60
Processed batch 21/60
Processed batch 22/60
Processed batch 23/60
Processed batch 24/60
Processed batch 25/60
Processed batch 26/60
Processed batch 27/60
Processed batch 28/60
Processed batch 29/60
Processed batch 30/60
Processed batch 31/60
Processed batch 32/60
Processed batch 33/60
Processed batch 34/60
Processed batch 35/60
Processed batch 36/60
Processed batch 37/60
Processed batch 38/60
Processed batch 39/60
Processed batch 40/60
Processed batch 41/60
Processed batch 42/60
Processed batch 43/60
Processed batch 44/60
Processed batch 45/60
Processed batch 46/