In [1]:
import pathlib
import numpy as np


def get_counts(folder):
    counts_total = None
    for file in pathlib.Path(folder).glob("*"):
        if counts_total is None:
            counts_total = np.load(file)
        else:
            counts_total += np.load(file)
    return counts_total

In [6]:
# Paths to folders with .npy files of caption partition counts (see dataset_synset_counts)
counts_2B = get_counts("YOUR_PATH_TO_COUNTS_DIRECTORY")
counts_400M = get_counts("YOUR_PATH_TO_COUNTS_DIRECTORY")
counts_coyo = get_counts("YOUR_PATH_TO_COUNTS_DIRECTORY")

In [None]:
from src import hierarchy, prompt

h = hierarchy.Hierarchy("../wordnet_classes/imagenet.txt")
all_lemmas = []

for s in h.get_all_synsets(True):
    all_lemmas.append((prompt.clean_lemma(s.lemmas()[0].name()), s))

all_lemmas.sort()
synset_to_idx = {
    x[1].name(): i for i, x in enumerate(all_lemmas)
}
synset_to_lemma = {
    y.name(): x for x, y in all_lemmas
}

In [19]:
import pickle
from scipy.stats import spearmanr
import matplotlib.pyplot as plt

def get_corr(model_name, counts):
    ispm = np.load(f"YOUR_PATH_TO_METRICS/{model_name}/subtree_in_prob_32.npz")
    scsm = np.load(f"YOUR_PATH_TO_METRICS/{model_name}/subtree_is.npz")

    isp_k = [x for x in sorted(ispm.keys()) if x != "average" and len(synset_to_lemma[x].split()) == 1]
    scs_k = [x for x in sorted(scsm.keys()) if x != "average" and len(synset_to_lemma[x].split()) == 1]

    isp = [ispm[x] for x in isp_k]
    isp_counts = [counts[synset_to_idx[x]] for x in isp_k]
    scs = [scsm[x] for x in scs_k]
    scs_counts = [counts[synset_to_idx[x]] for x in scs_k]

    return spearmanr(isp, isp_counts), spearmanr(scs, scs_counts)

In [24]:
pretty_names = [
    "GLIDE",
    "LDM",
    "SD 1.4",
    "SD 2.0",
    "unCLIP",
]

In [35]:
def make_table():
    print("ISP")
    print("Model", "400M\t", "2B\t", "COYO\t", sep="\t")
    for i, model in enumerate(["glide_75", "ldm_75", "sd14_75", "sd20_75", "unclip_75"]):
        results = []
        for counts in [counts_400M, counts_2B, counts_coyo]:
            isp_corr, scs_corr = get_corr(model, counts)
            results.append(f"{isp_corr.statistic:.2f} ({isp_corr.pvalue:.2f})")
        print(pretty_names[i], *results, sep="\t")
    
    print()
    print("SCS")
    print("Model", "400M\t", "2B\t", "COYO\t", sep="\t")
    for i, model in enumerate(["glide_75", "ldm_75", "sd14_75", "sd20_75", "unclip_75"]):
        results = []
        for counts in [counts_400M, counts_2B, counts_coyo]:
            isp_corr, scs_corr = get_corr(model, counts)
            results.append(f"{scs_corr.statistic:.2f} ({scs_corr.pvalue:.2f})")
        print(pretty_names[i], *results, sep="\t")

In [36]:
make_table()

ISP
Model	400M		2B		COYO	
GLIDE	0.19 (0.00)	0.18 (0.00)	0.16 (0.00)
LDM	0.29 (0.00)	0.27 (0.00)	0.24 (0.00)
SD 1.4	0.06 (0.15)	0.04 (0.34)	0.01 (0.81)
SD 2.0	0.10 (0.01)	0.08 (0.04)	0.05 (0.18)
unCLIP	0.02 (0.63)	0.00 (0.91)	-0.02 (0.61)

SCS
Model	400M		2B		COYO	
GLIDE	0.28 (0.00)	0.29 (0.00)	0.29 (0.00)
LDM	0.15 (0.01)	0.16 (0.00)	0.17 (0.00)
SD 1.4	0.00 (0.97)	0.01 (0.83)	0.03 (0.64)
SD 2.0	0.07 (0.21)	0.08 (0.16)	0.08 (0.12)
unCLIP	0.04 (0.44)	0.05 (0.32)	0.08 (0.17)
