In [1]:
import torch as t
from utils import DataManager
import random
import matplotlib.pyplot as plt
import random
from probes import LRProbe, MMProbe, CCSProbe
from matplotlib import gridspec

In [2]:
# hyperparameters
model = 'llama-3-8b' # llama-13b
model_size = '8B' # 13B
layer = 12 # layer from which to extract activations
split = 0.8

device = 'cuda:0' if t.cuda.is_available() else 'cpu'

# Experiment 1: llama-13b, layer 12
# Experiment 2: llama-13b, layer 12 (just to try a different split, see the effect of randomness)
# Experiment 3: llama-13b, layer 13 (the last useful layer, I think, identified by probing)
# Experiment 4: llama-8b, layer 12 (oops, this one still used llama-13b activations because of a hardcoded line in utils.py)
# Experiment 5: llama-8b, layer 12 (regenerated the activations for SS1 and SS2 gender + changed utils.py)

# Reproducing generalization matrix

In [3]:
train_medlies  = [
    ['experiment_cps'],
    ['experiment_inter_stereoset'], # gender
    ['experiment_intra_stereoset'], # gender
    ['experiment_inter_race_stereoset'],
    ['experiment_intra_race_stereoset'],
    ['experiment_inter_profession_stereoset'],
    ['experiment_intra_profession_stereoset'],
    ['experiment_inter_religion_stereoset'],
    ['experiment_intra_religion_stereoset'],
    ['likely']
]

val_datasets = [
    'experiment_cps',
    'experiment_inter_stereoset', # gender
    'experiment_intra_stereoset', # gender
    'experiment_inter_race_stereoset',
    'experiment_intra_race_stereoset',
    'experiment_inter_profession_stereoset',
    'experiment_intra_profession_stereoset',
    'experiment_inter_religion_stereoset',
    'experiment_intra_religion_stereoset',    
]

def to_str(l):
    return '+'.join(l)

ProbeClasses = [
    LRProbe, 
    MMProbe, 
    ]

accs = {str(probe_class) : {to_str(train_medley) : {} for train_medley in train_medlies} for probe_class in ProbeClasses}

seed = random.randint(0, 100000)

In [4]:
for ProbeClass in ProbeClasses:
    for medley in train_medlies:

        # set up data
        dm = DataManager()
        for dataset in medley:
            dm.add_dataset(dataset, model_size, layer, split=split, seed=seed, center=True, device=device)
        for dataset in val_datasets:
            if dataset not in medley:
                dm.add_dataset(dataset, model_size, layer, split=None, center=True, device=device)

        # train probe
        train_acts, train_labels = dm.get('train')
        probe = ProbeClass.from_data(train_acts, train_labels, device=device)


        # evaluate
        for val_dataset in val_datasets:
            if val_dataset in medley:
                acts, labels = dm.data['val'][val_dataset]
                accs[str(ProbeClass)][to_str(medley)][val_dataset] = (
                    probe.pred(acts, iid=True) == labels
                ).float().mean().item()
            else:
                acts, labels = dm.data[val_dataset]
                accs[str(ProbeClass)][to_str(medley)][val_dataset] = (
                    probe.pred(acts, iid=False) == labels
                ).float().mean().item()
        print("Finished train dataset", medley)

lr_mm_accs = accs.copy()

Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Finished train dataset ['experiment_cps']
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b

In [None]:
# print(val_dataset)
# print("Activations size:", acts.size())
# print("Labels size:", labels.size())

In [5]:
def normal_name(name):
    if name == "experiment_cps":
        new = "CP gender"
    elif name == "experiment_inter_stereoset":
        new = "SS1 gender"
    elif name == "experiment_intra_stereoset":
        new = "SS2 gender"
    elif name == "experiment_inter_race_stereoset":
        new = "SS1 race"
    elif name == "experiment_intra_race_stereoset":
        new = "SS2 race"
    elif name == "experiment_inter_religion_stereoset":
        new = "SS1 religion"
    elif name == "experiment_intra_religion_stereoset":
        new = "SS2 religion"
    elif name == "experiment_inter_profession_stereoset":
        new = "SS1 profession"
    elif name == "experiment_intra_profession_stereoset":
        new = "SS2 profession"
    elif name == "likely":
        return name
    return new

In [7]:
# get oracle probe results
oracle_accs = {str(probe_class) : [] for probe_class in ProbeClasses}
for ProbeClass in ProbeClasses:
    for dataset in val_datasets:
        dm = DataManager()
        dm.add_dataset(dataset, model_size, layer, split=split, seed=seed, device=device)
        acts, labels = dm.get('train')
        probe = ProbeClass.from_data(acts, labels, device=device)

        acts, labels = dm.data['val'][dataset]
        acc = (probe(acts, iid=True).round() == labels).float().mean().item()
        oracle_accs[str(ProbeClass)].append(acc)

Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /scratch-shared/tpungas/geometry-of-truth/acts/llama-3-8b/
Collecting acts from this directory: /sc

In [9]:
# print(oracle_accs)

In [10]:
# print(lr_mm_accs)

In [15]:
# Fixing the proportions of the baselines graph
gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 0.417]) 
fig = plt.figure(figsize=(20, 10)) 
axes = [plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2])]

title_size = 20
axlabel_size = 18
ticklabel_size = 12
num_size = 12
normalize = False  # Normalize all values to "LR on test set"?
ordered = True
colormin, colormax = (0.3, 0.9)

if normalize:
    colormin, colormax = (0.6, 1.2)
    
if ordered:
    val_datasets = [
    'experiment_cps',
    'experiment_inter_stereoset', # gender
    'experiment_inter_race_stereoset',
    'experiment_inter_profession_stereoset',
    'experiment_inter_religion_stereoset',
    'experiment_intra_stereoset', # gender
    'experiment_intra_race_stereoset',
    'experiment_intra_profession_stereoset',
    'experiment_intra_religion_stereoset',    
    ]
    
    train_medlies  = [
    ['experiment_cps'],
    ['experiment_inter_stereoset'], # gender
    ['experiment_inter_race_stereoset'],
    ['experiment_inter_profession_stereoset'],
    ['experiment_inter_religion_stereoset'],
    ['experiment_intra_stereoset'], # gender
    ['experiment_intra_race_stereoset'],
    ['experiment_intra_profession_stereoset'],
    ['experiment_intra_religion_stereoset'],
    ['likely']
    ]
    
else:
    val_datasets = [
    'experiment_cps',
    'experiment_inter_stereoset', # gender
    'experiment_intra_stereoset', # gender
    'experiment_inter_race_stereoset',
    'experiment_intra_race_stereoset',
    'experiment_inter_profession_stereoset',
    'experiment_intra_profession_stereoset',
    'experiment_inter_religion_stereoset',
    'experiment_intra_religion_stereoset',    
    ]
    
    train_medlies  = [
    ['experiment_cps'],
    ['experiment_inter_stereoset'], # gender
    ['experiment_intra_stereoset'], # gender
    ['experiment_inter_race_stereoset'],
    ['experiment_intra_race_stereoset'],
    ['experiment_inter_profession_stereoset'],
    ['experiment_intra_profession_stereoset'],
    ['experiment_inter_religion_stereoset'],
    ['experiment_intra_religion_stereoset'],
    ['likely']
    ]

# Define mapping for the type of experiments to indices in oracle_accs
experiment_to_index = {
    'experiment_cps': 0,
    'experiment_inter_stereoset': 1,
    'experiment_intra_stereoset': 2,
    'experiment_inter_race_stereoset': 3,
    'experiment_intra_race_stereoset': 4,
    'experiment_inter_profession_stereoset': 5,
    'experiment_intra_profession_stereoset': 6,
    'experiment_inter_religion_stereoset': 7,
    'experiment_intra_religion_stereoset': 8
}

# Subplot for Logistic Regression
ax = axes[0] 
ax.set_title("Logistic regression", fontsize=title_size)
ax_accs = lr_mm_accs[str(LRProbe)]
if normalize:
    for test_set, sub_exps in ax_accs.items():
        for i, (val_set, acc) in enumerate(sub_exps.items()):
            if val_set in experiment_to_index and test_set in experiment_to_index:
                norm_factor = oracle_accs[str(LRProbe)][i]
                norm_value = acc / norm_factor
                ax_accs[test_set][val_set] = norm_value
        
grid = [[] for _ in val_datasets]
for i, val_dataset in enumerate(val_datasets):
    for medley in train_medlies:
        if medley == ['likely']:
            continue
        grid[i].append(ax_accs[to_str(medley)][val_dataset])

ax.imshow(grid, vmin=colormin, vmax=colormax)
for i in range(len(grid)):
    for j in range(len(grid[0])):
        ax.text(j, i, f'{round(grid[i][j] * 100):2d}', ha='center', va='center', fontsize=num_size)
ax.set_xticks(range(len(train_medlies) - 1))
ax.set_xticklabels([normal_name(to_str(medley)) for medley in train_medlies[:-1]], rotation=45, ha='right', fontsize=ticklabel_size)


# Subplot for Mass Mean
ax = axes[1]
ax.set_title("Mass mean", fontsize=title_size)
ax_accs = lr_mm_accs[str(MMProbe)]

if normalize:
    for test_set, sub_exps in ax_accs.items():
        for i, (val_set, acc) in enumerate(sub_exps.items()):
            if val_set in experiment_to_index and test_set in experiment_to_index:
                norm_factor = oracle_accs[str(LRProbe)][i]
                norm_value = acc / norm_factor
                ax_accs[test_set][val_set] = norm_value

grid = [[] for _ in val_datasets]
for i, val_dataset in enumerate(val_datasets):
    for medley in train_medlies:
        if medley == ['likely']:
            continue
        grid[i].append(ax_accs[to_str(medley)][val_dataset])

ax.imshow(grid, vmin=colormin, vmax=colormax)
for i in range(len(grid)):
    for j in range(len(grid[0])):
        ax.text(j, i, f'{round(grid[i][j] * 100):2d}', ha='center', va='center', fontsize=num_size)
ax.set_xticks(range(len(train_medlies) - 1))
ax.set_xticklabels([normal_name(to_str(medley)) for medley in train_medlies[:-1]], rotation=45, ha='right', fontsize=ticklabel_size)


# Subplot for Baselines
ax = axes[2]
ax.set_title("Baselines", fontsize=title_size)
grid = [[lr_mm_accs[str(LRProbe)]['likely'][val_dataset],
         lr_mm_accs[str(MMProbe)]['likely'][val_dataset],
         oracle_accs[str(LRProbe)][i]] for i, val_dataset in enumerate(val_datasets)]

ax.imshow(grid, vmin=colormin, vmax=colormax)
for i in range(len(grid)):
    for j in range(len(grid[0])):
        ax.text(j, i, f'{round(grid[i][j] * 100):2d}', ha='center', va='center', fontsize=num_size)
ax.set_xticks(range(3))
ax.set_xticklabels(['LR on likely', 'MM on likely', 'LR on test set'], rotation=45, ha='right', fontsize=ticklabel_size)


# Bold lines to separate datasets
for i, ax in enumerate(axes):
    if i == 2: # baselines
        ax.hlines([0.5], *ax.get_xlim(), linewidth=3, color='black')
        ax.hlines([4.5], *ax.get_xlim(), linewidth=3, color='black')
    else:
        ax.hlines([0.5], *ax.get_xlim(), linewidth=3, color='black')
        ax.vlines([0.5], *ax.get_ylim(), linewidth=3, color='black')
        if ordered:
            ax.hlines([4.5], *ax.get_xlim(), linewidth=3, color='black')
            ax.vlines([4.5], *ax.get_ylim(), linewidth=3, color='black')

# General adjustments to axes
normal_val_datasets = [normal_name(dataset) for dataset in val_datasets]
for i, ax in enumerate(axes):
    if i == 0:
        ax.set_yticks(range(len(val_datasets)))
        ax.set_yticklabels(normal_val_datasets, fontsize=ticklabel_size)
        ax.set_ylabel('Test set', fontsize=axlabel_size)
    elif i == 1:
        ax.set_xlabel('Train set', labelpad=15, fontsize=axlabel_size)
        ax.xaxis.set_label_coords(0.1,-0.2)
        ax.set_yticks([])
    else:
        ax.set_yticks([])

# Adjust the layout to reduce whitespace
fig.subplots_adjust(wspace=0.025, hspace=0)

plt.colorbar(axes[0].images[0], ax=axes[2], aspect=30, shrink=0.82)

figure_name = 'figures/generalization/exp5/generalization_exp5'
if normalize:
    figure_name += "_normalized"
if not ordered:
    figure_name += "_NOTordered"
plt.savefig(figure_name + ".pdf", bbox_inches='tight', dpi=300)