In [None]:
import json
from matplotlib import pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr, pearsonr

plt.rcParams["figure.dpi"] = 300
plt.rcParams["savefig.dpi"] = 300
sns.set_theme(style="whitegrid")
sns.set_context("notebook")  # paper, notebook, talk, and poster


def get_accuracy(file_name: str):
    data = [json.loads(line) for line in open(file_name)]
    accs = []
    class_acc = {}
    for item in data:
        if "label" not in item:
            item["label"] = (
                item["question_id"].split("-")[2]
                if "imagenet" in file_name
                else item["question_id"].split("-")[1]
            )
        if "pred" not in item:
            item["pred"] = item["text"]

        acc = item["label"].lower() in item["pred"].lower()
        class_acc[item["label"]] = class_acc.get(item["label"], []) + [acc]
        accs.append(acc)
    print(sum(accs) / len(accs), len(accs))

    for item in class_acc:
        class_acc[item] = sum(class_acc[item]) / len(class_acc[item])

    class_acc_sorted = {
        k: v
        for k, v in sorted(class_acc.items(), key=lambda item: item[1], reverse=True)
    }
    return class_acc_sorted

# Compute Correlation

In [None]:
dataset = "imagenet"

class_freq_tuple_pretrain = {
    k: v for k, v in json.load(open(f"../data/tokenized/{dataset}_freq_pretrain.json"))
}
class_freq_tuple_instruct = {
    k: v for k, v in json.load(open(f"../data/tokenized/{dataset}_freq_instruct.json"))
}

class_freq_tuple = {
    k: class_freq_tuple_pretrain.get(k, 0) + class_freq_tuple_instruct.get(k, 0)
    for k in class_freq_tuple_pretrain
}

plt.figure(figsize=(15, 5))
class_freq_tuple_new = tuple(
    sorted(class_freq_tuple.items(), key=lambda x: x[1], reverse=True)
)[:80]
plt.bar(
    range(len(class_freq_tuple_new)), [x[1] for x in class_freq_tuple_new], color="b"
)
plt.xticks(
    range(len(class_freq_tuple_new)),
    [x[0] for x in class_freq_tuple_new],
    rotation=90,
    fontsize=8,
)
plt.ylabel("Frequency")
plt.xlabel("Class")

class_acc_sorted = get_accuracy(f"../main_results/outputs/{dataset}_llava7b.jsonl")
# class_acc_sorted = get_accuracy(f"../main_results/outputs/{dataset}_clipvitl336_1000classes.jsonl")

xs, ys = [], []
for key in class_acc_sorted:
    xs.append(class_freq_tuple.get(key.lower(), 0))
    ys.append(class_acc_sorted[key])

print(spearmanr(xs, ys), pearsonr(xs, ys))

# Draw Correlation

In [None]:
from matplotlib import pyplot as plt

plt.figure(figsize=(15, 3))

bins = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1000, 10000, 100000]
ys_bins = [[] for _ in range(len(bins) - 1)]
for x, y in zip(xs, ys):
    for i in range(len(bins) - 1):
        if bins[i] <= x < bins[i + 1]:
            ys_bins[i].append(y)
            break

n_classes = [len(ys_bin) for ys_bin in ys_bins]
ys_bins = [sum(ys_bin) / (len(ys_bin) + 1e-3) for ys_bin in ys_bins]

plt.plot(range(len(ys_bins)), ys_bins, marker="o")
plt.xticks(
    range(len(ys_bins)),
    [f"[{bins[i]}, {bins[i+1]}) ({n_classes[i]} classes)" for i in range(len(ys_bins))],
    rotation=45,
)
plt.xlabel("Frequency")
plt.ylabel("Accuracy")