In [1]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import json
import argparse
import random
import scipy
import config
from GPT import GPT
from LLAMA import LLAMA
from StimulusModel import LMFeatures
from utils_stim import get_stim, get_story_wordseqs
from utils_resp import get_resp
from utils_ridge.ridge import ridge, bootstrap_ridge, ridge_corr
from utils_ridge.ridge_torch import ridge_torch, bootstrap_ridge_torch, ridge_corr_torch
from utils_ridge.stimulus_utils import TRFile, load_textgrids, load_simulated_trfiles
from utils_ridge.dsutils import make_word_ds
from utils_ridge.interpdata import lanczosinterp2D
from utils_ridge.util import make_delayed
from utils_ridge.utils import mult_diag, counter
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import utils_llama.activation as ana

import scipy
import math
import matplotlib.pyplot as plt

import time

  from .autonotebook import tqdm as notebook_tqdm
2024-01-06 01:05:07,634 - numexpr.utils - INFO - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-01-06 01:05:07,635 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.
  _C._set_default_tensor_type(t)


In [98]:
torch.cuda.empty_cache()

In [2]:
class ARGS:
    def __init__(self):
        self.subject = 'S1'
        self.gpt = 'perceived'
        self.sessions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20]

args = ARGS()

# training stories
stories = []
with open(os.path.join(config.DATA_TRAIN_DIR, "sess_to_story.json"), "r") as f:
    sess_to_story = json.load(f) 
for sess in args.sessions:
    stories.extend(sess_to_story[str(sess)])

stories = stories[:10]

## load cached llm act
cache_dir = '/ossfs/workspace/act_cache_ffn_gate'
llama = LLAMA(None, None, cache_dir)

word_seqs = get_story_wordseqs(stories)
word_vecs = {}
for story in stories:
    words = word_seqs[story].data
    layer = 20
    act_name = 'ffn_gate'
    embs = llama.get_llm_act(story, words, config.GPT_WORDS, act_name, layer)
    word_vecs[story] = embs

In [3]:
word_mat = np.vstack([word_vecs[story] for story in stories])
word_mean, word_std = word_mat.mean(0), word_mat.std(0)

ds_vecs = {story : lanczosinterp2D(word_vecs[story], word_seqs[story].data_times, word_seqs[story].tr_times) 
           for story in stories}

ds_mat = np.vstack([ds_vecs[story][5+config.TRIM:-config.TRIM] for story in stories])
r_mean, r_std = ds_mat.mean(0), ds_mat.std(0)
r_std[r_std == 0] = 1
ds_mat = np.nan_to_num(np.dot((ds_mat - r_mean), np.linalg.inv(np.diag(r_std))))
del_mat = make_delayed(ds_mat, config.STIM_DELAYS)

In [4]:
del_mat.shape

(3737, 44032)

In [5]:
rresp = get_resp(args.subject, stories, stack = True)

In [6]:
rstim, tr_stats, word_stats = del_mat, (r_mean, r_std), (word_mean, word_std)

nchunks = int(np.ceil(rresp.shape[0] / 5 / config.CHUNKLEN))

In [7]:
rstim = torch.tensor(rstim).cuda()
rresp = torch.tensor(rresp).cuda()

In [8]:
Rstim = rstim
Rresp = rresp
chunklen = config.CHUNKLEN
nchunks = nchunks
singcutoff=1e-10
nboots = 1
nresp, nvox = Rresp.shape
bestalphas = torch.zeros((nboots, nvox))
valinds = []
Rcmats = []

In [9]:
weights, alphas, bscorrs = bootstrap_ridge_torch(rstim, rresp, use_corr = False, alphas = config.ALPHAS,
    nboots = 15, chunklen = config.CHUNKLEN, nchunks = nchunks)    

2024-01-06 01:10:01,678 - ridge_corr - INFO - Selecting held-out test set..
2024-01-06 01:10:01,712 - ridge_corr - INFO - Doing SVD...
2024-01-06 01:10:03,701 - ridge_corr - INFO - Dropped 1 tiny singular values.. (U is now torch.Size([2977, 2976]))
2024-01-06 01:10:03,702 - ridge_corr - INFO - Training stimulus has Frobenius norm: 3147.409
2024-01-06 01:10:03,864 - ridge_corr - INFO - Training: alpha=10.000, mean corr=0.00000, max corr=0.00000, over-under(0.20)=0
2024-01-06 01:10:03,887 - ridge_corr - INFO - Training: alpha=16.681, mean corr=0.00000, max corr=0.00000, over-under(0.20)=0
2024-01-06 01:10:03,910 - ridge_corr - INFO - Training: alpha=27.826, mean corr=0.00000, max corr=0.00000, over-under(0.20)=0
2024-01-06 01:10:03,933 - ridge_corr - INFO - Training: alpha=46.416, mean corr=0.00000, max corr=0.00000, over-under(0.20)=0
2024-01-06 01:10:03,956 - ridge_corr - INFO - Training: alpha=77.426, mean corr=0.00000, max corr=0.04527, over-under(0.20)=0
2024-01-06 01:10:03,978 - r

In [10]:
bscorrs = bscorrs.mean(2).max(0)
vox = np.sort(np.argsort(bscorrs)[-config.VOXELS:])

In [65]:
t = bscorrs[:, :, 5:10].mean(2).max(0)
np.argsort(t)[-100:]

array([35036, 30062, 26781, 37443, 26748, 12986, 21244, 27204, 60127,
       21246, 35086, 34583, 40058, 24600, 27260, 35035, 34944, 40642,
       24042, 39806, 37153, 24494, 40692, 49983, 37402, 62520, 29730,
       26637, 31838, 39846, 40643, 34502, 15813, 34899, 29106, 39974,
       37401, 62468, 24210, 62469, 38044, 24495, 21690, 29145, 40016,
       40842, 37318, 27261, 29414, 37154, 27312, 26735, 24547, 31878,
       60128, 37359, 29782, 40219, 62569, 30008, 29679, 24546, 27421,
       40691, 30009, 24439, 32361, 32266, 37478, 40017, 24545, 37647,
       37946, 27367, 29331, 43191, 43238, 37193, 39932, 32309, 24496,
       37233, 34543, 60176, 29330, 62518, 29731, 37600, 37360, 37695,
       34500, 62619, 39975, 60280, 34584, 29783, 37556, 34985, 34986,
       34501])

In [13]:
def get_voxel(corrs):
    t = corrs.mean(2).max(0)
    return np.argsort(t)[-10000:].tolist()

In [16]:
def get_stim(stories, tr_stats = None):
    word_seqs = get_story_wordseqs(stories)
    word_vecs = {}
    for story in stories:
        words = word_seqs[story].data
        layer = 20
        act_name = 'ffn_gate'
        embs = llama.get_llm_act(story, words, config.GPT_WORDS, act_name, layer)
        word_vecs[story] = embs
    
    word_mat = np.vstack([word_vecs[story] for story in stories])
    word_mean, word_std = word_mat.mean(0), word_mat.std(0)

    ds_vecs = {story : lanczosinterp2D(word_vecs[story], word_seqs[story].data_times, word_seqs[story].tr_times) 
               for story in stories}
    ds_mat = np.vstack([ds_vecs[story][5+config.TRIM:-config.TRIM] for story in stories])

    if tr_stats is None: 
        r_mean, r_std = ds_mat.mean(0), ds_mat.std(0)
        r_std[r_std == 0] = 1
    else: 
        r_mean, r_std = tr_stats
    ds_mat = np.nan_to_num(np.dot((ds_mat - r_mean), np.linalg.inv(np.diag(r_std))))
    del_mat = make_delayed(ds_mat, config.STIM_DELAYS)
    if tr_stats is None: return del_mat, (r_mean, r_std), (word_mean, word_std)
    else: return del_mat

stim_dict = {story : get_stim([story], tr_stats = tr_stats) for story in stories}

In [15]:
resp_dict = get_resp(args.subject, stories, stack = False, vox = vox)

In [17]:
noise_model = torch.zeros([len(vox), len(vox)]).cuda()
for hstory in stories:
    tstim, hstim = np.vstack([stim_dict[tstory] for tstory in stories if tstory != hstory]), stim_dict[hstory]
    tresp, hresp = np.vstack([resp_dict[tstory] for tstory in stories if tstory != hstory]), resp_dict[hstory]
    tstim, hstim = torch.tensor(tstim).cuda(), torch.tensor(hstim).cuda()
    tresp, hresp = torch.tensor(tresp).cuda(), torch.tensor(hresp).cuda()
    bs_weights = ridge_torch(tstim, tresp, alphas[vox])
    bs_weights = bs_weights.to(hstim.device).to(hstim.dtype)
    pred = hstim.matmul(bs_weights)
    resids = hresp - pred
    bs_noise_model = resids.T.matmul(resids)
    noise_model += bs_noise_model / torch.diag(bs_noise_model).mean() / len(stories)

    pred = pred.cpu().numpy() 
    hresp = hresp.cpu().numpy()
    print(scipy.stats.pearsonr(pred.flatten(), hresp.flatten()))


PearsonRResult(statistic=0.1343166077704546, pvalue=0.0)
PearsonRResult(statistic=0.18005104517542656, pvalue=0.0)
PearsonRResult(statistic=0.13352700556254338, pvalue=0.0)
PearsonRResult(statistic=0.14024638964978753, pvalue=0.0)
PearsonRResult(statistic=0.11630805275207659, pvalue=0.0)
PearsonRResult(statistic=0.14408247753978906, pvalue=0.0)
PearsonRResult(statistic=0.1883748198242552, pvalue=0.0)
PearsonRResult(statistic=0.06479054619918578, pvalue=0.0)
PearsonRResult(statistic=0.16666112501921226, pvalue=0.0)
PearsonRResult(statistic=0.1544978554628468, pvalue=0.0)


In [33]:
# random_weights = torch.randn(bs_weights.shape).cuda().double()
# random_input = torch.randn(hstim.shape).cuda().double()
ids = list(range(hstim.shape[1]))
random.shuffle(ids)
pred = hstim[:,ids].matmul(bs_weights)
pred = pred.cpu().numpy() 
# hresp = hresp.cpu().numpy()
print(scipy.stats.pearsonr(pred.flatten(), hresp.flatten()))



PearsonRResult(statistic=0.0314835300969397, pvalue=0.0)


In [35]:
str(scipy.stats.pearsonr(pred.flatten(), hresp.flatten()))

'PearsonRResult(statistic=0.0314835300969397, pvalue=0.0)'

In [29]:
ids = list(range(hstim.shape[0]))
random.shuffle(ids)

In [31]:
hstim[ids]

tensor([[ 2.8632,  0.0108, -0.7713,  ..., -0.2018, -0.6897, -1.0225],
        [ 0.0309, -0.0534, -0.2413,  ...,  0.0564,  1.1742, -1.2079],
        [ 1.5587,  0.7781,  0.0997,  ..., -1.0315, -0.2295, -0.4057],
        ...,
        [-0.2464, -0.7736,  0.3435,  ..., -1.1315,  0.4548,  1.1343],
        [-0.7421, -0.9829,  1.7259,  ..., -0.5219, -0.3622,  0.3405],
        [-0.1142,  0.6807, -0.5739,  ..., -0.1105,  0.6583,  0.6890]],
       device='cuda:0', dtype=torch.float64)