In [1]:
import os
import json
import pandas
import numpy as np
import nibabel as nib
import torch
import transformers
from transformers import AutoTokenizer
from typing import Dict
import pickle
import sklearn
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_dir = '/share/gzhch/resource/models/Llama-2-7b-hf/'
tokenizer = AutoTokenizer.from_pretrained(model_dir)
exclude_files = dict(tunnel=['sub-004','sub-013'], lucy=['sub-053', 'sub-065'])

In [3]:
def load_data(task_name):
    with open('/data/gzhch/data/llm_act/{task_name}_1.pkl'.format(task_name=task_name), 'rb') as f:
        llm_act = pickle.load(f)

    with open('/data/gzhch/narratives/stimuli/gentle/{task_name}/align.json'.format(task_name=task_name), 'r') as f:
        raw_input = json.loads(f.read())

    with open('aligned_input.pkl', 'rb') as f:
        all_aligned_inputs = pickle.load(f)
        
    input_ids = tokenizer(raw_input['transcript'])['input_ids']
    
    fmri_files = []
    fmri_imgs = []
    
    folder_path = '/data/gzhch/narratives/derivatives/afni-nosmooth/'

    for dirpath, dirnames, filenames in os.walk(folder_path):
        for filename in filenames:
            if filename.endswith('nii.gz') and f'{task_name}' in filename:
                participant = filename.split('_')[0]
                if participant in exclude_files[task_name]:
                    continue
                file_path = os.path.join(dirpath, filename)
                fmri_files.append(file_path)
    fmri_files.sort()

    for i, f in enumerate(fmri_files):
        fmri_imgs.append(nib.load(f))
    #     print(fmri_files[i].split('/')[-3], fmri_imgs[-1].shape)

    return dict(llm_act=llm_act, input_aligned=all_aligned_inputs[task_name], input_ids=input_ids, fmri_imgs=fmri_imgs, fmri_files=fmri_files)

def get_fdata(data):
    data['fmri_act'] = [img.get_fdata() for img in data['fmri_imgs']]
    return data

def tr_alignment(data):
    tr = 1.5
    for i in range(len(data['input_aligned'])):
        t = data['input_aligned'][-1-i]
        if t['start'] < t['end']:
            max_tr = math.ceil(t['end'] / tr)
            break

    tr_words = [[] for _ in range(max_tr)]
    for c, w in enumerate(data['input_aligned']):
        a = math.floor(w['start'] / tr)
        b = math.ceil(w['end'] / tr)
        l, r = w['word_to_token']
        for i in range(a, b):
            tr_words[i].append(c)

    for i in range(len(tr_words)):
        if tr_words[i] == []:
            tr_words[i].append(tr_words[i-1][-1])

    data['tr_to_words'] = tr_words
    return data

def reshape_llm_act(data):
    tr_words = data['tr_to_words']
    data['tr_to_ids'] = []
    for words in tr_words:
        l = data['input_aligned'][words[0]]['word_to_token'][0]
        r = data['input_aligned'][words[-1]]['word_to_token'][1]
        ids = list(range(l, r))
        data['tr_to_ids'].append(ids)
    return data

def get_filtered_neurons(data, th=1, layer=12):
    act_indices = torch.unbind(data['llm_act']['indices'][layer])
    act_values = torch.unbind(data['llm_act']['values'][layer])
    filtered_neurons = [i[v > th].tolist() for v, i in zip(act_values, act_indices)]
    return filtered_neurons

def diff(a, b):
    if type(a) is not torch.Tensor:
        a = torch.tensor(a)
        b = torch.tensor(b)
    return (a-b).norm().item()

In [41]:
data_lucy = load_data('lucy')
data_tunnel = load_data('tunnel')

data_lucy = tr_alignment(data_lucy)
data_tunnel = tr_alignment(data_tunnel)

data_lucy = reshape_llm_act(data_lucy)
data_tunnel = reshape_llm_act(data_tunnel)

data_lucy = get_fdata(data_lucy)
data_tunnel = get_fdata(data_tunnel)

In [16]:
act_indices = data_lucy['llm_act']['indices']
act_values = data_lucy['llm_act']['values']
act_shape = act_indices.shape
# act_values = view(-1)

In [33]:
th = 1 
layer = 10
act_indices = torch.unbind(data_lucy['llm_act']['indices'][layer])
act_values = torch.unbind(data_lucy['llm_act']['values'][layer])

In [124]:
filtered = get_filtered_neurons(data_lucy, th=0.5)

In [55]:
data_lucy['tr_to_ids'][0]

[1, 2, 3, 4, 5, 6, 7, 8]

In [125]:
tr_llm_neurons = []
for t in data_lucy['tr_to_ids']:
    neurons = []
    for i in t:
        neurons += filtered[i]
    neurons = set(neurons)
    tr_llm_neurons.append(neurons)

In [126]:
def set_diff(a, b):
    return len(a & b) / len(a | b)

llm_sim = []
for i in range(len(tr_llm_neurons)):
    row_sim = []
    for j in range(len(tr_llm_neurons)):
        row_sim.append(set_diff(tr_llm_neurons[i], tr_llm_neurons[j]))
    llm_sim.append(row_sim)

llm_sim = torch.tensor(llm_sim)
print(llm_sim.shape)

torch.Size([350, 350])


In [92]:
llm_sim.mean()

tensor(0.0794)

In [96]:
brain_sim.abs().mean()

tensor(0.0692, dtype=torch.float64)

In [None]:
l

torch.Size([370, 370])


In [106]:
offset = 6
t = brain_sim[offset: offset + llm_sim.shape[0], offset: offset + llm_sim.shape[1]]
print((t-llm_sim).norm(), (t-llm_sim).mean(), (t-llm_sim).std())

tensor(20.5144, dtype=torch.float64) tensor(-0.0311, dtype=torch.float64) tensor(0.0497, dtype=torch.float64)


In [127]:
for sub in range(10):
    t = torch.tensor(data_lucy['fmri_act'][sub])
    max_tr = t.shape[-1]
    t = t.reshape(-1, max_tr).transpose(0, 1)

    brain_sim = torch.corrcoef(t).abs()
    for offset in range(6,7):
        t = brain_sim[offset: offset + llm_sim.shape[0], offset: offset + llm_sim.shape[1]]
        print(f'offset:{offset}', (t-llm_sim).norm(), (t-llm_sim).abs().mean(), (t-llm_sim).std())

offset:6 tensor(31.7056, dtype=torch.float64) tensor(0.0762, dtype=torch.float64) tensor(0.0719, dtype=torch.float64)
offset:6 tensor(32.5146, dtype=torch.float64) tensor(0.0788, dtype=torch.float64) tensor(0.0641, dtype=torch.float64)
offset:6 tensor(34.1406, dtype=torch.float64) tensor(0.0847, dtype=torch.float64) tensor(0.0547, dtype=torch.float64)
offset:6 tensor(34.5137, dtype=torch.float64) tensor(0.0861, dtype=torch.float64) tensor(0.0544, dtype=torch.float64)
offset:6 tensor(36.3580, dtype=torch.float64) tensor(0.0928, dtype=torch.float64) tensor(0.0491, dtype=torch.float64)
offset:6 tensor(31.1921, dtype=torch.float64) tensor(0.0748, dtype=torch.float64) tensor(0.0666, dtype=torch.float64)
offset:6 tensor(37.1414, dtype=torch.float64) tensor(0.0953, dtype=torch.float64) tensor(0.0491, dtype=torch.float64)
offset:6 tensor(33.5578, dtype=torch.float64) tensor(0.0831, dtype=torch.float64) tensor(0.0582, dtype=torch.float64)
offset:6 tensor(34.7900, dtype=torch.float64) tensor(0.0

In [128]:
for sub in range(10):
    t = torch.tensor(data_tunnel['fmri_act'][sub])
    max_tr = t.shape[-1]
    t = t.reshape(-1, max_tr).transpose(0, 1)

    brain_sim = torch.corrcoef(t).abs()
    for offset in range(6,7):
        t = brain_sim[offset: offset + llm_sim.shape[0], offset: offset + llm_sim.shape[1]]
        print(f'offset:{offset}', (t-llm_sim).norm(), (t-llm_sim).abs().mean(), (t-llm_sim).std())

offset:6 tensor(35.7530, dtype=torch.float64) tensor(0.0909, dtype=torch.float64) tensor(0.0513, dtype=torch.float64)
offset:6 tensor(37.3162, dtype=torch.float64) tensor(0.0961, dtype=torch.float64) tensor(0.0482, dtype=torch.float64)
offset:6 tensor(35.4993, dtype=torch.float64) tensor(0.0895, dtype=torch.float64) tensor(0.0528, dtype=torch.float64)
offset:6 tensor(35.7427, dtype=torch.float64) tensor(0.0903, dtype=torch.float64) tensor(0.0554, dtype=torch.float64)
offset:6 tensor(31.6582, dtype=torch.float64) tensor(0.0762, dtype=torch.float64) tensor(0.0693, dtype=torch.float64)
offset:6 tensor(33.9467, dtype=torch.float64) tensor(0.0847, dtype=torch.float64) tensor(0.0572, dtype=torch.float64)
offset:6 tensor(34.9147, dtype=torch.float64) tensor(0.0880, dtype=torch.float64) tensor(0.0531, dtype=torch.float64)
offset:6 tensor(33.2596, dtype=torch.float64) tensor(0.0821, dtype=torch.float64) tensor(0.0593, dtype=torch.float64)
offset:6 tensor(36.4222, dtype=torch.float64) tensor(0.0

In [85]:
## recover llm matrix
def get_filtered_neurons(data, th=1, layer=12):
    act_indices = torch.unbind(data['llm_act']['indices'][layer])
    act_values = torch.unbind(data['llm_act']['values'][layer])
    filtered_neurons = [i[v > th].tolist() for v, i in zip(act_values, act_indices)]
    return filtered_neurons

In [71]:
indices = torch.unbind(data_lucy['llm_act']['indices'][16])
values = torch.unbind(data_lucy['llm_act']['values'][16])
ffn_gate = torch.unbind(torch.zeros(2271, 11008).half())
for i in range(len(indices)):
    ffn_gate[i][indices[i].to(torch.int)] = values[i]
ffn_gate = torch.stack(ffn_gate)
tr_words = data_lucy['tr_to_words']
tr_llm_act = []
for words in tr_words:
    l = data_lucy['input_aligned'][words[0]]['word_to_token'][0]
    r = data_lucy['input_aligned'][words[-1]]['word_to_token'][1]
    ids = list(range(l, r))
    tr_llm_act.append(ffn_gate[ids].sum(dim=0))
tr_llm_act = torch.stack(tr_llm_act).float()
tr_llm_act_low_rank, s, _ = torch.pca_lowrank(tr_llm_act, q=100)


In [73]:
sub = 0
tr_brain_act = torch.tensor(data_lucy['fmri_act'][sub])
max_tr = tr_brain_act.shape[-1]
tr_brain_act = tr_brain_act.reshape(-1, max_tr).transpose(0, 1)
tr_brain_act_low_rank, _, _ = torch.pca_lowrank(tr_brain_act, q=100)


In [74]:
tr_brain_act_low_rank

tensor([[-0.0396, -0.0143, -0.0514,  ..., -0.0005, -0.0624,  0.0025],
        [ 0.0062,  0.0307,  0.0071,  ..., -0.0205, -0.1000,  0.0121],
        [ 0.0265,  0.0026, -0.0122,  ...,  0.0098,  0.0123, -0.0214],
        ...,
        [ 0.0104,  0.0298,  0.0404,  ...,  0.0144,  0.0441,  0.0580],
        [ 0.0220, -0.0206,  0.0632,  ..., -0.0496, -0.0072,  0.0311],
        [-0.0032, -0.0542,  0.0024,  ...,  0.0008,  0.0153, -0.0453]],
       dtype=torch.float64)

In [64]:
tr_brain_act_low_rank.shape

torch.Size([370, 6])

In [89]:
from sklearn.linear_model import Ridge

offset = 8
n_train = 270
n_eval = 30
x_train = tr_llm_act_low_rank[:n_train].numpy()
x_eval = tr_llm_act_low_rank[n_train:n_train+n_eval].numpy()
y_train = tr_brain_act_low_rank[offset:offset+n_train].numpy()
y_eval = tr_brain_act_low_rank[offset+n_train:offset+n_train+n_eval].numpy()


clf = Ridge(alpha=1)
clf.fit(x_train, y_train)
print(clf.score(x_train, y_train), clf.score(x_eval, y_eval))

0.29443128965352455 -0.07294517128063545


In [77]:
clf.score(x_eval, y_eval)

-0.04338313383884836

In [78]:
y_zero = torch.zeros(torch.tensor(y_eval).shape).numpy()
clf.score(x_eval, y_zero)

0.0

In [80]:
clf.score(x_train, y_train)

0.29405949881305987

In [57]:
x_eval

array([[ 2.19521765e-02, -6.95774257e-02, -1.03463326e-03,
        -2.44314522e-02,  1.58921684e-04, -3.22564654e-02],
       [-1.96992811e-02, -3.89579713e-04,  3.17841321e-02,
         6.31387811e-03,  9.23208706e-03, -3.31459269e-02],
       [ 5.14575541e-02,  9.25030001e-03,  6.08093552e-02,
         1.58198476e-02, -3.20081599e-03,  1.17002381e-02],
       [-1.81398459e-03,  7.36554805e-03,  3.48787047e-02,
         1.85895395e-02,  4.72770929e-02,  3.96052785e-02],
       [-8.09315871e-03, -3.89149189e-02, -2.39572320e-02,
        -3.43334451e-02,  6.89699948e-02,  2.32156254e-02],
       [-2.87540518e-02, -4.17115241e-02,  7.82132335e-03,
         4.75939084e-03, -9.60046527e-05,  7.54610375e-02],
       [ 5.60508221e-02, -7.56961331e-02, -4.41426560e-02,
         1.54635543e-02, -4.89861593e-02,  1.47646382e-01],
       [-6.93465164e-03, -3.93947251e-02,  2.20387429e-02,
         7.13125942e-03, -3.15610059e-02,  8.37234482e-02],
       [-5.10896966e-02, -5.21181375e-02,  1.234

In [54]:
clf.score(x_train, y_train)

0.790532093436896

In [43]:
def inner_dis(acts, time1, time2):
    l = len(acts)
    dis = []
    for i in range(l):
        for j in range(i + 1, l):
            dis.append(diff(acts[i][:, :, :, time1], acts[j][:, :, :, time2]))
    dis = torch.tensor(dis)
    print(dis.mean(), dis.std())

def inter_dis(acts1, acts2, time1, time2):
    dis = []
    for a in acts1:
        for b in acts2:
            dis.append(diff(a[:, :, :, time1], b[:, :, :, time2]))
    dis = torch.tensor(dis)
    print(dis.mean(), dis.std())

inner_dis(data_lucy['fmri_act'][:-1], 100, 100)
inner_dis(data_tunnel['fmri_act'][:-1], 100, 100)
inter_dis(data_lucy['fmri_act'][:-1], data_tunnel['fmri_act'][:-1], 100, 100)

tensor(2730.3357) tensor(241.6544)
tensor(2775.7405) tensor(388.1245)
tensor(2756.1538) tensor(327.9261)


In [44]:
max_time = data_lucy['fmri_act'][0].shape[-1]

for t1 in range(0, max_time, 20):
    for t2 in range(t1, max_time, 20):
        print(t1, t2)
        inner_dis(data_lucy['fmri_act'][:-1], t1, t2)


0 0
tensor(2792.5964) tensor(408.9689)
0 20
tensor(2791.8516) tensor(391.3250)
0 40
tensor(2741.9087) tensor(395.5639)
0 60
tensor(2708.2378) tensor(399.6494)
0 80
tensor(2745.5054) tensor(450.3113)
0 100
tensor(2745.5635) tensor(388.7130)
0 120
tensor(2750.7163) tensor(386.6628)
0 140
tensor(2831.6277) tensor(394.7903)
0 160
tensor(2799.4453) tensor(383.5285)
0 180
tensor(2830.8284) tensor(383.8633)
0 200
tensor(2750.9917) tensor(389.5215)
0 220
tensor(2842.1714) tensor(428.9813)
0 240
tensor(2793.2007) tensor(413.5020)
0 260
tensor(2749.4810) tensor(406.9154)
0 280
tensor(2764.4436) tensor(390.7912)
0 300
tensor(2817.4138) tensor(394.3676)
0 320
tensor(2829.8596) tensor(412.8351)
0 340
tensor(2794.2546) tensor(417.4810)
0 360
tensor(2763.5847) tensor(382.2531)
20 20
tensor(2769.5854) tensor(265.7102)
20 40
tensor(2736.1877) tensor(277.1757)
20 60
tensor(2693.0085) tensor(270.8752)
20 80
tensor(2732.4910) tensor(341.8155)
20 100
tensor(2733.6453) tensor(261.7294)
20 120
tensor(2737.55

In [32]:
i, j = 8, 9
diff(data_lucy['fmri_act'][i], data_tunnel['fmri_act'][j])


RuntimeError: The size of tensor a (370) must match the size of tensor b (1040) at non-singleton dimension 3

In [28]:
type(a) == torch.Tensor

True