In [1]:
import json
import sys

sys.path.append("..")
from diagnose import (
    prepare_waterbird,
    prepare_fairface,
    prepare_dsprites,
    discover_slices,
)

# Waterbird

In [2]:
discover_slices(
    clip_model_name="ViT-B/32",
    linear_model_path="../pytorch_cache/iclrsubmission/models/waterbird_linear_model.pt",
    data_path="../../data/Waterbird/processed_attribute_dataset/attributes.jsonl",
    filter_fn=lambda i, x: x["attributes"]["split"] == "val",
    label_fn=lambda x: x["attributes"]["waterbird"],
    prepare_fn=prepare_waterbird,
    fields=["waterbird", "waterplace"],
)

Text Acc - Image Acc Correlation:
Spearman correlation: 1.0000 (p-value: 0.0000)
Pearson correlation: 0.9799 (p-value: 0.0201)
Text Prob - Image Acc Correlation:
Spearman correlation: 1.0000 (p-value: 0.0000)
Pearson correlation: 0.9953 (p-value: 0.0047)


[((('waterbird', 1), ('waterplace', 0)),
  (0.3233082706766917, 0.2403532608695652, 0.32583714)),
 ((('waterbird', 0), ('waterplace', 1)),
  (0.6523605150214592, 0.7508116883116883, 0.7029176)),
 ((('waterbird', 1), ('waterplace', 1)),
  (0.9548872180451128, 0.9490489130434783, 0.9306395)),
 ((('waterbird', 0), ('waterplace', 0)), (0.9978586723768736, 1.0, 0.9957342))]

# FairFace

In [4]:
discover_slices(
    clip_model_name="ViT-B/32",
    linear_model_path="../pytorch_cache/iclrsubmission/models/fairface_linear_model.pt",
    data_path="../../data/FairFace/processed_attribute_dataset/attributes.jsonl",
    filter_fn=lambda i, x: x["attributes"]["split"] == "val",
    label_fn=lambda x: int(x["attributes"]["gender"] == "Female"),
    prepare_fn=prepare_fairface,
    fields=["race"],
)

Text Acc - Image Acc Correlation:
Spearman correlation: 0.5045 (p-value: 0.2482)
Pearson correlation: 0.2396 (p-value: 0.6048)
Text Prob - Image Acc Correlation:
Spearman correlation: 0.8214 (p-value: 0.0234)
Pearson correlation: 0.9091 (p-value: 0.0046)


[((('race', 'Black'),), (0.8997429305912596, 0.9958333333333333, 0.9133974)),
 ((('race', 'Southeast Asian'),),
  (0.9342756183745583, 0.9972222222222222, 0.9332894)),
 ((('race', 'East Asian'),),
  (0.9419354838709677, 0.9916666666666667, 0.92771006)),
 ((('race', 'Indian'),), (0.9445910290237467, 0.99375, 0.9267614)),
 ((('race', 'Latino_Hispanic'),),
  (0.9537892791127541, 0.9979166666666667, 0.9463455)),
 ((('race', 'White'),), (0.9597122302158273, 0.9965277777777778, 0.9427252)),
 ((('race', 'Middle Eastern'),),
  (0.967741935483871, 0.9979166666666667, 0.95262563))]

# dSprites

In [5]:
train_idxs, val_idxs = json.load(
    open("../pytorch_cache/iclrsubmission/models/dsprites_train_val_idxs_2class.json")
)

discover_slices(
    clip_model_name="ViT-B/32",
    linear_model_path="../pytorch_cache/iclrsubmission/models/dsprites_linear_model_2class.pt",
    data_path="../../data/TriangleSquare/processed_attribute_dataset/attributes.jsonl",
    filter_fn=lambda i, x: i in val_idxs,
    label_fn=lambda x: x["attributes"]["label"],
    prepare_fn=prepare_dsprites,
    fields=["color", "label"],
)

Text Acc - Image Acc Correlation:
Spearman correlation: 0.6667 (p-value: 0.0179)
Pearson correlation: 0.9166 (p-value: 0.0000)
Text Prob - Image Acc Correlation:
Spearman correlation: 0.6961 (p-value: 0.0119)
Pearson correlation: 0.8079 (p-value: 0.0015)


[((('color', 'orange'), ('label', 0)),
  (0.033734939759036145, 0.020833333333333332, 0.29899698)),
 ((('color', 'green'), ('label', 1)), (0.06165228113440197, 0.0, 0.16407745)),
 ((('color', 'blue'), ('label', 1)),
  (0.6575342465753424, 0.7791666666666667, 0.56714785)),
 ((('color', 'cyan'), ('label', 0)), (0.9761904761904762, 1.0, 0.83357334)),
 ((('color', 'pink'), ('label', 1)),
  (0.9861431870669746, 0.5541666666666667, 0.5043879)),
 ((('color', 'red'), ('label', 0)),
  (0.9939759036144579, 0.9833333333333333, 0.74450433)),
 ((('color', 'red'), ('label', 1)), (0.9954337899543378, 0.775, 0.5650696)),
 ((('color', 'blue'), ('label', 0)), (1.0, 1.0, 0.8061522)),
 ((('color', 'cyan'), ('label', 1)), (1.0, 0.775, 0.5658888)),
 ((('color', 'green'), ('label', 0)), (1.0, 1.0, 0.97125906)),
 ((('color', 'orange'), ('label', 1)), (1.0, 1.0, 0.93163455)),
 ((('color', 'pink'), ('label', 0)), (1.0, 1.0, 0.86252326))]