In [177]:
import sys
import os
import random
import re

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from datasets import Dataset
from torch.utils.data import DataLoader

sys.path.append(os.getenv('SPARSE_PROBING_ROOT'))
from activations.activation_subset import load_activation_subset
from load import load_feature_dataset, load_model

In [178]:
wikidata_property = 'sex_or_gender'
n_seq = 20

# wikidata_property = 'occupation_athlete'
# n_seq = 5

model = 'pythia-70m'
model = 'pythia-1b'
model = 'pythia-6.9b'

### Load Dataset and Logits

In [179]:
# load the dataset
dataset_name = f'wikidata_ablations_{wikidata_property}.pyth.128.{n_seq}'
dataset = load_feature_dataset(os.path.join('ablation_datasets', dataset_name))
dataset

Dataset({
    features: ['prompt', 'name', 'class', 'mapped_class', 'tokens', 'logit_index'],
    num_rows: 20
})

In [180]:
# load the logits
data_dir = os.path.join(
    os.getenv('RESULTS_DIR'),
    'ablations',
    model,
    'ablation_datasets',
    dataset_name,
    )
data_files = os.listdir(data_dir)

nominal_logits = torch.load(os.path.join(data_dir, 'nominal_logits.pt'))

ablated_logits = {}
for data_file in data_files:
    if 'ablated_logits' not in data_file:
        continue

    lix, nix = re.match('ablated_logits_(\d+)_(\d+).pt', data_file).groups()
    neuron = (int(lix), int(nix))

    ablated_logits[neuron] = torch.load(os.path.join(data_dir, data_file))

neurons = list(ablated_logits.keys())
neurons

[(1, 1)]

In [181]:
tokenizer = load_model('pythia-70m').tokenizer

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m into HookedTransformer


### Analysis

In [182]:
# see results
# TODO: do for each class
n = 3
for ix, example in enumerate(dataset.select(range(n))):
    prompt = example['prompt'].split('.')[-1][1:]  # truncate the examples

    logits = nominal_logits[ix, :].to(torch.float32)
    probs = torch.nn.functional.softmax(logits, dim=0)

    token_indices = torch.argsort(logits, descending=True)[:5]
    predictions = tokenizer.batch_decode(token_indices)
    prediction_probs = [round(p, 2) for p in probs[token_indices].tolist()]

    print(f'Prompt: {prompt}; true class: {example["mapped_class"]} ({example["class"]})')
    print(f'\t{list(zip(predictions, prediction_probs))}')

Prompt: Kalki Koechlin has gender; true class: female (female)
	[(' male', 0.76), (' female', 0.16), (' Male', 0.01), (' masculine', 0.01), (' gender', 0.01)]
Prompt: Sofia Coppola has gender; true class: female (female)
	[(' male', 0.56), (' female', 0.35), (' gender', 0.01), (' feminine', 0.01), (' masculine', 0.0)]
Prompt: Brenda Lee has gender; true class: female (female)
	[(' female', 0.63), (' male', 0.23), (' feminine', 0.01), (' woman', 0.01), (' gender', 0.01)]


In [183]:
# TODO: do the above for each ablated neuron

In [184]:
dataset_df = pd.DataFrame(dataset)
classes = dataset_df.mapped_class.unique().tolist()
class_tokens = [tokenizer(f' {c}').input_ids[0] for c in classes]

nominal_probs = torch.nn.functional.softmax(nominal_logits.to(torch.float32), dim=1)

In [185]:
# get accuracy for each class
for c, c_token in zip(classes, class_tokens):
    class_mask = (dataset_df.mapped_class == c).values
    class_probs = nominal_probs[class_mask, c_token]

    mean = class_probs.mean()
    acc = (nominal_probs[class_mask, :].argmax(dim=1) == c_token).to(torch.float32).mean()

    print(f'Class: {c}')
    print(f'\tAccuracy: {acc:.2f}')
    print(f'\tMean probability on correct token: {mean:.2f}')

Class: female
	Accuracy: 0.50
	Mean probability on correct token: 0.41
Class: male
	Accuracy: 0.70
	Mean probability on correct token: 0.57


In [186]:
# TODO: do the above for each ablated neuron