In [1]:
import torch as t
from utils import DataManager
import random
import matplotlib.pyplot as plt
import random
from probes import LRProbe, MMProbe, CCSProbe

In [2]:
# hyperparameters
model = 'llama-13b'
model_size = '13B'
layer = 13 # layer from which to extract activations
split = 0.8

device = 'cuda:0' if t.cuda.is_available() else 'cpu'

# Experiment 1: layer 12
# Experiment 2: layer 12 (just to try a different split, see the effect of randomness)
# Experiment 3: layer 13 (the last useful layer, I think, identified by probing)

# Reproducing generalization matrix

In [3]:
train_medlies  = [
    ['experiment_cps'],
    ['experiment_inter_stereoset'], # gender
    ['experiment_intra_stereoset'], # gender
    ['experiment_inter_race_stereoset'],
    ['experiment_intra_race_stereoset'],
    ['experiment_inter_profession_stereoset'],
    ['experiment_intra_profession_stereoset'],
    ['experiment_inter_religion_stereoset'],
    ['experiment_intra_religion_stereoset'],
    ['likely']
]

val_datasets = [
    'experiment_cps',
    'experiment_inter_stereoset', # gender
    'experiment_intra_stereoset', # gender
    'experiment_inter_race_stereoset',
    'experiment_intra_race_stereoset',
    'experiment_inter_profession_stereoset',
    'experiment_intra_profession_stereoset',
    'experiment_inter_religion_stereoset',
    'experiment_intra_religion_stereoset',    
]

def to_str(l):
    return '+'.join(l)

ProbeClasses = [
    LRProbe, 
    MMProbe, 
    ]

accs = {str(probe_class) : {to_str(train_medley) : {} for train_medley in train_medlies} for probe_class in ProbeClasses}

seed = random.randint(0, 100000)

In [4]:
for ProbeClass in ProbeClasses:
    for medley in train_medlies:

        # set up data
        dm = DataManager()
        for dataset in medley:
            dm.add_dataset(dataset, model_size, layer, split=split, seed=seed, center=True, device=device)
        for dataset in val_datasets:
            if dataset not in medley:
                dm.add_dataset(dataset, model_size, layer, split=None, center=True, device=device)

        # train probe
        train_acts, train_labels = dm.get('train')
        probe = ProbeClass.from_data(train_acts, train_labels, device=device)


        # evaluate
        for val_dataset in val_datasets:
            if val_dataset in medley:
                acts, labels = dm.data['val'][val_dataset]
                accs[str(ProbeClass)][to_str(medley)][val_dataset] = (
                    probe.pred(acts, iid=True) == labels
                ).float().mean().item()
            else:
                acts, labels = dm.data[val_dataset]
                accs[str(ProbeClass)][to_str(medley)][val_dataset] = (
                    probe.pred(acts, iid=False) == labels
                ).float().mean().item()

lr_mm_accs = accs.copy()

In [5]:
def normal_name(name):
    if name == "experiment_cps":
        new = "CP gender"
    elif name == "experiment_inter_stereoset":
        new = "SS1 gender"
    elif name == "experiment_intra_stereoset":
        new = "SS2 gender"
    elif name == "experiment_inter_race_stereoset":
        new = "SS1 race"
    elif name == "experiment_intra_race_stereoset":
        new = "SS2 race"
    elif name == "experiment_inter_religion_stereoset":
        new = "SS1 religion"
    elif name == "experiment_intra_religion_stereoset":
        new = "SS2 religion"
    elif name == "experiment_inter_profession_stereoset":
        new = "SS1 profession"
    elif name == "experiment_intra_profession_stereoset":
        new = "SS2 profession"
    elif name == "likely":
        return name
    return new

In [6]:
# get oracle probe results
oracle_accs = {str(probe_class) : [] for probe_class in ProbeClasses}
for ProbeClass in ProbeClasses:
    for dataset in val_datasets:
        dm = DataManager()
        dm.add_dataset(dataset, model_size, layer, split=split, seed=seed, device=device)
        acts, labels = dm.get('train')
        probe = ProbeClass.from_data(acts, labels, device=device)

        acts, labels = dm.data['val'][dataset]
        acc = (probe(acts, iid=True).round() == labels).float().mean().item()
        oracle_accs[str(ProbeClass)].append(acc)

In [11]:
# print(oracle_accs)

In [12]:
# print(lr_mm_accs)

In [9]:
# oracle_accs = 
# lr_mm_accs = 

In [18]:
# Stored accuracies from first run of the generalization experiment
lr_mm_accs = {"<class 'probes.LRProbe'>": {'experiment_cps': {'experiment_cps': 0.5529412031173706, 'experiment_inter_stereoset': 0.6717019081115723, 'experiment_intra_stereoset': 0.63791424036026, 'experiment_inter_race_stereoset': 0.6695131063461304, 'experiment_intra_race_stereoset': 0.6188420653343201, 'experiment_inter_profession_stereoset': 0.6223176121711731, 'experiment_intra_profession_stereoset': 0.6193890571594238, 'experiment_inter_religion_stereoset': 0.6990595459938049, 'experiment_intra_religion_stereoset': 0.6564416885375977}, 'experiment_inter_stereoset': {'experiment_cps': 0.5801886916160583, 'experiment_inter_stereoset': 0.7713567614555359, 'experiment_intra_stereoset': 0.6578947305679321, 'experiment_inter_race_stereoset': 0.7679072022438049, 'experiment_intra_race_stereoset': 0.6838496923446655, 'experiment_inter_profession_stereoset': 0.7524524927139282, 'experiment_intra_profession_stereoset': 0.6211035251617432, 'experiment_inter_religion_stereoset': 0.7962382435798645, 'experiment_intra_religion_stereoset': 0.6963189840316772}, 'experiment_intra_stereoset': {'experiment_cps': 0.625, 'experiment_inter_stereoset': 0.7059416174888611, 'experiment_intra_stereoset': 0.6447688341140747, 'experiment_inter_race_stereoset': 0.6339535713195801, 'experiment_intra_race_stereoset': 0.6232859492301941, 'experiment_inter_profession_stereoset': 0.6420907378196716, 'experiment_intra_profession_stereoset': 0.6686409115791321, 'experiment_inter_religion_stereoset': 0.6536049842834473, 'experiment_intra_religion_stereoset': 0.6978527307510376}, 'experiment_inter_race_stereoset': {'experiment_cps': 0.5377358794212341, 'experiment_inter_stereoset': 0.7049345374107361, 'experiment_intra_stereoset': 0.5667641162872314, 'experiment_inter_race_stereoset': 0.8356688022613525, 'experiment_intra_race_stereoset': 0.7153377532958984, 'experiment_inter_profession_stereoset': 0.7145922780036926, 'experiment_intra_profession_stereoset': 0.565773069858551, 'experiment_inter_religion_stereoset': 0.7993730306625366, 'experiment_intra_religion_stereoset': 0.657975435256958}, 'experiment_intra_race_stereoset': {'experiment_cps': 0.5400943756103516, 'experiment_inter_stereoset': 0.6364551782608032, 'experiment_intra_stereoset': 0.6086744666099548, 'experiment_inter_race_stereoset': 0.7555441856384277, 'experiment_intra_race_stereoset': 0.7239847779273987, 'experiment_inter_profession_stereoset': 0.6473022699356079, 'experiment_intra_profession_stereoset': 0.6231297254562378, 'experiment_inter_religion_stereoset': 0.722570538520813, 'experiment_intra_religion_stereoset': 0.6855828166007996}, 'experiment_inter_profession_stereoset': {'experiment_cps': 0.5424528121948242, 'experiment_inter_stereoset': 0.7411882877349854, 'experiment_intra_stereoset': 0.5843079686164856, 'experiment_inter_race_stereoset': 0.7566912770271301, 'experiment_intra_race_stereoset': 0.6182072162628174, 'experiment_inter_profession_stereoset': 0.7670498490333557, 'experiment_intra_profession_stereoset': 0.6204800605773926, 'experiment_inter_religion_stereoset': 0.7570532560348511, 'experiment_intra_religion_stereoset': 0.6610429286956787}, 'experiment_intra_profession_stereoset': {'experiment_cps': 0.6132075786590576, 'experiment_inter_stereoset': 0.6475327014923096, 'experiment_intra_stereoset': 0.6598440408706665, 'experiment_inter_race_stereoset': 0.6552383303642273, 'experiment_intra_race_stereoset': 0.692610502243042, 'experiment_inter_profession_stereoset': 0.6814837455749512, 'experiment_intra_profession_stereoset': 0.6822429895401001, 'experiment_inter_religion_stereoset': 0.678683340549469, 'experiment_intra_religion_stereoset': 0.7361962795257568}, 'experiment_inter_religion_stereoset': {'experiment_cps': 0.5353773832321167, 'experiment_inter_stereoset': 0.7155085802078247, 'experiment_intra_stereoset': 0.581384003162384, 'experiment_inter_race_stereoset': 0.7717307806015015, 'experiment_intra_race_stereoset': 0.6715337634086609, 'experiment_inter_profession_stereoset': 0.7069282531738281, 'experiment_intra_profession_stereoset': 0.5729426741600037, 'experiment_inter_religion_stereoset': 0.7578125, 'experiment_intra_religion_stereoset': 0.6963189840316772}, 'experiment_intra_religion_stereoset': {'experiment_cps': 0.5471698045730591, 'experiment_inter_stereoset': 0.6248741149902344, 'experiment_intra_stereoset': 0.6193957328796387, 'experiment_inter_race_stereoset': 0.6477185487747192, 'experiment_intra_race_stereoset': 0.6888014674186707, 'experiment_inter_profession_stereoset': 0.6301348805427551, 'experiment_intra_profession_stereoset': 0.6547693610191345, 'experiment_inter_religion_stereoset': 0.7382444739341736, 'experiment_intra_religion_stereoset': 0.6335877776145935}, 'likely': {'experiment_cps': 0.4811320900917053, 'experiment_inter_stereoset': 0.5010070204734802, 'experiment_intra_stereoset': 0.5097466111183167, 'experiment_inter_race_stereoset': 0.491970419883728, 'experiment_intra_race_stereoset': 0.5153631567955017, 'experiment_inter_profession_stereoset': 0.49202942848205566, 'experiment_intra_profession_stereoset': 0.4911159873008728, 'experiment_inter_religion_stereoset': 0.4811912178993225, 'experiment_intra_religion_stereoset': 0.5}}, "<class 'probes.MMProbe'>": {'experiment_cps': {'experiment_cps': 0.43529412150382996, 'experiment_inter_stereoset': 0.6314199566841125, 'experiment_intra_stereoset': 0.6257309913635254, 'experiment_inter_race_stereoset': 0.6384144425392151, 'experiment_intra_race_stereoset': 0.6239207983016968, 'experiment_inter_profession_stereoset': 0.6146535873413086, 'experiment_intra_profession_stereoset': 0.6044264435768127, 'experiment_inter_religion_stereoset': 0.684952974319458, 'experiment_intra_religion_stereoset': 0.6441717743873596}, 'experiment_inter_stereoset': {'experiment_cps': 0.5589622855186462, 'experiment_inter_stereoset': 0.6457286477088928, 'experiment_intra_stereoset': 0.6325536370277405, 'experiment_inter_race_stereoset': 0.7508283853530884, 'experiment_intra_race_stereoset': 0.7001016139984131, 'experiment_inter_profession_stereoset': 0.7070815563201904, 'experiment_intra_profession_stereoset': 0.6128429174423218, 'experiment_inter_religion_stereoset': 0.8009403944015503, 'experiment_intra_religion_stereoset': 0.6886503100395203}, 'experiment_intra_stereoset': {'experiment_cps': 0.5778301954269409, 'experiment_inter_stereoset': 0.6938570141792297, 'experiment_intra_stereoset': 0.6131386756896973, 'experiment_inter_race_stereoset': 0.6454243659973145, 'experiment_intra_race_stereoset': 0.6637887358665466, 'experiment_inter_profession_stereoset': 0.6578786373138428, 'experiment_intra_profession_stereoset': 0.6493142247200012, 'experiment_inter_religion_stereoset': 0.722570538520813, 'experiment_intra_religion_stereoset': 0.699386477470398}, 'experiment_inter_race_stereoset': {'experiment_cps': 0.5165094137191772, 'experiment_inter_stereoset': 0.6540785431861877, 'experiment_intra_stereoset': 0.5204678177833557, 'experiment_inter_race_stereoset': 0.7331210374832153, 'experiment_intra_race_stereoset': 0.7116556763648987, 'experiment_inter_profession_stereoset': 0.6655426025390625, 'experiment_intra_profession_stereoset': 0.5218204855918884, 'experiment_inter_religion_stereoset': 0.7884012460708618, 'experiment_intra_religion_stereoset': 0.6196318864822388}, 'experiment_intra_race_stereoset': {'experiment_cps': 0.5400943756103516, 'experiment_inter_stereoset': 0.6384692788124084, 'experiment_intra_stereoset': 0.5662767887115479, 'experiment_inter_race_stereoset': 0.7931429743766785, 'experiment_intra_race_stereoset': 0.6472080945968628, 'experiment_inter_profession_stereoset': 0.6486818194389343, 'experiment_intra_profession_stereoset': 0.5710723400115967, 'experiment_inter_religion_stereoset': 0.7789968252182007, 'experiment_intra_religion_stereoset': 0.6656441688537598}, 'experiment_inter_profession_stereoset': {'experiment_cps': 0.5495283007621765, 'experiment_inter_stereoset': 0.7004027962684631, 'experiment_intra_stereoset': 0.5935672521591187, 'experiment_inter_race_stereoset': 0.7559265494346619, 'experiment_intra_race_stereoset': 0.7004824876785278, 'experiment_inter_profession_stereoset': 0.6727969646453857, 'experiment_intra_profession_stereoset': 0.603335440158844, 'experiment_inter_religion_stereoset': 0.794670820236206, 'experiment_intra_religion_stereoset': 0.6763803362846375}, 'experiment_intra_profession_stereoset': {'experiment_cps': 0.5448113083839417, 'experiment_inter_stereoset': 0.6440080404281616, 'experiment_intra_stereoset': 0.6603313684463501, 'experiment_inter_race_stereoset': 0.5667855739593506, 'experiment_intra_race_stereoset': 0.6140173077583313, 'experiment_inter_profession_stereoset': 0.6335070729255676, 'experiment_intra_profession_stereoset': 0.6066977977752686, 'experiment_inter_religion_stereoset': 0.6771159768104553, 'experiment_intra_religion_stereoset': 0.6963189840316772}, 'experiment_inter_religion_stereoset': {'experiment_cps': 0.5377358794212341, 'experiment_inter_stereoset': 0.6576032042503357, 'experiment_intra_stereoset': 0.5599415302276611, 'experiment_inter_race_stereoset': 0.7556716203689575, 'experiment_intra_race_stereoset': 0.7093702554702759, 'experiment_inter_profession_stereoset': 0.6620171666145325, 'experiment_intra_profession_stereoset': 0.5584476590156555, 'experiment_inter_religion_stereoset': 0.7734375, 'experiment_intra_religion_stereoset': 0.6641104221343994}, 'experiment_intra_religion_stereoset': {'experiment_cps': 0.5589622855186462, 'experiment_inter_stereoset': 0.6686807870864868, 'experiment_intra_stereoset': 0.6359649300575256, 'experiment_inter_race_stereoset': 0.7238082885742188, 'experiment_intra_race_stereoset': 0.700736403465271, 'experiment_inter_profession_stereoset': 0.6802574992179871, 'experiment_intra_profession_stereoset': 0.6566396951675415, 'experiment_inter_religion_stereoset': 0.777429461479187, 'experiment_intra_religion_stereoset': 0.6335877776145935}, 'likely': {'experiment_cps': 0.5023584961891174, 'experiment_inter_stereoset': 0.5266867876052856, 'experiment_intra_stereoset': 0.5467836260795593, 'experiment_inter_race_stereoset': 0.5514912009239197, 'experiment_intra_race_stereoset': 0.5752920508384705, 'experiment_inter_profession_stereoset': 0.527437150478363, 'experiment_intra_profession_stereoset': 0.532418966293335, 'experiment_inter_religion_stereoset': 0.5297805666923523, 'experiment_intra_religion_stereoset': 0.5674846768379211}}}
oracle_accs = {"<class 'probes.LRProbe'>": [0.5529412031173706, 0.7713567614555359, 0.6447688341140747, 0.8356688022613525, 0.7239847779273987, 0.7670498490333557, 0.6822429895401001, 0.7578125, 0.6335877776145935], "<class 'probes.MMProbe'>": [0.43529412150382996, 0.6457286477088928, 0.6131386756896973, 0.7331210374832153, 0.6472080945968628, 0.6727969646453857, 0.6066977977752686, 0.7734375, 0.6335877776145935]}

fig, axes = plt.subplots(1, 3, figsize=(30, 10))
normalize = True  # Do you want to normalize all values to LR on test set?
colormin, colormax = (0.3, 0.9)
if normalize:
    colormin, colormax = (0.7, 1.2)

# Define mapping for the type of experiments to indices in oracle_accs
experiment_to_index = {
    'experiment_cps': 0,
    'experiment_inter_stereoset': 1,
    'experiment_intra_stereoset': 2,
    'experiment_inter_race_stereoset': 3,
    'experiment_intra_race_stereoset': 4,
    'experiment_inter_profession_stereoset': 5,
    'experiment_intra_profession_stereoset': 6,
    'experiment_inter_religion_stereoset': 7,
    'experiment_intra_religion_stereoset': 8
}

# Subplot for Logistic Regression
ax = axes[0]
ax.set_title("Logistic regression")
ax_accs = lr_mm_accs[str(LRProbe)]
if normalize:
    for test_set, sub_exps in ax_accs.items():
        for i, (val_set, acc) in enumerate(sub_exps.items()):
            if val_set in experiment_to_index and test_set in experiment_to_index:  # Ensure it exists in oracle_accs
                norm_factor = oracle_accs[str(LRProbe)][i]
                norm_value = acc / norm_factor
                ax_accs[test_set][val_set] = norm_value
        
grid = [[] for _ in val_datasets]
for i, val_dataset in enumerate(val_datasets):
    for medley in train_medlies:
        if medley == ['likely']:
            continue
        grid[i].append(ax_accs[to_str(medley)][val_dataset])

ax.imshow(grid, vmin=colormin, vmax=colormax)
for i in range(len(grid)):
    for j in range(len(grid[0])):
        ax.text(j, i, f'{round(grid[i][j] * 100):2d}', ha='center', va='center')
ax.set_xticks(range(len(train_medlies) - 1))
ax.set_xticklabels([normal_name(to_str(medley)) for medley in train_medlies[:-1]], rotation=45, ha='right')


# Subplot for Mass Mean
ax = axes[1]
ax.set_title("Mass mean")
ax_accs = lr_mm_accs[str(MMProbe)]

if normalize:
    for test_set, sub_exps in ax_accs.items():
        for i, (val_set, acc) in enumerate(sub_exps.items()):
            if val_set in experiment_to_index and test_set in experiment_to_index:  # Ensure it exists in oracle_accs
                norm_factor = oracle_accs[str(LRProbe)][i]
                norm_value = acc / norm_factor
                ax_accs[test_set][val_set] = norm_value

grid = [[] for _ in val_datasets]
for i, val_dataset in enumerate(val_datasets):
    for medley in train_medlies:
        if medley == ['likely']:
            continue
        grid[i].append(ax_accs[to_str(medley)][val_dataset])

ax.imshow(grid, vmin=colormin, vmax=colormax)
for i in range(len(grid)):
    for j in range(len(grid[0])):
        ax.text(j, i, f'{round(grid[i][j] * 100):2d}', ha='center', va='center')
ax.set_xticks(range(len(train_medlies) - 1))
ax.set_xticklabels([normal_name(to_str(medley)) for medley in train_medlies[:-1]], rotation=45, ha='right')


# Subplot for Baselines
ax = axes[2]
ax.set_title("Baselines")
grid = [[lr_mm_accs[str(LRProbe)]['likely'][val_dataset],
         lr_mm_accs[str(MMProbe)]['likely'][val_dataset],
         oracle_accs[str(LRProbe)][i]] for i, val_dataset in enumerate(val_datasets)]

ax.imshow(grid, vmin=colormin, vmax=colormax)
for i in range(len(grid)):
    for j in range(len(grid[0])):
        ax.text(j, i, f'{round(grid[i][j] * 100):2d}', ha='center', va='center')
ax.set_xticks(range(3))
ax.set_xticklabels(['LR on likely', 'MM on likely', 'LR on test set'], rotation=45, ha='right')


# Bold lines to separate datasets
for ax in axes:
    ax.hlines([0.5], *ax.get_xlim(), linewidth=3, color='black')
    ax.vlines([0.5], *ax.get_ylim(), linewidth=3, color='black')

# General adjustments to axes
normal_val_datasets = [normal_name(dataset) for dataset in val_datasets]
for i, ax in enumerate(axes):
    if i == 1:
        ax.set_yticks([])
    else:
        ax.set_yticks(range(len(val_datasets)))
        ax.set_yticklabels(normal_val_datasets)
        ax.set_ylabel('Test set', fontsize=15)
        ax.set_xlabel('Train set', fontsize=15, ha='center', labelpad=15)

# Adjust the layout to reduce whitespace
plt.subplots_adjust(wspace=0.1, hspace=0)

plt.colorbar(axes[0].images[0], ax=axes[2])
plt.savefig('figures/generalization_exp3_normalized.png', bbox_inches='tight')