In [5]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr

from utils.dataset import get_dataset
from utils.model_utils import load_model, get_text_embeddings, get_image_embeddings
import os
import glob
from PIL import Image

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "openai/clip-vit-base-patch32"

# Define datasets and templates
datasets_info = [
    {"name": "NWPU-RESISC45", "template": "a satellite photo containing {}."},
    {"name": "Stanford_dogs", "template": "a photo of {}, a type of dog."},
    {"name": "CUB_200_2011", "template": "a photo of {}, a type of bird."},
    {"name": "Flower102", "template": "a photo of {}, a type of flower."},
]

class PseudoImageDataset(Dataset):
    def __init__(self, image_folder, class_names, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.class_names = set(class_names)  # For faster lookup

        for image_path in glob.glob(os.path.join(image_folder, '*.png')):
            filename = os.path.basename(image_path)
            class_name = filename.split('-')[0].replace('_', ' ')
            if class_name in self.class_names:
                self.image_paths.append(image_path)
                self.labels.append(class_name)
        # convert labels to indices
        self.labels = [class_names.index(label) for label in self.labels]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

def get_image_loader(dataset, class_indices, batch_size, shuffle=False):
    if hasattr(dataset, 'targets'):
        selected_indices = [i for i, label in enumerate(dataset.targets) if label in class_indices]
    else:
        selected_indices = [i for i, (_, label) in enumerate(dataset.imgs) if label in class_indices]
    selected_dataset = Subset(dataset, selected_indices)

    # Set no transform for simplicity (depends on dataset loading code)
    selected_dataset.dataset.transform = None

    loader = DataLoader(
        selected_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])  # Custom collate
    )
    return loader

def run_experiment(dataset_name, template):
    # 1. Load dataset (test split)
    dataset = get_dataset(dataset_name, data_root='data', split='test')
    class_names = dataset.classes
    num_classes = len(class_names)
    # Create dataloader
    if USE_PSEUDO_DATA:
        PSEUDO_IMAGES_FOLDER = f'pseudo_images/{dataset_name}'
        pseudo_dataset = PseudoImageDataset(
            image_folder=PSEUDO_IMAGES_FOLDER,
            class_names=dataset.classes,
            transform=None  # Add transformations if required
        )

        dataloader = DataLoader(
            pseudo_dataset,
            batch_size=64,
            shuffle=False,
            pin_memory=True,
            collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])  # Custom collate function
        )
    else:
        selected_class_indices = list(range(num_classes))
        dataloader = get_image_loader(dataset, selected_class_indices, 64, shuffle=False)

    # 2. Load model
    model_info = load_model(model_name, device=device)

    # 3. Create captions and compute text embeddings
    captions = [template.format(c) for c in class_names]
    text_embeddings = get_text_embeddings(captions, model_info, device, batch_size=64)

    # 4. Compute image embeddings
    image_embeddings, all_targets = get_image_embeddings(dataloader, model_info, device)
    all_targets = np.array(all_targets)

    # Normalize embeddings
    image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
    text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)

    # 5. Compute similarities and predictions
    temperature = 0.01
    similarities = np.matmul(image_embeddings, text_embeddings.T) / temperature
    probs = np.exp(similarities) / np.sum(np.exp(similarities), axis=1, keepdims=True)

    preds = np.argmax(probs, axis=1)
    confidences = np.max(probs, axis=1)

    # Compute class-wise actual and predicted accuracies
    actual_accuracies = []
    predicted_accuracies = []
    for c in range(num_classes):
        cur_class_indices = (all_targets == c)
        cur_class_pred = preds[cur_class_indices]
        cur_class_accuracy = (cur_class_pred == c).mean()
        if np.isnan(cur_class_accuracy):
            continue
        
        predicted_cur_class_indices = (preds == c)
        predicted_confidences = confidences[predicted_cur_class_indices]
        predicted_acc = predicted_confidences.mean() if predicted_confidences.size > 0 else np.nan

        actual_accuracies.append(cur_class_accuracy)
        if not np.isnan(predicted_acc):
            predicted_accuracies.append(predicted_acc)
        else:
            predicted_accuracies.append(0.0)
            
    # Compute Spearman correlation
    spearman_corr = spearmanr(np.array(actual_accuracies), np.array(predicted_accuracies))
    print(f"Dataset: {dataset_name}, Spearman correlation: {spearman_corr}")

    # Plot and save
    fig, ax = plt.subplots(figsize=(8,8))
    if USE_PSEUDO_DATA: # orange
        color = 'tab:orange'
    else: # blue
        color = 'tab:blue'
    ax.scatter(predicted_accuracies, actual_accuracies, color=color)
    ax.set_xlabel("Predicted Accuracy (Confidence)")
    ax.set_ylabel("Actual Accuracy")
    if USE_PSEUDO_DATA:
        title_name = f"Calibration Approach: {dataset_name} with Pseudo Images\n Spearman Correlation: {spearman_corr[0]:.2f}"
    else:
        title_name = f"Calibration Approach: {dataset_name} with Real Images\n Spearman Correlation: {spearman_corr[0]:.2f}"
    ax.set_title(title_name)
    ax.grid(True)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    if USE_PSEUDO_DATA:
        plot_filename = f"figures/{dataset_name}_pseudo_pred_vs_actual.png"
    else:
        plot_filename = f"figures/{dataset_name}_pred_vs_actual.png"
    plt.savefig(plot_filename)
    plt.close(fig)

# Run experiments for each dataset

USE_PSEUDO_DATA = True
for dinfo in datasets_info:
    run_experiment(dinfo["name"], dinfo["template"])
USE_PSEUDO_DATA = False
for dinfo in datasets_info:
    run_experiment(dinfo["name"], dinfo["template"])


Processed batch 1/1


Processing Images: 100%|██████████| 23/23 [00:22<00:00,  1.02it/s]


Dataset: NWPU-RESISC45, Spearman correlation: SignificanceResult(statistic=0.8166573537765787, pvalue=7.952020772809736e-12)
Processed batch 1/2
Processed batch 2/2


Processing Images: 100%|██████████| 59/59 [00:54<00:00,  1.09it/s]


Dataset: Stanford_dogs, Spearman correlation: SignificanceResult(statistic=0.7259391551199936, pvalue=6.5836049796789e-21)
Processed batch 1/4
Processed batch 2/4
Processed batch 3/4
Processed batch 4/4


Processing Images: 100%|██████████| 98/98 [04:19<00:00,  2.65s/it]
  cur_class_accuracy = (cur_class_pred == c).mean()
  ret = ret.dtype.type(ret / rcount)


Dataset: CUB_200_2011, Spearman correlation: SignificanceResult(statistic=0.833884559388305, pvalue=9.652183689450992e-53)
Processed batch 1/2
Processed batch 2/2


Processing Images: 100%|██████████| 46/46 [00:45<00:00,  1.01it/s]


Dataset: Flower102, Spearman correlation: SignificanceResult(statistic=0.6967831449517579, pvalue=4.444976788328451e-15)
Processed batch 1/1


Processing Images: 100%|██████████| 493/493 [04:48<00:00,  1.71it/s]


Dataset: NWPU-RESISC45, Spearman correlation: SignificanceResult(statistic=0.8180500658761528, pvalue=6.850790104513321e-12)
Processed batch 1/2
Processed batch 2/2


Processing Images: 100%|██████████| 322/322 [03:43<00:00,  1.44it/s]


Dataset: Stanford_dogs, Spearman correlation: SignificanceResult(statistic=0.8170845197583164, pvalue=5.3325877013584045e-30)
Processed batch 1/4
Processed batch 2/4
Processed batch 3/4
Processed batch 4/4


Processing Images: 100%|██████████| 185/185 [02:31<00:00,  1.22it/s]


Dataset: CUB_200_2011, Spearman correlation: SignificanceResult(statistic=0.8781028005954786, pvalue=2.6330984361777895e-65)
Processed batch 1/2
Processed batch 2/2


Processing Images: 100%|██████████| 128/128 [01:54<00:00,  1.12it/s]

Dataset: Flower102, Spearman correlation: SignificanceResult(statistic=0.8870509216952215, pvalue=2.4270314261147215e-35)





In [None]:



# Run experiments for each dataset
# for dinfo in datasets_info:
#     run_experiment(dinfo["name"], dinfo["template"])
dataset_name = dinfo["name"]
template = dinfo["template"]
# def run_experiment(dataset_name, template):
    # 1. Load dataset (test split)
dataset = get_dataset(dataset_name, data_root='data', split='test')
class_names = dataset.classes
num_classes = len(class_names)
from utils.utils import CaptionGenerator

# Initialize CaptionGenerator and generate alternative captions
capGenerator = CaptionGenerator(dataset_name=dataset_name, class_names=class_names, num_captions=40)

alter_caption_list = []
labels = []
all_targets = []
for class_name in class_names:
    alter_captions = capGenerator.get_alternative_captions(class_name)
    alter_caption_list.extend(alter_captions)
    labels.extend([class_name] * len(alter_captions))
    all_targets.extend([class_names.index(class_name)] * len(alter_captions))

selected_class_indices = list(range(num_classes))
dataloader = get_image_loader(dataset, selected_class_indices, 64, shuffle=False)

# 2. Load model
model_info = load_model(model_name, device=device)

# 3. Create captions and compute text embeddings
captions = [template.format(c) for c in class_names]
text_embeddings = get_text_embeddings(captions, model_info, device, batch_size=64)

# 4. Compute alternative caption embeddings
descriptive_text_embeddings = get_text_embeddings(alter_caption_list, model_info, device, batch_size=64)

# image_embeddings, all_targets = get_image_embeddings(dataloader, model_info, device)
# all_targets = np.array(all_targets)

# Normalize embeddings
# image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
descriptive_text_embeddings = descriptive_text_embeddings / np.linalg.norm(descriptive_text_embeddings, axis=1, keepdims=True)
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)

# descriptive text shape: (num_classes * num_captions, embedding_dim)
# text_embeddings shape: (num_classes, embedding_dim)

# 5. Compute similarities and predictions
temperature = 0.07
similarities = np.matmul(descriptive_text_embeddings, text_embeddings.T) / temperature
probs = np.exp(similarities) / np.sum(np.exp(similarities), axis=1, keepdims=True)

preds = np.argmax(probs, axis=1)
confidences = np.max(probs, axis=1)

# Compute class-wise actual and predicted accuracies
actual_accuracies = []
predicted_accuracies = []
for c in range(num_classes):
    cur_class_indices = (all_targets == c)
    cur_class_pred = preds[cur_class_indices]
    cur_class_accuracy = (cur_class_pred == c).mean()
    if np.isnan(cur_class_accuracy):
        continue
    
    predicted_cur_class_indices = (preds == c)
    predicted_confidences = confidences[predicted_cur_class_indices]
    predicted_acc = predicted_confidences.mean() if predicted_confidences.size > 0 else np.nan

    actual_accuracies.append(cur_class_accuracy)
    if not np.isnan(predicted_acc):
        predicted_accuracies.append(predicted_acc)
    else:
        predicted_accuracies.append(0.0)
        
# Compute Spearman correlation
spearman_corr = spearmanr(np.array(actual_accuracies), np.array(predicted_accuracies))
print(f"Dataset: {dataset_name}, Spearman correlation: {spearman_corr}")

# Plot and save
fig, ax = plt.subplots(figsize=(8,8))
ax.scatter(predicted_accuracies, actual_accuracies)
ax.set_xlabel("Predicted Accuracy (Confidence)")
ax.set_ylabel("Actual Accuracy")
ax.set_title(f"{dataset_name}: Predicted vs. Actual Accuracy per Class\n Spearman Correlation: {spearman_corr[0]:.2f}")
ax.grid(True)

plot_filename = f"{dataset_name}_pred_vs_actual.png"
plt.savefig(plot_filename)
plt.close(fig)


Loaded data from cache: cache\gpt-4o-mini-2024-07-18\Flower102_global_traits.json
Using OpenAI model: gpt-4o-mini-2024-07-18
Configured to generate 40 captions.
Meta Prompt: You are an AI assistant that generates creative and diverse image captions suitable for use with image generation models like DALL-E. Given a subject, provide 40 distinct, diverse and descriptive captions, considering the following global taxonomical traits when generating captions: ['Petal arrangement', 'Leaf shape', 'Flower color', 'Growth habit (herbaceous vs. woody)', 'Stem structure', 'Pollination mechanism', 'Geographic distribution', 'Habitat preference (aquatic, terrestrial)', 'Fruit type (dry, fleshy)', 'Root system (fibrous vs. taproot)', 'Seasonality (annual, perennial)', 'Height and size variation', 'Reproductive structures (e.g., presence of sepals, stamens)', 'Symmetry of flowers (radial vs. bilateral)', 'Response to climate conditions', 'Association with specific insects or animals'].
Loaded data fro

  cur_class_accuracy = (cur_class_pred == c).mean()
  ret = ret.dtype.type(ret / rcount)


In [60]:
probs

array([[9.99998987e-01, 7.78778277e-26, 7.06572385e-17, ...,
        8.25450803e-21, 4.14319259e-21, 9.51531882e-28],
       [1.00000000e+00, 8.36850962e-22, 2.20842442e-13, ...,
        3.51010122e-17, 1.05598439e-16, 1.12418414e-23],
       [9.99999642e-01, 1.40198842e-20, 7.73849775e-14, ...,
        2.03927628e-18, 1.29697553e-17, 3.74584931e-25],
       ...,
       [9.05913343e-21, 2.80029886e-17, 1.04934279e-14, ...,
        1.21163184e-11, 1.18866200e-14, 9.99999166e-01],
       [6.35460261e-21, 1.46658659e-17, 7.36505312e-15, ...,
        3.08224557e-11, 6.75076789e-16, 9.99908686e-01],
       [4.93357346e-22, 1.46807395e-18, 1.36964031e-15, ...,
        1.73707305e-12, 1.12639857e-16, 9.99999821e-01]], dtype=float32)

In [4]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr

from utils.dataset import get_dataset
from utils.model_utils import load_model, get_text_embeddings, get_image_embeddings
import os
import glob
from PIL import Image

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "openai/clip-vit-base-patch32"

datasets_info = [
    # {"name": "NWPU-RESISC45", "template": "a satellite photo containing {}."},
    # {"name": "Stanford_dogs", "template": "a photo of {}, a type of dog."},
    {"name": "CUB_200_2011", "template": "a photo of {}, a type of bird."},
    # {"name": "Flower102", "template": "a photo of {}, a type of flower."},
]

class PseudoImageDataset(Dataset):
    def __init__(self, image_folder, class_names, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.class_names = class_names

        # Store class_names in a set for faster lookup
        class_name_set = set(class_names)

        for image_path in glob.glob(os.path.join(image_folder, '*.png')):
            filename = os.path.basename(image_path)
            class_name = filename.split('-')[0].replace('_', ' ')
            if class_name in class_name_set:
                self.image_paths.append(image_path)
                self.labels.append(class_name)

        # Convert labels to indices
        self.labels = [class_names.index(label) for label in self.labels]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

def get_image_loader(dataset, class_indices, batch_size, shuffle=False):
    if hasattr(dataset, 'targets'):
        selected_indices = [i for i, label in enumerate(dataset.targets) if label in class_indices]
    else:
        selected_indices = [i for i, (_, label) in enumerate(dataset.imgs) if label in class_indices]
    selected_dataset = Subset(dataset, selected_indices)

    # Ensure no transform or compatible transform is applied
    selected_dataset.dataset.transform = None

    loader = DataLoader(
        selected_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])
    )
    return loader

def run_experiment(dataset_name, template, use_pseudo_data):
    # 1. Load dataset (test split)
    dataset = get_dataset(dataset_name, data_root='data', split='test')
    class_names = dataset.classes
    num_classes = len(class_names)

    if use_pseudo_data:
        PSEUDO_IMAGES_FOLDER = f'pseudo_images/{dataset_name}'
        pseudo_dataset = PseudoImageDataset(
            image_folder=PSEUDO_IMAGES_FOLDER,
            class_names=class_names,
            transform=None
        )
        dataloader = DataLoader(
            pseudo_dataset,
            batch_size=64,
            shuffle=False,
            pin_memory=True,
            collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])
        )
    else:
        selected_class_indices = list(range(num_classes))
        dataloader = get_image_loader(dataset, selected_class_indices, 64, shuffle=False)

    # 2. Load model
    model_info = load_model(model_name, device=device)

    # 3. Create captions and compute text embeddings
    captions = [template.format(c) for c in class_names]
    text_embeddings = get_text_embeddings(captions, model_info, device, batch_size=64)

    # 4. Compute image embeddings
    image_embeddings, all_targets = get_image_embeddings(dataloader, model_info, device)
    all_targets = np.array(all_targets)

    # Normalize embeddings
    image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
    text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)

    # 5. Compute similarities and predictions
    temperature = 0.01
    similarities = np.matmul(image_embeddings, text_embeddings.T) / temperature
    probs = np.exp(similarities) / np.sum(np.exp(similarities), axis=1, keepdims=True)

    preds = np.argmax(probs, axis=1)
    confidences = np.max(probs, axis=1)

    # Compute class-wise actual and predicted accuracies
    actual_accuracies = []
    predicted_accuracies = []
    for c in range(num_classes):
        cur_class_indices = (all_targets == c)
        if cur_class_indices.sum() == 0:
            continue
        cur_class_pred = preds[cur_class_indices]
        cur_class_accuracy = (cur_class_pred == c).mean()
        if np.isnan(cur_class_accuracy):
            continue
        
        predicted_cur_class_indices = (preds == c)
        predicted_confidences = confidences[predicted_cur_class_indices]
        predicted_acc = predicted_confidences.mean() if predicted_confidences.size > 0 else np.nan

        actual_accuracies.append(cur_class_accuracy)
        predicted_accuracies.append(predicted_acc if not np.isnan(predicted_acc) else 0.0)
            
    # Compute Spearman correlation
    spearman_corr = spearmanr(np.array(actual_accuracies), np.array(predicted_accuracies))

    # Return data for plotting outside
    return np.array(actual_accuracies), np.array(predicted_accuracies), spearman_corr


# Now we call run_experiment for pseudo and real data and plot them together
for dinfo in datasets_info:
    actual_pseudo, predicted_pseudo, spearman_pseudo = run_experiment(dinfo["name"], dinfo["template"], use_pseudo_data=True)
    actual_real, predicted_real, spearman_real = run_experiment(dinfo["name"], dinfo["template"], use_pseudo_data=False)

    # Create a single plot for both pseudo and real
    fig, ax = plt.subplots(figsize=(8,8))
    ax.scatter(predicted_pseudo, actual_pseudo, color='tab:orange', label='Pseudo Images', alpha = 0.5)
    ax.scatter(predicted_real, actual_real, color='tab:blue', label='Real Images', alpha=0.5)

    ax.set_xlabel("Predicted Accuracy (Confidence)")
    ax.set_ylabel("Actual Accuracy")
    pseudo_corr_str = f"{spearman_pseudo.correlation:.2f}" if spearman_pseudo.correlation is not None else "N/A"
    real_corr_str = f"{spearman_real.correlation:.2f}" if spearman_real.correlation is not None else "N/A"
    ax.set_title(f"Calibration Approach: {dinfo['name']}\nSpearman (Pseudo): {pseudo_corr_str}, Spearman (Real): {real_corr_str}")
    ax.grid(True)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    ax.legend()

    # Save the figure
    plot_filename = f"figures/{dinfo['name']}_combined_pred_vs_actual.png"
    plt.savefig(plot_filename)
    plt.close(fig)


Processed batch 1/4
Processed batch 2/4
Processed batch 3/4
Processed batch 4/4


Processing Images: 100%|██████████| 98/98 [05:33<00:00,  3.41s/it]


Processed batch 1/4
Processed batch 2/4
Processed batch 3/4
Processed batch 4/4


Processing Images: 100%|██████████| 185/185 [02:31<00:00,  1.22it/s]


In [3]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr

from utils.dataset import get_dataset
from utils.model_utils import load_model, get_text_embeddings, get_image_embeddings
import os
import glob
from PIL import Image

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "openai/clip-vit-base-patch32"

# Define datasets and templates
datasets_info = [
    # {"name": "NWPU-RESISC45", "template": "a satellite photo containing {}."},
    # {"name": "Stanford_dogs", "template": "a photo of {}, a type of dog."},
    {"name": "CUB_200_2011", "template": "a photo of {}, a type of bird."},
    # {"name": "Flower102", "template": "a photo of {}, a type of flower."},
]

class PseudoImageDataset(Dataset):
    def __init__(self, image_folder, class_names, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.class_names = set(class_names)  # For faster lookup

        for image_path in glob.glob(os.path.join(image_folder, '*.png')):
            filename = os.path.basename(image_path)
            class_name = filename.split('-')[0].replace('_', ' ')
            if class_name in self.class_names:
                self.image_paths.append(image_path)
                self.labels.append(class_name)
        # convert labels to indices
        self.labels = [class_names.index(label) for label in self.labels]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

def get_image_loader(dataset, class_indices, batch_size, shuffle=False):
    if hasattr(dataset, 'targets'):
        selected_indices = [i for i, label in enumerate(dataset.targets) if label in class_indices]
    else:
        selected_indices = [i for i, (_, label) in enumerate(dataset.imgs) if label in class_indices]
    selected_dataset = Subset(dataset, selected_indices)

    # Set no transform for simplicity (depends on dataset loading code)
    selected_dataset.dataset.transform = None

    loader = DataLoader(
        selected_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])  # Custom collate
    )
    return loader

def run_experiment(dataset_name, template):
    # 1. Load dataset (test split)
    dataset = get_dataset(dataset_name, data_root='data', split='test')
    class_names = dataset.classes
    num_classes = len(class_names)
    # Create dataloader
    if USE_PSEUDO_DATA:
        PSEUDO_IMAGES_FOLDER = f'pseudo_images/{dataset_name}'
        pseudo_dataset = PseudoImageDataset(
            image_folder=PSEUDO_IMAGES_FOLDER,
            class_names=dataset.classes,
            transform=None  # Add transformations if required
        )

        dataloader = DataLoader(
            pseudo_dataset,
            batch_size=64,
            shuffle=False,
            pin_memory=True,
            collate_fn=lambda batch: (list(zip(*batch))[0], list(zip(*batch))[1])  # Custom collate function
        )
    else:
        selected_class_indices = list(range(num_classes))
        dataloader = get_image_loader(dataset, selected_class_indices, 64, shuffle=False)

    # 2. Load model
    model_info = load_model(model_name, device=device)

    # 3. Create captions and compute text embeddings
    captions = [template.format(c) for c in class_names]
    text_embeddings = get_text_embeddings(captions, model_info, device, batch_size=64)

    # 4. Compute image embeddings
    image_embeddings, all_targets = get_image_embeddings(dataloader, model_info, device)
    all_targets = np.array(all_targets)

    # Normalize embeddings
    image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
    text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)

    # 5. Compute similarities and predictions
    # temperature = 0.01
    for temperature in [0.005, 0.01, 0.05, 0.1]:
        similarities = np.matmul(image_embeddings, text_embeddings.T) / temperature
        probs = np.exp(similarities) / np.sum(np.exp(similarities), axis=1, keepdims=True)

        preds = np.argmax(probs, axis=1)
        confidences = np.max(probs, axis=1)

        # Compute class-wise actual and predicted accuracies
        actual_accuracies = []
        predicted_accuracies = []
        for c in range(num_classes):
            cur_class_indices = (all_targets == c)
            cur_class_pred = preds[cur_class_indices]
            cur_class_accuracy = (cur_class_pred == c).mean()
            if np.isnan(cur_class_accuracy):
                continue
            
            predicted_cur_class_indices = (preds == c)
            predicted_confidences = confidences[predicted_cur_class_indices]
            predicted_acc = predicted_confidences.mean() if predicted_confidences.size > 0 else np.nan

            actual_accuracies.append(cur_class_accuracy)
            if not np.isnan(predicted_acc):
                predicted_accuracies.append(predicted_acc)
            else:
                predicted_accuracies.append(0.0)
                
        # Compute Spearman correlation
        spearman_corr = spearmanr(np.array(actual_accuracies), np.array(predicted_accuracies))
        print(f"Dataset: {dataset_name}, Spearman correlation: {spearman_corr}")

        # Plot and save
        fig, ax = plt.subplots(figsize=(8,8))
        if USE_PSEUDO_DATA: # orange
            color = 'tab:orange'
        else: # blue
            color = 'tab:blue'
        ax.scatter(predicted_accuracies, actual_accuracies, color=color)
        ax.set_xlabel("Predicted Accuracy (Confidence)")
        ax.set_ylabel("Actual Accuracy")
        if USE_PSEUDO_DATA:
            title_name = f"Calibration Approach: {dataset_name} with Pseudo Images (Temperature = {1/temperature})\n Spearman Correlation: {spearman_corr[0]:.2f}"
        else:
            title_name = f"Calibration Approach: {dataset_name} with Real Images (Temperature = {1/temperature})\n Spearman Correlation: {spearman_corr[0]:.2f}"
        ax.set_title(title_name)
        ax.grid(True)
        # ax.set_xlim([0, 1])
        # ax.set_ylim([0, 1])
        if USE_PSEUDO_DATA:
            plot_filename = f"figures/{dataset_name}_pseudo_pred_vs_actual_{temperature}.png"
        else:
            plot_filename = f"figures/{dataset_name}_pred_vs_actual_{temperature}.png"
        plt.savefig(plot_filename)
        plt.close(fig)

# Run experiments for each dataset

# USE_PSEUDO_DATA = True
# for dinfo in datasets_info:
#     run_experiment(dinfo["name"], dinfo["template"])
USE_PSEUDO_DATA = False
for dinfo in datasets_info:
    run_experiment(dinfo["name"], dinfo["template"])


Processed batch 1/4
Processed batch 2/4
Processed batch 3/4
Processed batch 4/4


Processing Images: 100%|██████████| 185/185 [01:14<00:00,  2.47it/s]


Dataset: CUB_200_2011, Spearman correlation: SignificanceResult(statistic=0.8906007894462, pvalue=1.1184070199550016e-69)
Dataset: CUB_200_2011, Spearman correlation: SignificanceResult(statistic=0.8781028005954786, pvalue=2.6330984361777895e-65)
Dataset: CUB_200_2011, Spearman correlation: SignificanceResult(statistic=0.6701227937060699, pvalue=1.964492082710588e-27)
Dataset: CUB_200_2011, Spearman correlation: SignificanceResult(statistic=0.5945655086098027, pvalue=1.662329032885547e-20)
