In [1]:
from neural_nlp.benchmarks import benchmark_pool
pereira = benchmark_pool["Pereira2018-encoding"]
data = pereira._load_assembly(version='base')

  class Score(DataAssembly):


In [2]:
# here are the various stimuli (passages with their accompanying sentences) which 
# were presented to the human participants in the experiment

from collections import Counter

# from _PereiraBenchmark#call
# we add a new passage identifier (experiment + the index of the passage read)
# this will allow us to process each stimulus together (passage by passage)

stimulus_set = data.attrs['stimulus_set']
stimulus_set.loc[:, 'passage_id'] = stimulus_set['experiment'] + stimulus_set['passage_index'].astype(str)

print(stimulus_set)
print(Counter(stimulus_set['passage_id']))

                                              sentence  sentence_num  \
0    Beekeeping encourages the conservation of loca...             0   
1    It is in every beekeeper's interest to conserv...             1   
2    As a passive form of agriculture, it does not ...             2   
3    Beekeepers also discourage the use of pesticid...             3   
4    Artisanal beekeepers go to extremes for their ...             4   
..                                                 ...           ...   
622  Some windows have multiple panes to increase i...           379   
623                   A woman is a female human adult.           380   
624    A woman is stereotypically seen as a caregiver.           381   
625     A woman can become pregnant and bear children.           382   
626  A woman has different reproductive organs than...           383   

          stimulus_id    experiment                       story  \
0      243sentences.0  243sentences     243sentences.beekeeping   
1

In [4]:
import torch
from transformers import GPT2TokenizerFast, GPT2Model

tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2', output_hidden_states=True)
model = model.eval()  

model

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0): GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (1): GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP

In [8]:
from tqdm import tqdm

# now we run the stimuli through our model and get their corresponding activations
# we do so for each "story" (identified by passage_id) separately by concatenating
# it's constituent sentences, keeping track of each sentences start and end tokens
# to be able to retrieve their token representans

def extract_passage_activations():
    # from stimulus_id -> 13 x 768 tensor (final representations from each layer)
    activations = {}
    for story in tqdm(sorted(set(stimulus_set['passage_id'].values))):
        story_stimuli = stimulus_set[stimulus_set['passage_id'] == story]

        sentences = []
        stimulus_ids = []
        stimulus_ends = []
        length_so_far = 0
        for _, stimulus in story_stimuli.sort_values(by='sentence_num', ascending=True).iterrows():
            length_so_far += len(stimulus['sentence'])
            sentences.append(stimulus['sentence'])
            stimulus_ids.append(stimulus['stimulus_id'])
            stimulus_ends.append(length_so_far - 1)

            # we'll join the sentences with spaces 
            length_so_far += 1

        with torch.no_grad():
            tokenized = tokenizer(
                [' '.join(sentences)], 
                add_special_tokens=True,
                return_tensors='pt'
            )

            # note that the ending character here is usually a period 
            # (we can experiment w/ the last word by subtracting 1)
            stimulus_token_ends = [
                tokenized.char_to_token(stimulus_end) for stimulus_end in stimulus_ends
            ]

            output = model(**tokenized)

            for stimulus_id, stimulus_token_end in zip(stimulus_ids, stimulus_token_ends):
                assert stimulus_id not in activations

                # get hidden state of each final token for each stimulus

                activations[stimulus_id] = torch.stack([
                    output.hidden_states[i][0][stimulus_token_end] for i in range(len(output.hidden_states))
                ])
    return activations

def extract_sentence_activations():
    # from stimulus_id -> 13 x 768 tensor (final representations from each layer)
    activations = {}
    for stimulus_id in tqdm(sorted(set(stimulus_set['stimulus_id'].values))):
        stimulus = stimulus_set[stimulus_set['stimulus_id'] == stimulus_id]

        assert len(stimulus) == 1
        
        with torch.no_grad():
            tokenized = tokenizer(
                [stimulus.iloc[0]['sentence']], 
                add_special_tokens=True,
                return_tensors='pt'
            )

            output = model(**tokenized)
            
            assert stimulus_id not in activations
            
            activations[stimulus_id] = torch.stack([
                output.hidden_states[i][0][-1] for i in range(len(output.hidden_states))
            ])
    return activations

passage_activations = extract_passage_activations()
sentence_activations = extract_sentence_activations()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 627/627 [00:34<00:00, 18.01it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 168/168 [00:22<00:00,  7.35it/s]


In [9]:
import pickle

with open('../cache/gpt2_passage_activations.pkl', 'wb') as f:
    pickle.dump(passage_activations, f)

with open('../cache/gpt2_sentence_activations.pkl', 'wb') as f:
    pickle.dump(sentence_activations, f)

In [4]:
import pickle

with open('../cache/gpt2_passage_activations.pkl', 'rb') as f:
    passage_activations = pickle.load(f)

with open('../cache/gpt2_sentence_activations.pkl', 'rb') as f:
    sentence_activations = pickle.load(f)
    
# sanity check that we have gpt activations for every stimulus
    
assert set(passage_activations.keys()) == set(data['stimulus_id'].values)
assert set(sentence_activations.keys()) == set(data['stimulus_id'].values)

In [5]:
import numpy as np
from tqdm import tqdm
from collections import defaultdict

# now we have to split / group the data as done in the neural_nlp repo

# here are our raw voxels
print(data.values.shape)

# for each subject in our data (dim 0), here are their corresponding experiments
print(len(data['experiment']), Counter(data['experiment'].values))

# for each voxel in our data (dim 1), here is its corresponding brain region (atlas)
print(len(data['atlas']), Counter(data['atlas'].values))

# later, we'll group results by subject
# strangely, individual voxels are subject specific?
print(len(data['subject']), Counter(data['subject'].values))

# also, they calculate correlations by neuroid_id....
print(len(data['neuroid_id']), len(set(data['neuroid_id'].values)))

# we split the data by experiment and atlas (this is very slow...)
# from brainscore.metrics.transformations import CartesianProduct
# splitter = CartesianProduct(dividers=['experiment', 'atlas'])
# splits = splitter(data, apply=lambda split: split.drop_vars(['experiment', 'atlas']))

experiment_voxels = defaultdict(list)
experiment_voxel_ids = defaultdict(set)
experiment_voxel_nas = defaultdict(set)
experiment_stimuli = defaultdict(list)
for presentation_id, stimulus_id, experiment in tqdm(zip(
    range(data.shape[0]), 
    data['stimulus_id'].values, 
    data['experiment'].values
)):
    voxels = []
    for voxel_id, atlas in zip(range(data.shape[1]), data['atlas'].values):
        if atlas == 'language':
            experiment_voxel_ids[experiment].add(voxel_id)
            voxel = data.values[presentation_id][voxel_id]
            if np.isnan(voxel):
                experiment_voxel_nas[experiment].add(voxel_id)
            voxels.append(voxel)

    experiment_voxels[experiment].append(voxels)    
    experiment_stimuli[experiment].append(stimulus_id)

for experiment in experiment_voxel_ids:
    experiment_voxel_ids[experiment] = list(sorted(experiment_voxel_ids[experiment]))

(627, 103900)
627 Counter({'384sentences': 384, '243sentences': 243})
103900 Counter({'visual': 43741, 'MD': 29936, 'language': 13553, 'DMN': 10978, 'auditory': 5692})
103900 Counter({'288': 10854, '407': 10825, '296': 10625, '343': 10615, '426': 10611, '215': 10462, '366': 10444, '199': 10378, '289': 10139, '018': 8947})
103900 101248


627it [00:29, 21.29it/s]


In [6]:
import numpy as np

# we filter out the voxels that are na

for experiment in experiment_voxels:
    print(
        experiment, 
        len(experiment_voxels[experiment]), 
        len(experiment_voxels[experiment][0]), 
        len(experiment_voxel_nas[experiment])
    )
    
experiments = {}
experiment_subjects = {}
for experiment in experiment_voxels:
    voxel_ids = experiment_voxel_ids[experiment]
    experiments[experiment] = np.array(
        [
            [voxel for voxel_id, voxel in zip(voxel_ids, voxels) if voxel_id not in experiment_voxel_nas[experiment]]
            for voxels in experiment_voxels[experiment]
        ]
    )
    experiment_subjects[experiment] = [
        subject for voxel_id, subject in zip(voxel_ids, data['subject'].values)
        if voxel_id not in experiment_voxel_nas[experiment]
    ]
    print(experiment, experiments[experiment].shape)

384sentences 384 13553 1398
243sentences 243 13553 5522
384sentences (384, 12155)
243sentences (243, 8031)


In [7]:
from tqdm import tqdm
from scipy.stats import pearsonr
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import GroupShuffleSplit

activations = passage_activations

# TODO:
# for some reason, individual neuroids in the first experiment have 
# ~constant activations across the presentations which leads to nans when calculating pearson-r
# need to understand why!

# 2 experiments x 5 folds x 13 layers
experiment_pearsonrs = defaultdict(lambda: np.zeros((5, 13)))
for experiment, brain_reps in experiments.items(): 
    # splits need to be by stimulus_id (how do we shuffle here?)
    # (though really they should be by passage_id given how we're doing the GPT2 encoding...
    # otherwise the test set will leak into the train set...)
    # in the brain-score repo, CrossRegressedCorrelation uses a train_size of 0.9
    k_folds = GroupShuffleSplit(n_splits=5, train_size=0.9)

    for fold, (train_indices, test_indices) in enumerate(
        k_folds.split(brain_reps, groups=experiment_stimuli[experiment])
    ):
        train_brain_reps, test_brain_reps = brain_reps[train_indices], brain_reps[test_indices]
        for layer_num in tqdm(range(13), desc='%s-fold%s' % (experiment, fold)):
            train_hidden_states = np.stack([
                activations[experiment_stimuli[experiment][brain_rep_idx]][layer_num].numpy()
                for brain_rep_idx in train_indices
            ])
            test_hidden_states = np.stack([
                activations[experiment_stimuli[experiment][brain_rep_idx]][layer_num].numpy()
                for brain_rep_idx in test_indices
            ])

            # TODO: Are they doing any kind of hyperparameter tuning
            # (regularization, etc) here?  We're using SKLearn's defaults
            
            model = LinearRegression().fit(train_hidden_states, train_brain_reps)
            pred_brain_reps = model.predict(test_hidden_states)

            # We aggregated voxel/electrode/ROI predictivity scores by taking the
            # median of scores for each participant’s voxels/electrodes/ROIs and
            # then computing the median across participants. Finally, this score was
            # divided by the estimated ceiling value (see below) to yield a final score in
            # the range [0, 1].

            # https://github.com/brain-score/brain-score/blob/master/brainscore/metrics/xarray_utils.py#L78
            # https://github.com/brain-score/brain-score/blob/master/brainscore/metrics/regression.py#L33
            # https://github.com/brain-score/brain-score/blob/master/brainscore/metrics/transformations.py#L42

            # not totally sure this is right...
            layer_pearson_rs_by_subj = defaultdict(list)
            for idx, test_index in enumerate(test_indices):
                subject = experiment_subjects[experiment][test_index]
                pred_voxels = pred_brain_reps[:, idx]
                test_voxels = test_brain_reps[:, idx]
                layer_pearson_rs_by_subj[subject].append(pearsonr(pred_voxels, test_voxels)[0])

            experiment_pearsonrs[experiment][fold][layer_num] = np.median([
                np.median(subj_pearson_rs) for subj_pearson_rs in layer_pearson_rs_by_subj.values()
            ])

384sentences-fold0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:22<00:00,  1.76s/it]


384sentences-fold1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:22<00:00,  1.75s/it]


384sentences-fold2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:22<00:00,  1.76s/it]


384sentences-fold3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:22<00:00,  1.74s/it]


384sentences-fold4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:22<00:00,  1.76s/it]
243sentences-fold0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.51it/s]
243sentences-fold1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.56it/s]
243sentences-fold2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.53it/s]
243sentences-fold3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 

In [8]:
for experiment, pearsonrs in experiment_pearsonrs.items():
    print(experiment)
    for layer_num in range(pearsonrs.shape[1]):
        print((layer_num, np.mean(pearsonrs[:, layer_num]), np.median(pearsonrs[:, layer_num])))

384sentences
(0, nan, nan)
(1, nan, nan)
(2, nan, nan)
(3, nan, nan)
(4, nan, nan)
(5, nan, nan)
(6, nan, nan)
(7, nan, nan)
(8, nan, nan)
(9, nan, nan)
(10, nan, nan)
(11, nan, nan)
(12, nan, nan)
243sentences
(0, 0.044156083646619784, 0.10456223984999405)
(1, 0.19331561215887155, 0.20244518091914154)
(2, 0.2802488794535649, 0.2614844823794658)
(3, 0.2538961571939675, 0.26393395981178697)
(4, 0.2590281368149035, 0.2256927193040175)
(5, 0.11670076962931715, 0.1160806411390303)
(6, 0.20304918827342341, 0.23146041784409135)
(7, 0.14270608302148732, 0.1422166237692077)
(8, 0.14111144108653745, 0.0980128107742348)
(9, 0.1759473819143163, 0.11888027686430289)
(10, 0.26596581363938454, 0.2935783145999833)
(11, 0.3436764894490004, 0.35978215380644335)
(12, 0.3289451251433617, 0.3093803155994607)


In [39]:
# Compare with the ceiling
pereira.ceiling