In [1]:
import torch as t
from utils import DataManager
import random
import matplotlib.pyplot as plt
import random
from probes import LRProbe, MMProbe, CCSProbe
from matplotlib import gridspec

In [2]:
# hyperparameters
model = 'llama-8b' # llama-13b
model_size = '8B' # 13B
layer = 12 # layer from which to extract activations
split = 0.8

device = 'cuda:0' if t.cuda.is_available() else 'cpu'

# Experiment 1: llama-13b, layer 12
# Experiment 2: llama-13b, layer 12 (just to try a different split, see the effect of randomness)
# Experiment 3: llama-13b, layer 13 (the last useful layer, I think, identified by probing)
# Experiment 4: llama-8b, layer 12

# Reproducing generalization matrix

In [4]:
train_medlies  = [
    ['experiment_cps'],
    ['experiment_inter_stereoset'], # gender
    ['experiment_intra_stereoset'], # gender
    ['experiment_inter_race_stereoset'],
    ['experiment_intra_race_stereoset'],
    ['experiment_inter_profession_stereoset'],
    ['experiment_intra_profession_stereoset'],
    ['experiment_inter_religion_stereoset'],
    ['experiment_intra_religion_stereoset'],
    ['likely']
]

val_datasets = [
    'experiment_cps',
    'experiment_inter_stereoset', # gender
    'experiment_intra_stereoset', # gender
    'experiment_inter_race_stereoset',
    'experiment_intra_race_stereoset',
    'experiment_inter_profession_stereoset',
    'experiment_intra_profession_stereoset',
    'experiment_inter_religion_stereoset',
    'experiment_intra_religion_stereoset',    
]

def to_str(l):
    return '+'.join(l)

ProbeClasses = [
    LRProbe, 
    MMProbe, 
    ]

accs = {str(probe_class) : {to_str(train_medley) : {} for train_medley in train_medlies} for probe_class in ProbeClasses}

seed = random.randint(0, 100000)

In [5]:
for ProbeClass in ProbeClasses:
    for medley in train_medlies:

        # set up data
        dm = DataManager()
        for dataset in medley:
            dm.add_dataset(dataset, model_size, layer, split=split, seed=seed, center=True, device=device)
        for dataset in val_datasets:
            if dataset not in medley:
                dm.add_dataset(dataset, model_size, layer, split=None, center=True, device=device)

        # train probe
        train_acts, train_labels = dm.get('train')
        probe = ProbeClass.from_data(train_acts, train_labels, device=device)


        # evaluate
        for val_dataset in val_datasets:
            if val_dataset in medley:
                acts, labels = dm.data['val'][val_dataset]
                accs[str(ProbeClass)][to_str(medley)][val_dataset] = (
                    probe.pred(acts, iid=True) == labels
                ).float().mean().item()
            else:
                acts, labels = dm.data[val_dataset]
                accs[str(ProbeClass)][to_str(medley)][val_dataset] = (
                    probe.pred(acts, iid=False) == labels
                ).float().mean().item()

lr_mm_accs = accs.copy()

In [6]:
def normal_name(name):
    if name == "experiment_cps":
        new = "CP gender"
    elif name == "experiment_inter_stereoset":
        new = "SS1 gender"
    elif name == "experiment_intra_stereoset":
        new = "SS2 gender"
    elif name == "experiment_inter_race_stereoset":
        new = "SS1 race"
    elif name == "experiment_intra_race_stereoset":
        new = "SS2 race"
    elif name == "experiment_inter_religion_stereoset":
        new = "SS1 religion"
    elif name == "experiment_intra_religion_stereoset":
        new = "SS2 religion"
    elif name == "experiment_inter_profession_stereoset":
        new = "SS1 profession"
    elif name == "experiment_intra_profession_stereoset":
        new = "SS2 profession"
    elif name == "likely":
        return name
    return new

In [7]:
# get oracle probe results
oracle_accs = {str(probe_class) : [] for probe_class in ProbeClasses}
for ProbeClass in ProbeClasses:
    for dataset in val_datasets:
        dm = DataManager()
        dm.add_dataset(dataset, model_size, layer, split=split, seed=seed, device=device)
        acts, labels = dm.get('train')
        probe = ProbeClass.from_data(acts, labels, device=device)

        acts, labels = dm.data['val'][dataset]
        acc = (probe(acts, iid=True).round() == labels).float().mean().item()
        oracle_accs[str(ProbeClass)].append(acc)

In [10]:
# print(oracle_accs)

In [11]:
# print(lr_mm_accs)

In [6]:
# Data from exp2
# oracle_accs = {"<class 'probes.LRProbe'>": [0.43529412150382996, 0.766331672668457, 0.6131386756896973, 0.8337579965591431, 0.7252538204193115, 0.7969349026679993, 0.6970404982566833, 0.8046875, 0.7022900581359863], "<class 'probes.MMProbe'>": [0.3529411852359772, 0.643216073513031, 0.5936739444732666, 0.7445859909057617, 0.6414974331855774, 0.6865900754928589, 0.6292834877967834, 0.8359375, 0.6870229244232178]}
# lr_mm_accs = {"<class 'probes.LRProbe'>": {'experiment_cps': {'experiment_cps': 0.43529412150382996, 'experiment_inter_stereoset': 0.6782477498054504, 'experiment_intra_stereoset': 0.6315789222717285, 'experiment_inter_race_stereoset': 0.696533203125, 'experiment_intra_race_stereoset': 0.6499492526054382, 'experiment_inter_profession_stereoset': 0.639178454875946, 'experiment_intra_profession_stereoset': 0.602867841720581, 'experiment_inter_religion_stereoset': 0.702194333076477, 'experiment_intra_religion_stereoset': 0.6426380276679993}, 'experiment_inter_stereoset': {'experiment_cps': 0.5966981053352356, 'experiment_inter_stereoset': 0.766331672668457, 'experiment_intra_stereoset': 0.6539961099624634, 'experiment_inter_race_stereoset': 0.7825643420219421, 'experiment_intra_race_stereoset': 0.7127984166145325, 'experiment_inter_profession_stereoset': 0.7434089779853821, 'experiment_intra_profession_stereoset': 0.6200124621391296, 'experiment_inter_religion_stereoset': 0.7931034564971924, 'experiment_intra_religion_stereoset': 0.6947852373123169}, 'experiment_intra_stereoset': {'experiment_cps': 0.6061320900917053, 'experiment_inter_stereoset': 0.673212468624115, 'experiment_intra_stereoset': 0.6131386756896973, 'experiment_inter_race_stereoset': 0.5986489653587341, 'experiment_intra_race_stereoset': 0.6145251393318176, 'experiment_inter_profession_stereoset': 0.6336603164672852, 'experiment_intra_profession_stereoset': 0.6779925227165222, 'experiment_inter_religion_stereoset': 0.6253918409347534, 'experiment_intra_religion_stereoset': 0.699386477470398}, 'experiment_inter_race_stereoset': {'experiment_cps': 0.5448113083839417, 'experiment_inter_stereoset': 0.708459198474884, 'experiment_intra_stereoset': 0.5760233998298645, 'experiment_inter_race_stereoset': 0.8337579965591431, 'experiment_intra_race_stereoset': 0.702133059501648, 'experiment_inter_profession_stereoset': 0.7185775637626648, 'experiment_intra_profession_stereoset': 0.5660848021507263, 'experiment_inter_religion_stereoset': 0.8025078177452087, 'experiment_intra_religion_stereoset': 0.657975435256958}, 'experiment_intra_race_stereoset': {'experiment_cps': 0.5400943756103516, 'experiment_inter_stereoset': 0.6359516382217407, 'experiment_intra_stereoset': 0.5891813039779663, 'experiment_inter_race_stereoset': 0.7538872957229614, 'experiment_intra_race_stereoset': 0.7252538204193115, 'experiment_inter_profession_stereoset': 0.6437768340110779, 'experiment_intra_profession_stereoset': 0.6200124621391296, 'experiment_inter_religion_stereoset': 0.7319748997688293, 'experiment_intra_religion_stereoset': 0.7085889577865601}, 'experiment_inter_profession_stereoset': {'experiment_cps': 0.5660377740859985, 'experiment_inter_stereoset': 0.7457200288772583, 'experiment_intra_stereoset': 0.595516562461853, 'experiment_inter_race_stereoset': 0.7659953832626343, 'experiment_intra_race_stereoset': 0.6342052221298218, 'experiment_inter_profession_stereoset': 0.7969349026679993, 'experiment_intra_profession_stereoset': 0.6337282061576843, 'experiment_inter_religion_stereoset': 0.7633228898048401, 'experiment_intra_religion_stereoset': 0.653374195098877}, 'experiment_intra_profession_stereoset': {'experiment_cps': 0.5683962106704712, 'experiment_inter_stereoset': 0.6258811950683594, 'experiment_intra_stereoset': 0.6471735239028931, 'experiment_inter_race_stereoset': 0.6326790452003479, 'experiment_intra_race_stereoset': 0.6956577301025391, 'experiment_inter_profession_stereoset': 0.6404046416282654, 'experiment_intra_profession_stereoset': 0.6970404982566833, 'experiment_inter_religion_stereoset': 0.6316614151000977, 'experiment_intra_religion_stereoset': 0.7223926186561584}, 'experiment_inter_religion_stereoset': {'experiment_cps': 0.5542452931404114, 'experiment_inter_stereoset': 0.7079557180404663, 'experiment_intra_stereoset': 0.5882066488265991, 'experiment_inter_race_stereoset': 0.7709660530090332, 'experiment_intra_race_stereoset': 0.6791518926620483, 'experiment_inter_profession_stereoset': 0.698651134967804, 'experiment_intra_profession_stereoset': 0.5729426741600037, 'experiment_inter_religion_stereoset': 0.8046875, 'experiment_intra_religion_stereoset': 0.6702454090118408}, 'experiment_intra_religion_stereoset': {'experiment_cps': 0.5471698045730591, 'experiment_inter_stereoset': 0.6359516382217407, 'experiment_intra_stereoset': 0.6398635506629944, 'experiment_inter_race_stereoset': 0.6408360600471497, 'experiment_intra_race_stereoset': 0.6946419477462769, 'experiment_inter_profession_stereoset': 0.6074494123458862, 'experiment_intra_profession_stereoset': 0.6441708207130432, 'experiment_inter_religion_stereoset': 0.7053291201591492, 'experiment_intra_religion_stereoset': 0.7022900581359863}, 'likely': {'experiment_cps': 0.5047169923782349, 'experiment_inter_stereoset': 0.47985902428627014, 'experiment_intra_stereoset': 0.49707603454589844, 'experiment_inter_race_stereoset': 0.4640581011772156, 'experiment_intra_race_stereoset': 0.5062214732170105, 'experiment_inter_profession_stereoset': 0.48099327087402344, 'experiment_intra_profession_stereoset': 0.48706361651420593, 'experiment_inter_religion_stereoset': 0.44984325766563416, 'experiment_intra_religion_stereoset': 0.47852760553359985}}, "<class 'probes.MMProbe'>": {'experiment_cps': {'experiment_cps': 0.3529411852359772, 'experiment_inter_stereoset': 0.6112789511680603, 'experiment_intra_stereoset': 0.5945419073104858, 'experiment_inter_race_stereoset': 0.6492480039596558, 'experiment_intra_race_stereoset': 0.6245556473731995, 'experiment_inter_profession_stereoset': 0.6065297722816467, 'experiment_intra_profession_stereoset': 0.5777743458747864, 'experiment_inter_religion_stereoset': 0.6755485534667969, 'experiment_intra_religion_stereoset': 0.6180981397628784}, 'experiment_inter_stereoset': {'experiment_cps': 0.5683962106704712, 'experiment_inter_stereoset': 0.643216073513031, 'experiment_intra_stereoset': 0.6364522576332092, 'experiment_inter_race_stereoset': 0.739357590675354, 'experiment_intra_race_stereoset': 0.7192737460136414, 'experiment_inter_profession_stereoset': 0.7066217064857483, 'experiment_intra_profession_stereoset': 0.6243765950202942, 'experiment_inter_religion_stereoset': 0.7993730306625366, 'experiment_intra_religion_stereoset': 0.7147238850593567}, 'experiment_intra_stereoset': {'experiment_cps': 0.5707547068595886, 'experiment_inter_stereoset': 0.7049345374107361, 'experiment_intra_stereoset': 0.5936739444732666, 'experiment_inter_race_stereoset': 0.6126688718795776, 'experiment_intra_race_stereoset': 0.6253174543380737, 'experiment_inter_profession_stereoset': 0.6514408588409424, 'experiment_intra_profession_stereoset': 0.6687967777252197, 'experiment_inter_religion_stereoset': 0.6692789793014526, 'experiment_intra_religion_stereoset': 0.7009202241897583}, 'experiment_inter_race_stereoset': {'experiment_cps': 0.5235849022865295, 'experiment_inter_stereoset': 0.6530715227127075, 'experiment_intra_stereoset': 0.5355750322341919, 'experiment_inter_race_stereoset': 0.7445859909057617, 'experiment_intra_race_stereoset': 0.7120366096496582, 'experiment_inter_profession_stereoset': 0.6702942848205566, 'experiment_intra_profession_stereoset': 0.528678297996521, 'experiment_inter_religion_stereoset': 0.7899686098098755, 'experiment_intra_religion_stereoset': 0.6180981397628784}, 'experiment_intra_race_stereoset': {'experiment_cps': 0.5353773832321167, 'experiment_inter_stereoset': 0.6565961837768555, 'experiment_intra_stereoset': 0.5501949191093445, 'experiment_inter_race_stereoset': 0.8083099126815796, 'experiment_intra_race_stereoset': 0.6414974331855774, 'experiment_inter_profession_stereoset': 0.6603310704231262, 'experiment_intra_profession_stereoset': 0.5612531304359436, 'experiment_inter_religion_stereoset': 0.7742946743965149, 'experiment_intra_religion_stereoset': 0.6671779155731201}, 'experiment_inter_profession_stereoset': {'experiment_cps': 0.5613207817077637, 'experiment_inter_stereoset': 0.7104732990264893, 'experiment_intra_stereoset': 0.5994151830673218, 'experiment_inter_race_stereoset': 0.765230655670166, 'experiment_intra_race_stereoset': 0.7097511291503906, 'experiment_inter_profession_stereoset': 0.6865900754928589, 'experiment_intra_profession_stereoset': 0.6059850454330444, 'experiment_inter_religion_stereoset': 0.8025078177452087, 'experiment_intra_religion_stereoset': 0.6809815764427185}, 'experiment_intra_profession_stereoset': {'experiment_cps': 0.551886796951294, 'experiment_inter_stereoset': 0.6535750031471252, 'experiment_intra_stereoset': 0.6666666865348816, 'experiment_inter_race_stereoset': 0.577746570110321, 'experiment_intra_race_stereoset': 0.6206195950508118, 'experiment_inter_profession_stereoset': 0.6350398659706116, 'experiment_intra_profession_stereoset': 0.6292834877967834, 'experiment_inter_religion_stereoset': 0.6833855509757996, 'experiment_intra_religion_stereoset': 0.7147238850593567}, 'experiment_inter_religion_stereoset': {'experiment_cps': 0.5400943756103516, 'experiment_inter_stereoset': 0.6787512302398682, 'experiment_intra_stereoset': 0.5745614171028137, 'experiment_inter_race_stereoset': 0.7872800827026367, 'experiment_intra_race_stereoset': 0.7277806401252747, 'experiment_inter_profession_stereoset': 0.6828632950782776, 'experiment_intra_profession_stereoset': 0.5581359267234802, 'experiment_inter_religion_stereoset': 0.8359375, 'experiment_intra_religion_stereoset': 0.6656441688537598}, 'experiment_intra_religion_stereoset': {'experiment_cps': 0.5613207817077637, 'experiment_inter_stereoset': 0.6993957757949829, 'experiment_intra_stereoset': 0.6384015679359436, 'experiment_inter_race_stereoset': 0.7756818532943726, 'experiment_intra_race_stereoset': 0.7470797300338745, 'experiment_inter_profession_stereoset': 0.6961986422538757, 'experiment_intra_profession_stereoset': 0.6527431607246399, 'experiment_inter_religion_stereoset': 0.8181818127632141, 'experiment_intra_religion_stereoset': 0.6870229244232178}, 'likely': {'experiment_cps': 0.5165094137191772, 'experiment_inter_stereoset': 0.5161128044128418, 'experiment_intra_stereoset': 0.5263158082962036, 'experiment_inter_race_stereoset': 0.5373438596725464, 'experiment_intra_race_stereoset': 0.5627222061157227, 'experiment_inter_profession_stereoset': 0.518853485584259, 'experiment_intra_profession_stereoset': 0.5182356834411621, 'experiment_inter_religion_stereoset': 0.5109717845916748, 'experiment_intra_religion_stereoset': 0.5506134629249573}}}

In [None]:
# Data from exp4
# oracle_accs = {"<class 'probes.LRProbe'>": [0.48235294222831726, 0.4648241102695465, 0.4768856465816498, 0.8299363255500793, 0.7055837512016296, 0.7954022884368896, 0.6526479721069336, 0.8125, 0.6641221642494202], "<class 'probes.MMProbe'>": [0.5176470875740051, 0.4899497330188751, 0.4866180121898651, 0.7394904494285583, 0.6408629417419434, 0.6934866309165955, 0.6020249128341675, 0.75, 0.6335877776145935]}
# lr_mm_accs = {"<class 'probes.LRProbe'>": {'experiment_cps': {'experiment_cps': 0.48235294222831726, 'experiment_inter_stereoset': 0.4904330372810364, 'experiment_intra_stereoset': 0.514132559299469, 'experiment_inter_race_stereoset': 0.6664541959762573, 'experiment_intra_race_stereoset': 0.6529964804649353, 'experiment_inter_profession_stereoset': 0.619558572769165, 'experiment_intra_profession_stereoset': 0.5978803038597107, 'experiment_inter_religion_stereoset': 0.6912225484848022, 'experiment_intra_religion_stereoset': 0.6671779155731201}, 'experiment_inter_stereoset': {'experiment_cps': 0.525943398475647, 'experiment_inter_stereoset': 0.4648241102695465, 'experiment_intra_stereoset': 0.5185185074806213, 'experiment_inter_race_stereoset': 0.5233239531517029, 'experiment_intra_race_stereoset': 0.5502793192863464, 'experiment_inter_profession_stereoset': 0.49754753708839417, 'experiment_intra_profession_stereoset': 0.531951367855072, 'experiment_inter_religion_stereoset': 0.5062695741653442, 'experiment_intra_religion_stereoset': 0.48159506916999817}, 'experiment_intra_stereoset': {'experiment_cps': 0.5212264060974121, 'experiment_inter_stereoset': 0.5080564022064209, 'experiment_intra_stereoset': 0.4768856465816498, 'experiment_inter_race_stereoset': 0.5662757754325867, 'experiment_intra_race_stereoset': 0.5022854208946228, 'experiment_inter_profession_stereoset': 0.5482832789421082, 'experiment_intra_profession_stereoset': 0.48831048607826233, 'experiment_inter_religion_stereoset': 0.5595611333847046, 'experiment_intra_religion_stereoset': 0.5030674934387207}, 'experiment_inter_race_stereoset': {'experiment_cps': 0.5188679099082947, 'experiment_inter_stereoset': 0.4853977859020233, 'experiment_intra_stereoset': 0.4917154014110565, 'experiment_inter_race_stereoset': 0.8299363255500793, 'experiment_intra_race_stereoset': 0.7068309187889099, 'experiment_inter_profession_stereoset': 0.7317596673965454, 'experiment_intra_profession_stereoset': 0.569513738155365, 'experiment_inter_religion_stereoset': 0.8275861740112305, 'experiment_intra_religion_stereoset': 0.6595091819763184}, 'experiment_intra_race_stereoset': {'experiment_cps': 0.5424528121948242, 'experiment_inter_stereoset': 0.4989929497241974, 'experiment_intra_stereoset': 0.5175438523292542, 'experiment_inter_race_stereoset': 0.7354065179824829, 'experiment_intra_race_stereoset': 0.7055837512016296, 'experiment_inter_profession_stereoset': 0.6126609444618225, 'experiment_intra_profession_stereoset': 0.6262469291687012, 'experiment_inter_religion_stereoset': 0.722570538520813, 'experiment_intra_religion_stereoset': 0.7131901383399963}, 'experiment_inter_profession_stereoset': {'experiment_cps': 0.573113203048706, 'experiment_inter_stereoset': 0.4904330372810364, 'experiment_intra_stereoset': 0.4844054579734802, 'experiment_inter_race_stereoset': 0.7591128945350647, 'experiment_intra_race_stereoset': 0.6168105602264404, 'experiment_inter_profession_stereoset': 0.7954022884368896, 'experiment_intra_profession_stereoset': 0.6234413981437683, 'experiment_inter_religion_stereoset': 0.777429461479187, 'experiment_intra_religion_stereoset': 0.6671779155731201}, 'experiment_intra_profession_stereoset': {'experiment_cps': 0.573113203048706, 'experiment_inter_stereoset': 0.4929506480693817, 'experiment_intra_stereoset': 0.4941520392894745, 'experiment_inter_race_stereoset': 0.6293652653694153, 'experiment_intra_race_stereoset': 0.689309298992157, 'experiment_inter_profession_stereoset': 0.6488350629806519, 'experiment_intra_profession_stereoset': 0.6526479721069336, 'experiment_inter_religion_stereoset': 0.6363636255264282, 'experiment_intra_religion_stereoset': 0.7377300262451172}, 'experiment_inter_religion_stereoset': {'experiment_cps': 0.5566037893295288, 'experiment_inter_stereoset': 0.48892244696617126, 'experiment_intra_stereoset': 0.5219298005104065, 'experiment_inter_race_stereoset': 0.7633188366889954, 'experiment_intra_race_stereoset': 0.6648044586181641, 'experiment_inter_profession_stereoset': 0.6926732063293457, 'experiment_intra_profession_stereoset': 0.5684227347373962, 'experiment_inter_religion_stereoset': 0.8125, 'experiment_intra_religion_stereoset': 0.6794478297233582}, 'experiment_intra_religion_stereoset': {'experiment_cps': 0.5471698045730591, 'experiment_inter_stereoset': 0.5040282011032104, 'experiment_intra_stereoset': 0.49561405181884766, 'experiment_inter_race_stereoset': 0.6653071045875549, 'experiment_intra_race_stereoset': 0.692610502243042, 'experiment_inter_profession_stereoset': 0.6092888116836548, 'experiment_intra_profession_stereoset': 0.6412094831466675, 'experiment_inter_religion_stereoset': 0.7084639072418213, 'experiment_intra_religion_stereoset': 0.6641221642494202}, 'likely': {'experiment_cps': 0.48349058628082275, 'experiment_inter_stereoset': 0.5130916237831116, 'experiment_intra_stereoset': 0.5102339386940002, 'experiment_inter_race_stereoset': 0.4821564853191376, 'experiment_intra_race_stereoset': 0.5096495747566223, 'experiment_inter_profession_stereoset': 0.4816063940525055, 'experiment_intra_profession_stereoset': 0.5010910630226135, 'experiment_inter_religion_stereoset': 0.46865203976631165, 'experiment_intra_religion_stereoset': 0.4938650131225586}}, "<class 'probes.MMProbe'>": {'experiment_cps': {'experiment_cps': 0.5176470875740051, 'experiment_inter_stereoset': 0.5015105605125427, 'experiment_intra_stereoset': 0.5063352584838867, 'experiment_inter_race_stereoset': 0.6752485036849976, 'experiment_intra_race_stereoset': 0.6424581408500671, 'experiment_inter_profession_stereoset': 0.6146535873413086, 'experiment_intra_profession_stereoset': 0.5826060175895691, 'experiment_inter_religion_stereoset': 0.7006269693374634, 'experiment_intra_religion_stereoset': 0.6595091819763184}, 'experiment_inter_stereoset': {'experiment_cps': 0.4952830374240875, 'experiment_inter_stereoset': 0.4899497330188751, 'experiment_intra_stereoset': 0.5160818696022034, 'experiment_inter_race_stereoset': 0.49311748147010803, 'experiment_intra_race_stereoset': 0.5163788795471191, 'experiment_inter_profession_stereoset': 0.4665849208831787, 'experiment_intra_profession_stereoset': 0.49984416365623474, 'experiment_inter_religion_stereoset': 0.4921630024909973, 'experiment_intra_religion_stereoset': 0.4892638027667999}, 'experiment_intra_stereoset': {'experiment_cps': 0.4858490526676178, 'experiment_inter_stereoset': 0.5, 'experiment_intra_stereoset': 0.4866180121898651, 'experiment_inter_race_stereoset': 0.5423145294189453, 'experiment_intra_race_stereoset': 0.5426612496376038, 'experiment_inter_profession_stereoset': 0.5032188892364502, 'experiment_intra_profession_stereoset': 0.5007793307304382, 'experiment_inter_religion_stereoset': 0.5344827175140381, 'experiment_intra_religion_stereoset': 0.49846625328063965}, 'experiment_inter_race_stereoset': {'experiment_cps': 0.5235849022865295, 'experiment_inter_stereoset': 0.5010070204734802, 'experiment_intra_stereoset': 0.49512672424316406, 'experiment_inter_race_stereoset': 0.7394904494285583, 'experiment_intra_race_stereoset': 0.7120366096496582, 'experiment_inter_profession_stereoset': 0.669681191444397, 'experiment_intra_profession_stereoset': 0.5285224914550781, 'experiment_inter_religion_stereoset': 0.7915360331535339, 'experiment_intra_religion_stereoset': 0.6242331266403198}, 'experiment_intra_race_stereoset': {'experiment_cps': 0.5353773832321167, 'experiment_inter_stereoset': 0.5060423016548157, 'experiment_intra_stereoset': 0.49658870697021484, 'experiment_inter_race_stereoset': 0.809457004070282, 'experiment_intra_race_stereoset': 0.6408629417419434, 'experiment_inter_profession_stereoset': 0.6624770164489746, 'experiment_intra_profession_stereoset': 0.5589152574539185, 'experiment_inter_religion_stereoset': 0.7821316719055176, 'experiment_intra_religion_stereoset': 0.6625766754150391}, 'experiment_inter_profession_stereoset': {'experiment_cps': 0.551886796951294, 'experiment_inter_stereoset': 0.482880175113678, 'experiment_intra_stereoset': 0.49561405181884766, 'experiment_inter_race_stereoset': 0.7672699093818665, 'experiment_intra_race_stereoset': 0.7111477851867676, 'experiment_inter_profession_stereoset': 0.6934866309165955, 'experiment_intra_profession_stereoset': 0.6038030385971069, 'experiment_inter_religion_stereoset': 0.8025078177452087, 'experiment_intra_religion_stereoset': 0.6779140830039978}, 'experiment_intra_profession_stereoset': {'experiment_cps': 0.5495283007621765, 'experiment_inter_stereoset': 0.5055387616157532, 'experiment_intra_stereoset': 0.4995126724243164, 'experiment_inter_race_stereoset': 0.583482027053833, 'experiment_intra_race_stereoset': 0.6166836023330688, 'experiment_inter_profession_stereoset': 0.64071124792099, 'experiment_intra_profession_stereoset': 0.6020249128341675, 'experiment_inter_religion_stereoset': 0.6833855509757996, 'experiment_intra_religion_stereoset': 0.7223926186561584}, 'experiment_inter_religion_stereoset': {'experiment_cps': 0.5448113083839417, 'experiment_inter_stereoset': 0.503524661064148, 'experiment_intra_stereoset': 0.5014619827270508, 'experiment_inter_race_stereoset': 0.771093487739563, 'experiment_intra_race_stereoset': 0.7220670580863953, 'experiment_inter_profession_stereoset': 0.6765788197517395, 'experiment_intra_profession_stereoset': 0.5625, 'experiment_inter_religion_stereoset': 0.75, 'experiment_intra_religion_stereoset': 0.6840490698814392}, 'experiment_intra_religion_stereoset': {'experiment_cps': 0.551886796951294, 'experiment_inter_stereoset': 0.49546828866004944, 'experiment_intra_stereoset': 0.5058479309082031, 'experiment_inter_race_stereoset': 0.7714758515357971, 'experiment_intra_race_stereoset': 0.7310817837715149, 'experiment_inter_profession_stereoset': 0.7024831771850586, 'experiment_intra_profession_stereoset': 0.6521196961402893, 'experiment_inter_religion_stereoset': 0.8072100281715393, 'experiment_intra_religion_stereoset': 0.6335877776145935}, 'likely': {'experiment_cps': 0.5141509771347046, 'experiment_inter_stereoset': 0.5065458416938782, 'experiment_intra_stereoset': 0.4985380172729492, 'experiment_inter_race_stereoset': 0.5365791320800781, 'experiment_intra_race_stereoset': 0.5622143149375916, 'experiment_inter_profession_stereoset': 0.518853485584259, 'experiment_intra_profession_stereoset': 0.5196384191513062, 'experiment_inter_religion_stereoset': 0.5156739950180054, 'experiment_intra_religion_stereoset': 0.5506134629249573}}}

In [20]:
# Data from exp4
oracle_accs = {"<class 'probes.LRProbe'>": [0.48235294222831726, 0.4648241102695465, 0.4768856465816498, 0.8299363255500793, 0.7055837512016296, 0.7954022884368896, 0.6526479721069336, 0.8125, 0.6641221642494202], "<class 'probes.MMProbe'>": [0.5176470875740051, 0.4899497330188751, 0.4866180121898651, 0.7394904494285583, 0.6408629417419434, 0.6934866309165955, 0.6020249128341675, 0.75, 0.6335877776145935]}
lr_mm_accs = {"<class 'probes.LRProbe'>": {'experiment_cps': {'experiment_cps': 0.48235294222831726, 'experiment_inter_stereoset': 0.4904330372810364, 'experiment_intra_stereoset': 0.514132559299469, 'experiment_inter_race_stereoset': 0.6664541959762573, 'experiment_intra_race_stereoset': 0.6529964804649353, 'experiment_inter_profession_stereoset': 0.619558572769165, 'experiment_intra_profession_stereoset': 0.5978803038597107, 'experiment_inter_religion_stereoset': 0.6912225484848022, 'experiment_intra_religion_stereoset': 0.6671779155731201}, 'experiment_inter_stereoset': {'experiment_cps': 0.525943398475647, 'experiment_inter_stereoset': 0.4648241102695465, 'experiment_intra_stereoset': 0.5185185074806213, 'experiment_inter_race_stereoset': 0.5233239531517029, 'experiment_intra_race_stereoset': 0.5502793192863464, 'experiment_inter_profession_stereoset': 0.49754753708839417, 'experiment_intra_profession_stereoset': 0.531951367855072, 'experiment_inter_religion_stereoset': 0.5062695741653442, 'experiment_intra_religion_stereoset': 0.48159506916999817}, 'experiment_intra_stereoset': {'experiment_cps': 0.5212264060974121, 'experiment_inter_stereoset': 0.5080564022064209, 'experiment_intra_stereoset': 0.4768856465816498, 'experiment_inter_race_stereoset': 0.5662757754325867, 'experiment_intra_race_stereoset': 0.5022854208946228, 'experiment_inter_profession_stereoset': 0.5482832789421082, 'experiment_intra_profession_stereoset': 0.48831048607826233, 'experiment_inter_religion_stereoset': 0.5595611333847046, 'experiment_intra_religion_stereoset': 0.5030674934387207}, 'experiment_inter_race_stereoset': {'experiment_cps': 0.5188679099082947, 'experiment_inter_stereoset': 0.4853977859020233, 'experiment_intra_stereoset': 0.4917154014110565, 'experiment_inter_race_stereoset': 0.8299363255500793, 'experiment_intra_race_stereoset': 0.7068309187889099, 'experiment_inter_profession_stereoset': 0.7317596673965454, 'experiment_intra_profession_stereoset': 0.569513738155365, 'experiment_inter_religion_stereoset': 0.8275861740112305, 'experiment_intra_religion_stereoset': 0.6595091819763184}, 'experiment_intra_race_stereoset': {'experiment_cps': 0.5424528121948242, 'experiment_inter_stereoset': 0.4989929497241974, 'experiment_intra_stereoset': 0.5175438523292542, 'experiment_inter_race_stereoset': 0.7354065179824829, 'experiment_intra_race_stereoset': 0.7055837512016296, 'experiment_inter_profession_stereoset': 0.6126609444618225, 'experiment_intra_profession_stereoset': 0.6262469291687012, 'experiment_inter_religion_stereoset': 0.722570538520813, 'experiment_intra_religion_stereoset': 0.7131901383399963}, 'experiment_inter_profession_stereoset': {'experiment_cps': 0.573113203048706, 'experiment_inter_stereoset': 0.4904330372810364, 'experiment_intra_stereoset': 0.4844054579734802, 'experiment_inter_race_stereoset': 0.7591128945350647, 'experiment_intra_race_stereoset': 0.6168105602264404, 'experiment_inter_profession_stereoset': 0.7954022884368896, 'experiment_intra_profession_stereoset': 0.6234413981437683, 'experiment_inter_religion_stereoset': 0.777429461479187, 'experiment_intra_religion_stereoset': 0.6671779155731201}, 'experiment_intra_profession_stereoset': {'experiment_cps': 0.573113203048706, 'experiment_inter_stereoset': 0.4929506480693817, 'experiment_intra_stereoset': 0.4941520392894745, 'experiment_inter_race_stereoset': 0.6293652653694153, 'experiment_intra_race_stereoset': 0.689309298992157, 'experiment_inter_profession_stereoset': 0.6488350629806519, 'experiment_intra_profession_stereoset': 0.6526479721069336, 'experiment_inter_religion_stereoset': 0.6363636255264282, 'experiment_intra_religion_stereoset': 0.7377300262451172}, 'experiment_inter_religion_stereoset': {'experiment_cps': 0.5566037893295288, 'experiment_inter_stereoset': 0.48892244696617126, 'experiment_intra_stereoset': 0.5219298005104065, 'experiment_inter_race_stereoset': 0.7633188366889954, 'experiment_intra_race_stereoset': 0.6648044586181641, 'experiment_inter_profession_stereoset': 0.6926732063293457, 'experiment_intra_profession_stereoset': 0.5684227347373962, 'experiment_inter_religion_stereoset': 0.8125, 'experiment_intra_religion_stereoset': 0.6794478297233582}, 'experiment_intra_religion_stereoset': {'experiment_cps': 0.5471698045730591, 'experiment_inter_stereoset': 0.5040282011032104, 'experiment_intra_stereoset': 0.49561405181884766, 'experiment_inter_race_stereoset': 0.6653071045875549, 'experiment_intra_race_stereoset': 0.692610502243042, 'experiment_inter_profession_stereoset': 0.6092888116836548, 'experiment_intra_profession_stereoset': 0.6412094831466675, 'experiment_inter_religion_stereoset': 0.7084639072418213, 'experiment_intra_religion_stereoset': 0.6641221642494202}, 'likely': {'experiment_cps': 0.48349058628082275, 'experiment_inter_stereoset': 0.5130916237831116, 'experiment_intra_stereoset': 0.5102339386940002, 'experiment_inter_race_stereoset': 0.4821564853191376, 'experiment_intra_race_stereoset': 0.5096495747566223, 'experiment_inter_profession_stereoset': 0.4816063940525055, 'experiment_intra_profession_stereoset': 0.5010910630226135, 'experiment_inter_religion_stereoset': 0.46865203976631165, 'experiment_intra_religion_stereoset': 0.4938650131225586}}, "<class 'probes.MMProbe'>": {'experiment_cps': {'experiment_cps': 0.5176470875740051, 'experiment_inter_stereoset': 0.5015105605125427, 'experiment_intra_stereoset': 0.5063352584838867, 'experiment_inter_race_stereoset': 0.6752485036849976, 'experiment_intra_race_stereoset': 0.6424581408500671, 'experiment_inter_profession_stereoset': 0.6146535873413086, 'experiment_intra_profession_stereoset': 0.5826060175895691, 'experiment_inter_religion_stereoset': 0.7006269693374634, 'experiment_intra_religion_stereoset': 0.6595091819763184}, 'experiment_inter_stereoset': {'experiment_cps': 0.4952830374240875, 'experiment_inter_stereoset': 0.4899497330188751, 'experiment_intra_stereoset': 0.5160818696022034, 'experiment_inter_race_stereoset': 0.49311748147010803, 'experiment_intra_race_stereoset': 0.5163788795471191, 'experiment_inter_profession_stereoset': 0.4665849208831787, 'experiment_intra_profession_stereoset': 0.49984416365623474, 'experiment_inter_religion_stereoset': 0.4921630024909973, 'experiment_intra_religion_stereoset': 0.4892638027667999}, 'experiment_intra_stereoset': {'experiment_cps': 0.4858490526676178, 'experiment_inter_stereoset': 0.5, 'experiment_intra_stereoset': 0.4866180121898651, 'experiment_inter_race_stereoset': 0.5423145294189453, 'experiment_intra_race_stereoset': 0.5426612496376038, 'experiment_inter_profession_stereoset': 0.5032188892364502, 'experiment_intra_profession_stereoset': 0.5007793307304382, 'experiment_inter_religion_stereoset': 0.5344827175140381, 'experiment_intra_religion_stereoset': 0.49846625328063965}, 'experiment_inter_race_stereoset': {'experiment_cps': 0.5235849022865295, 'experiment_inter_stereoset': 0.5010070204734802, 'experiment_intra_stereoset': 0.49512672424316406, 'experiment_inter_race_stereoset': 0.7394904494285583, 'experiment_intra_race_stereoset': 0.7120366096496582, 'experiment_inter_profession_stereoset': 0.669681191444397, 'experiment_intra_profession_stereoset': 0.5285224914550781, 'experiment_inter_religion_stereoset': 0.7915360331535339, 'experiment_intra_religion_stereoset': 0.6242331266403198}, 'experiment_intra_race_stereoset': {'experiment_cps': 0.5353773832321167, 'experiment_inter_stereoset': 0.5060423016548157, 'experiment_intra_stereoset': 0.49658870697021484, 'experiment_inter_race_stereoset': 0.809457004070282, 'experiment_intra_race_stereoset': 0.6408629417419434, 'experiment_inter_profession_stereoset': 0.6624770164489746, 'experiment_intra_profession_stereoset': 0.5589152574539185, 'experiment_inter_religion_stereoset': 0.7821316719055176, 'experiment_intra_religion_stereoset': 0.6625766754150391}, 'experiment_inter_profession_stereoset': {'experiment_cps': 0.551886796951294, 'experiment_inter_stereoset': 0.482880175113678, 'experiment_intra_stereoset': 0.49561405181884766, 'experiment_inter_race_stereoset': 0.7672699093818665, 'experiment_intra_race_stereoset': 0.7111477851867676, 'experiment_inter_profession_stereoset': 0.6934866309165955, 'experiment_intra_profession_stereoset': 0.6038030385971069, 'experiment_inter_religion_stereoset': 0.8025078177452087, 'experiment_intra_religion_stereoset': 0.6779140830039978}, 'experiment_intra_profession_stereoset': {'experiment_cps': 0.5495283007621765, 'experiment_inter_stereoset': 0.5055387616157532, 'experiment_intra_stereoset': 0.4995126724243164, 'experiment_inter_race_stereoset': 0.583482027053833, 'experiment_intra_race_stereoset': 0.6166836023330688, 'experiment_inter_profession_stereoset': 0.64071124792099, 'experiment_intra_profession_stereoset': 0.6020249128341675, 'experiment_inter_religion_stereoset': 0.6833855509757996, 'experiment_intra_religion_stereoset': 0.7223926186561584}, 'experiment_inter_religion_stereoset': {'experiment_cps': 0.5448113083839417, 'experiment_inter_stereoset': 0.503524661064148, 'experiment_intra_stereoset': 0.5014619827270508, 'experiment_inter_race_stereoset': 0.771093487739563, 'experiment_intra_race_stereoset': 0.7220670580863953, 'experiment_inter_profession_stereoset': 0.6765788197517395, 'experiment_intra_profession_stereoset': 0.5625, 'experiment_inter_religion_stereoset': 0.75, 'experiment_intra_religion_stereoset': 0.6840490698814392}, 'experiment_intra_religion_stereoset': {'experiment_cps': 0.551886796951294, 'experiment_inter_stereoset': 0.49546828866004944, 'experiment_intra_stereoset': 0.5058479309082031, 'experiment_inter_race_stereoset': 0.7714758515357971, 'experiment_intra_race_stereoset': 0.7310817837715149, 'experiment_inter_profession_stereoset': 0.7024831771850586, 'experiment_intra_profession_stereoset': 0.6521196961402893, 'experiment_inter_religion_stereoset': 0.8072100281715393, 'experiment_intra_religion_stereoset': 0.6335877776145935}, 'likely': {'experiment_cps': 0.5141509771347046, 'experiment_inter_stereoset': 0.5065458416938782, 'experiment_intra_stereoset': 0.4985380172729492, 'experiment_inter_race_stereoset': 0.5365791320800781, 'experiment_intra_race_stereoset': 0.5622143149375916, 'experiment_inter_profession_stereoset': 0.518853485584259, 'experiment_intra_profession_stereoset': 0.5196384191513062, 'experiment_inter_religion_stereoset': 0.5156739950180054, 'experiment_intra_religion_stereoset': 0.5506134629249573}}}

# Fixing the proportions of the baselines graph
gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 0.417]) 
fig = plt.figure(figsize=(20, 10)) 
axes = [plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2])]

title_size = 20
axlabel_size = 18
ticklabel_size = 12
num_size = 12
normalize = True  # Do you want to normalize all values to LR on test set?
ordered = True  # Do you want to order the train/val datasests by CPS/SS1/SS2 instead of bias type?
colormin, colormax = (0.4, 0.9)
if normalize:
    colormin, colormax = (0.6, 1.2)
    
if ordered:
    val_datasets = [
    'experiment_cps',
    'experiment_inter_stereoset', # gender
    'experiment_inter_race_stereoset',
    'experiment_inter_profession_stereoset',
    'experiment_inter_religion_stereoset',
    'experiment_intra_stereoset', # gender
    'experiment_intra_race_stereoset',
    'experiment_intra_profession_stereoset',
    'experiment_intra_religion_stereoset',    
    ]
    
    train_medlies  = [
    ['experiment_cps'],
    ['experiment_inter_stereoset'], # gender
    ['experiment_inter_race_stereoset'],
    ['experiment_inter_profession_stereoset'],
    ['experiment_inter_religion_stereoset'],
    ['experiment_intra_stereoset'], # gender
    ['experiment_intra_race_stereoset'],
    ['experiment_intra_profession_stereoset'],
    ['experiment_intra_religion_stereoset'],
    ['likely']
    ]
    
else:
    val_datasets = [
    'experiment_cps',
    'experiment_inter_stereoset', # gender
    'experiment_intra_stereoset', # gender
    'experiment_inter_race_stereoset',
    'experiment_intra_race_stereoset',
    'experiment_inter_profession_stereoset',
    'experiment_intra_profession_stereoset',
    'experiment_inter_religion_stereoset',
    'experiment_intra_religion_stereoset',    
    ]
    
    train_medlies  = [
    ['experiment_cps'],
    ['experiment_inter_stereoset'], # gender
    ['experiment_intra_stereoset'], # gender
    ['experiment_inter_race_stereoset'],
    ['experiment_intra_race_stereoset'],
    ['experiment_inter_profession_stereoset'],
    ['experiment_intra_profession_stereoset'],
    ['experiment_inter_religion_stereoset'],
    ['experiment_intra_religion_stereoset'],
    ['likely']
    ]

# Define mapping for the type of experiments to indices in oracle_accs
experiment_to_index = {
    'experiment_cps': 0,
    'experiment_inter_stereoset': 1,
    'experiment_intra_stereoset': 2,
    'experiment_inter_race_stereoset': 3,
    'experiment_intra_race_stereoset': 4,
    'experiment_inter_profession_stereoset': 5,
    'experiment_intra_profession_stereoset': 6,
    'experiment_inter_religion_stereoset': 7,
    'experiment_intra_religion_stereoset': 8
}

# Subplot for Logistic Regression
ax = axes[0] 
ax.set_title("Logistic regression", fontsize=title_size)
ax_accs = lr_mm_accs[str(LRProbe)]
if normalize:
    for test_set, sub_exps in ax_accs.items():
        for i, (val_set, acc) in enumerate(sub_exps.items()):
            if val_set in experiment_to_index and test_set in experiment_to_index:
                norm_factor = oracle_accs[str(LRProbe)][i]
                norm_value = acc / norm_factor
                ax_accs[test_set][val_set] = norm_value
        
grid = [[] for _ in val_datasets]
for i, val_dataset in enumerate(val_datasets):
    for medley in train_medlies:
        if medley == ['likely']:
            continue
        grid[i].append(ax_accs[to_str(medley)][val_dataset])

ax.imshow(grid, vmin=colormin, vmax=colormax)
for i in range(len(grid)):
    for j in range(len(grid[0])):
        ax.text(j, i, f'{round(grid[i][j] * 100):2d}', ha='center', va='center', fontsize=num_size)
ax.set_xticks(range(len(train_medlies) - 1))
ax.set_xticklabels([normal_name(to_str(medley)) for medley in train_medlies[:-1]], rotation=45, ha='right', fontsize=ticklabel_size)


# Subplot for Mass Mean
ax = axes[1]
ax.set_title("Mass mean", fontsize=title_size)
ax_accs = lr_mm_accs[str(MMProbe)]

if normalize:
    for test_set, sub_exps in ax_accs.items():
        for i, (val_set, acc) in enumerate(sub_exps.items()):
            if val_set in experiment_to_index and test_set in experiment_to_index:
                norm_factor = oracle_accs[str(LRProbe)][i]
                norm_value = acc / norm_factor
                ax_accs[test_set][val_set] = norm_value

grid = [[] for _ in val_datasets]
for i, val_dataset in enumerate(val_datasets):
    for medley in train_medlies:
        if medley == ['likely']:
            continue
        grid[i].append(ax_accs[to_str(medley)][val_dataset])

ax.imshow(grid, vmin=colormin, vmax=colormax)
for i in range(len(grid)):
    for j in range(len(grid[0])):
        ax.text(j, i, f'{round(grid[i][j] * 100):2d}', ha='center', va='center', fontsize=num_size)
ax.set_xticks(range(len(train_medlies) - 1))
ax.set_xticklabels([normal_name(to_str(medley)) for medley in train_medlies[:-1]], rotation=45, ha='right', fontsize=ticklabel_size)


# Subplot for Baselines
ax = axes[2]
ax.set_title("Baselines", fontsize=title_size)
grid = [[lr_mm_accs[str(LRProbe)]['likely'][val_dataset],
         lr_mm_accs[str(MMProbe)]['likely'][val_dataset],
         oracle_accs[str(LRProbe)][i]] for i, val_dataset in enumerate(val_datasets)]

ax.imshow(grid, vmin=colormin, vmax=colormax)
for i in range(len(grid)):
    for j in range(len(grid[0])):
        ax.text(j, i, f'{round(grid[i][j] * 100):2d}', ha='center', va='center', fontsize=num_size)
ax.set_xticks(range(3))
ax.set_xticklabels(['LR on likely', 'MM on likely', 'LR on test set'], rotation=45, ha='right', fontsize=ticklabel_size)


# Bold lines to separate datasets
for i, ax in enumerate(axes):
    if i == 2: # baselines
        ax.hlines([0.5], *ax.get_xlim(), linewidth=3, color='black')
        ax.hlines([4.5], *ax.get_xlim(), linewidth=3, color='black')
    else:
        ax.hlines([0.5], *ax.get_xlim(), linewidth=3, color='black')
        ax.vlines([0.5], *ax.get_ylim(), linewidth=3, color='black')
        if ordered:
            ax.hlines([4.5], *ax.get_xlim(), linewidth=3, color='black')
            ax.vlines([4.5], *ax.get_ylim(), linewidth=3, color='black')

# General adjustments to axes
normal_val_datasets = [normal_name(dataset) for dataset in val_datasets]
for i, ax in enumerate(axes):
    if i == 0:
        ax.set_yticks(range(len(val_datasets)))
        ax.set_yticklabels(normal_val_datasets, fontsize=ticklabel_size)
        ax.set_ylabel('Test set', fontsize=axlabel_size)
    elif i == 1:
        ax.set_xlabel('Train set', labelpad=15, fontsize=axlabel_size)
        ax.xaxis.set_label_coords(0.1,-0.2)
        ax.set_yticks([])
    else:
        ax.set_yticks([])

# Adjust the layout to reduce whitespace
fig.subplots_adjust(wspace=0.025, hspace=0)

plt.colorbar(axes[0].images[0], ax=axes[2], aspect=30, shrink=0.82)

figure_name = 'figures/generalization/exp4/generalization_exp4'
if normalize:
    figure_name += "_normalized"
if ordered:
    figure_name += "_ordered"
plt.savefig(figure_name + ".png", bbox_inches='tight', dpi=300)