In [1]:
import json
from typing import List, Tuple

import clip
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
from scipy.stats import spearmanr, pearsonr
from matplotlib import pyplot as plt

import sys

sys.path.append("..")
from datasets import ImageDataset, TextDataset, create_dataloader
from models import Linear
from trainer import run_one_epoch
from utils import computing_subgroup_metrics, subgrouping
from prepare_text_datasets import prepare_waterbird, prepare_fairface, prepare_dspites

In [2]:
def get_model_output(
    model: torch.nn.Module,
    clip_model: torch.nn.Module,
    transform: torchvision.transforms,
    image_data: List[dict],
    text_data: List[dict],
) -> Tuple[dict, dict]:
    image_dataset = ImageDataset(data=image_data)
    image_dataloader = create_dataloader(
        dataset=image_dataset, modality="image", transform=transform
    )
    image_metrics = run_one_epoch(
        dataloader=image_dataloader,
        model=model,
        clip_model=clip_model,
        modality="image",
        opt=None,
        epoch_idx=-1,
        eval=True,
        verbose=False,
    )

    text_dataset = TextDataset(data=text_data)
    text_dataloader = create_dataloader(dataset=text_dataset, modality="text")
    text_metrics = run_one_epoch(
        dataloader=text_dataloader,
        model=model,
        clip_model=clip_model,
        modality="text",
        opt=None,
        epoch_idx=-1,
        eval=True,
        verbose=False,
    )
    return image_metrics, text_metrics


def compute_correlation(
    data1_list: List, data2_list: List, visualization: bool = False
) -> None:
    assert len(data1_list) == len(data2_list)
    data1 = np.array(data1_list)
    data2 = np.array(data2_list)
    spearmanr_corr, spearmanr_pval = spearmanr(data1, data2)
    pearsonr_corr, pearsonr_pval = pearsonr(data1, data2)
    print(f"Spearman correlation: {spearmanr_corr:.4f} (p-value: {spearmanr_pval:.4f})")
    print(f"Pearson correlation: {pearsonr_corr:.4f} (p-value: {pearsonr_pval:.4f})")
    if visualization:
        plt.scatter(data1, data2)
        plt.xlabel("Image")
        plt.ylabel("Text")
        plt.show()


def compute_subgroup_correlation(
    image_data: List,
    image_metrics: List,
    text_data: List,
    text_metrics: List,
    fields: List[str],
    visualization: bool = False,
) -> None:
    image_subgroups = subgrouping(image_data, fields)
    image_instance_accs = np.array(image_metrics["preds"]) == np.array(
        image_metrics["labels"]
    )
    image_subgroup_accs = computing_subgroup_metrics(
        image_instance_accs, image_subgroups
    )

    text_subgroups = subgrouping(text_data, fields)
    text_instance_accs = np.array(text_metrics["preds"]) == np.array(
        text_metrics["labels"]
    )
    text_subgroup_accs = computing_subgroup_metrics(text_instance_accs, text_subgroups)

    text_instance_probs = torch.softmax(
        torch.tensor(text_metrics["logits"]), dim=1
    ).numpy()[np.arange(len(text_metrics["labels"])), text_metrics["labels"]]
    text_subgroup_probs = computing_subgroup_metrics(
        text_instance_probs, text_subgroups
    )

    print("Text Acc - Image Acc Correlation:")
    compute_correlation(
        [text_subgroup_accs[x] for x in image_subgroups],
        [image_subgroup_accs[x] for x in image_subgroups],
    )
    print("Text Prob - Image Acc Correlation:")
    compute_correlation(
        [text_subgroup_probs[x] for x in image_subgroups],
        [image_subgroup_accs[x] for x in image_subgroups],
    )

# Waterbird

In [3]:
CLIP_MODEL = "ViT-B/32"
LINEAR_MODEL = "../pytorch_cache/iclrsubmission/models/waterbird_linear_model.pt"
DATA_PATH = "../../data/Waterbird/processed_attribute_dataset/attributes.jsonl"
FIELDS = ["species", "place"]
N_CLASS = 2

clip_model, transform = clip.load(name=CLIP_MODEL, device="cuda")
clip_model = clip_model.float()
model = Linear(clip_model.visual.output_dim, N_CLASS).cuda()
model.load_state_dict(torch.load(LINEAR_MODEL))


def filter_fn(x):
    return x["attributes"]["split"] == "val"


def label_fn(x):
    return x["attributes"]["waterbird"]


image_data = [json.loads(line) for line in open(DATA_PATH)]
image_data = [x for x in image_data if filter_fn(x)]
for item in image_data:
    item["label"] = label_fn(item)

text_data_concat = prepare_waterbird(data_path=DATA_PATH, input_type="concat")
text_data_prompt = prepare_waterbird(data_path=DATA_PATH, input_type="prompt")
text_data_ensemble = prepare_waterbird(data_path=DATA_PATH, input_type="ensemble")

print("\nConcat:\n")
image_metrics, text_metrics_concat = get_model_output(
    model, clip_model, transform, image_data, text_data_concat
)
compute_subgroup_correlation(
    image_data, image_metrics, text_data_concat, text_metrics_concat, fields=FIELDS
)

print("\nPrompt:\n")
image_metrics, text_metrics_prompt = get_model_output(
    model, clip_model, transform, image_data, text_data_prompt
)
compute_subgroup_correlation(
    image_data, image_metrics, text_data_prompt, text_metrics_prompt, fields=FIELDS
)

print("\nEnsemble:\n")
image_metrics, text_metrics_ensemble = get_model_output(
    model, clip_model, transform, image_data, text_data_ensemble
)
compute_subgroup_correlation(
    image_data, image_metrics, text_data_ensemble, text_metrics_ensemble, fields=FIELDS
)


Concat:

Text Acc - Image Acc Correlation:
Spearman correlation: 0.4167 (p-value: 0.0000)
Pearson correlation: 0.4355 (p-value: 0.0000)
Text Prob - Image Acc Correlation:
Spearman correlation: 0.5899 (p-value: 0.0000)
Pearson correlation: 0.5773 (p-value: 0.0000)

Prompt:

Text Acc - Image Acc Correlation:
Spearman correlation: 0.5607 (p-value: 0.0000)
Pearson correlation: 0.5742 (p-value: 0.0000)
Text Prob - Image Acc Correlation:
Spearman correlation: 0.6462 (p-value: 0.0000)
Pearson correlation: 0.6721 (p-value: 0.0000)

Ensemble:

Text Acc - Image Acc Correlation:
Spearman correlation: 0.6704 (p-value: 0.0000)
Pearson correlation: 0.6091 (p-value: 0.0000)
Text Prob - Image Acc Correlation:
Spearman correlation: 0.6465 (p-value: 0.0000)
Pearson correlation: 0.6776 (p-value: 0.0000)
