In [51]:
%load_ext autoreload
%autoreload 2
import os
import copy
import numpy as np
import json
import argparse
import random
import scipy
import config
from GPT import GPT
from LLAMA import LLAMA
from StimulusModel import LMFeatures
from utils_stim import get_story_wordseqs
# from utils_resp import get_resp
from utils_ridge.ridge import ridge, bootstrap_ridge, ridge_corr
from utils_ridge.ridge_torch import ridge_torch, bootstrap_ridge_torch, ridge_corr_torch
from utils_ridge.stimulus_utils import TRFile, load_textgrids, load_simulated_trfiles
from utils_ridge.dsutils import make_word_ds
from utils_ridge.interpdata import lanczosinterp2D, lanczosinterp2D_torch
from utils_ridge.util import make_delayed
from utils_ridge.utils import mult_diag, counter
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import utils_llama.activation as ana

import scipy
import math
import matplotlib.pyplot as plt

import time
import h5py


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [66]:
# torch.cuda.memory._record_memory_history()
torch.cuda.empty_cache()
torch.set_grad_enabled(False)

class ARGS:
    def __init__(self):
        self.subject = 'S1'
        self.gpt = 'perceived'
        self.sessions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20]
        self.layer = 17
        self.act_name = 'ffn_gate'
        self.window = 5

args = ARGS()

# training stories
stories = []
with open(os.path.join(config.DATA_TRAIN_DIR, "sess_to_story.json"), "r") as f:
    sess_to_story = json.load(f) 
for sess in args.sessions:
    stories.extend(sess_to_story[str(sess)])

stories = stories[:10]


In [11]:
model_dir = '/ossfs/workspace/nas/gzhch/data/models/Llama-2-7b-hf'
model = AutoModelForCausalLM.from_pretrained(
    model_dir, 
    device_map='auto',
    torch_dtype=torch.float16,
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_dir)


## load cached llm act if possible
cache_dir = '/ossfs/workspace/act_cache'
llama = LLAMA(model, tokenizer, cache_dir)

2024-01-08 10:59:35,737 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.94s/it]


In [12]:
def get_stim(args, stories, llama, tr_stats = None, delay=True, vox=None):
    word_seqs = get_story_wordseqs(stories)
    word_vecs = {}
    for story in stories:
        words = word_seqs[story].data
        embs = llama.get_llm_act(story, words, args.window, args.act_name, args.layer)
        word_vecs[story] = embs
    
    word_mat = np.vstack([word_vecs[story] for story in stories])
    word_mean, word_std = word_mat.mean(0), word_mat.std(0)

    ds_vecs = {story : lanczosinterp2D(word_vecs[story], word_seqs[story].data_times, word_seqs[story].tr_times) 
               for story in stories}
    ds_mat = np.vstack([ds_vecs[story][5+config.TRIM:-config.TRIM] for story in stories])

    if vox is not None:
        ds_mat = ds_mat[:, vox]

    if tr_stats is None: 
        r_mean, r_std = ds_mat.mean(0), ds_mat.std(0)
        r_std[r_std == 0] = 1
    else: 
        r_mean, r_std = tr_stats
    ds_mat = np.nan_to_num(np.dot((ds_mat - r_mean), np.linalg.inv(np.diag(r_std))))
    if delay:
        del_mat = make_delayed(ds_mat, config.STIM_DELAYS)
    else:
        del_mat = ds_mat
    if tr_stats is None: return del_mat, (r_mean, r_std), (word_mean, word_std)
    else: return del_mat, None, None

def get_resp(subject, stories, stack = True, vox = None, delay=False):
    """loads response data
    """
    subject_dir = os.path.join(config.DATA_TRAIN_DIR, "train_response", subject)
    resp = {}
    for story in stories:
        resp_path = os.path.join(subject_dir, "%s.hf5" % story)
        hf = h5py.File(resp_path, "r")
        resp[story] = np.nan_to_num(hf["data"][:])
        if vox is not None:
            resp[story] = resp[story][:, vox]
        if delay:
            resp[story] = make_delayed(resp[story], config.RESP_DELAYS)
        hf.close()
    if stack: return np.vstack([resp[story] for story in stories]) 
    else: return resp


18

In [18]:
args2 = copy.deepcopy(args)
args2.layer = args.layer + 1
rstim, tr_stats, word_stats = get_stim(args, stories, llama, delay=True)
rresp, tr_stats, word_stats = get_stim(args2, stories, llama, delay=False)

rstim = torch.tensor(rstim).cuda()
rresp = torch.tensor(rresp).cuda()

nchunks = int(np.ceil(rresp.shape[0] / 5 / config.CHUNKLEN))

weights, alphas, bscorrs = bootstrap_ridge_torch(rstim, rresp, use_corr = False, alphas = config.ALPHAS,
        nboots = config.NBOOTS, chunklen = config.CHUNKLEN, nchunks = nchunks)        
bscorrs = bscorrs.mean(2).max(0)
vox = np.sort(np.argsort(bscorrs)[-config.VOXELS:])

# del rstim, rresp

2024-01-08 11:13:13,362 - ridge_corr - INFO - Selecting held-out test set..
2024-01-08 11:13:13,370 - ridge_corr - INFO - Doing SVD...
2024-01-08 11:13:15,443 - ridge_corr - INFO - Dropped 1 tiny singular values.. (U is now torch.Size([2977, 2976]))
2024-01-08 11:13:15,443 - ridge_corr - INFO - Training stimulus has Frobenius norm: 3397.224
2024-01-08 11:13:15,470 - ridge_corr - INFO - Training: alpha=10.000, mean corr=0.00000, max corr=0.00000, over-under(0.20)=0
2024-01-08 11:13:15,474 - ridge_corr - INFO - Training: alpha=16.681, mean corr=0.00000, max corr=0.00000, over-under(0.20)=0
2024-01-08 11:13:15,478 - ridge_corr - INFO - Training: alpha=27.826, mean corr=0.00001, max corr=0.04234, over-under(0.20)=0
2024-01-08 11:13:15,482 - ridge_corr - INFO - Training: alpha=46.416, mean corr=0.00009, max corr=0.10953, over-under(0.20)=0
2024-01-08 11:13:15,486 - ridge_corr - INFO - Training: alpha=77.426, mean corr=0.00231, max corr=0.17192, over-under(0.20)=0
2024-01-08 11:13:15,490 - r

In [23]:
del rstim, rresp
torch.cuda.empty_cache()

In [56]:
rresp, tr_stats, word_stats = get_stim(args, stories, llama, delay=False)
rstim = get_resp(args.subject, stories, stack = True, delay=True, vox = vox)

rstim = torch.tensor(rstim).cuda()
rresp = torch.tensor(rresp).cuda()

nchunks = int(np.ceil(rresp.shape[0] / 5 / config.CHUNKLEN))

_, alphas, bscorrs = bootstrap_ridge_torch(rresp, rstim, use_corr = False, alphas = config.ALPHAS,
        nboots = config.NBOOTS, chunklen = config.CHUNKLEN, nchunks = nchunks)        
bscorrs = bscorrs.mean(2).max(0)
llm_vox = np.sort(np.argsort(bscorrs)[-config.VOXELS:])

del rstim, rresp

2024-01-07 19:26:40,763 - ridge_corr - INFO - Selecting held-out test set..
2024-01-07 19:26:40,766 - ridge_corr - INFO - Doing SVD...
2024-01-07 19:26:42,320 - ridge_corr - INFO - Dropped 0 tiny singular values.. (U is now torch.Size([2977, 2977]))
2024-01-07 19:26:42,320 - ridge_corr - INFO - Training stimulus has Frobenius norm: 2486.823
2024-01-07 19:26:42,374 - ridge_corr - INFO - Training: alpha=10.000, mean corr=0.00000, max corr=0.00000, over-under(0.20)=0
2024-01-07 19:26:42,386 - ridge_corr - INFO - Training: alpha=16.681, mean corr=0.00000, max corr=0.00000, over-under(0.20)=0
2024-01-07 19:26:42,398 - ridge_corr - INFO - Training: alpha=27.826, mean corr=0.00000, max corr=0.00000, over-under(0.20)=0
2024-01-07 19:26:42,409 - ridge_corr - INFO - Training: alpha=46.416, mean corr=0.00000, max corr=0.01474, over-under(0.20)=0
2024-01-07 19:26:42,421 - ridge_corr - INFO - Training: alpha=77.426, mean corr=0.00008, max corr=0.10191, over-under(0.20)=0
2024-01-07 19:26:42,433 - r

In [24]:
stim_dict = {story : get_stim(args, [story], llama, delay=True)[0] for story in stories}
resp_dict = {story : get_stim(args2, [story], llama, delay=False)[0] for story in stories}

In [25]:
# noise_model = torch.zeros([len(vox), len(vox)]).cuda()
for hstory in stories:
    tstim, hstim = np.vstack([stim_dict[tstory] for tstory in stories if tstory != hstory]), stim_dict[hstory]
    tresp, hresp = np.vstack([resp_dict[tstory] for tstory in stories if tstory != hstory]), resp_dict[hstory]
    tstim, hstim = torch.tensor(tstim).cuda(), torch.tensor(hstim).cuda()
    tresp, hresp = torch.tensor(tresp).cuda(), torch.tensor(hresp).cuda()
    bs_weights = ridge_torch(tstim, tresp, alphas[vox])
    bs_weights = bs_weights.to(hstim.device).to(hstim.dtype)
    pred = hstim.matmul(bs_weights)
    # resids = hresp - pred
    # bs_noise_model = resids.T.matmul(resids)
    # noise_model += bs_noise_model / torch.diag(bs_noise_model).mean() / len(stories)

    pred = pred.cpu().numpy() 
    hresp = hresp.cpu().numpy()
    print(scipy.stats.pearsonr(pred.flatten(), hresp.flatten()))


PearsonRResult(statistic=0.2216686750672747, pvalue=0.0)
PearsonRResult(statistic=0.20648518891327244, pvalue=0.0)
PearsonRResult(statistic=0.20083938616936117, pvalue=0.0)
PearsonRResult(statistic=0.17719768939144936, pvalue=0.0)
PearsonRResult(statistic=0.20235533695530714, pvalue=0.0)
PearsonRResult(statistic=0.2174478444259279, pvalue=0.0)
PearsonRResult(statistic=0.20448339628057602, pvalue=0.0)
PearsonRResult(statistic=0.1847570591852558, pvalue=0.0)
PearsonRResult(statistic=0.2000519068334876, pvalue=0.0)
PearsonRResult(statistic=0.20726467936869294, pvalue=0.0)


In [61]:
stim_dict['alternateithicatom'].shape

(343, 40000)

In [67]:
a = 

(343, 11008)

In [7]:
scipy.stats.pearsonr([1,2,3,4], [0,1,2,3])

PearsonRResult(statistic=1.0, pvalue=0.0)

In [79]:
def get_resp_torch(subject, stories, stack = True, vox = None, delay=False):
    """loads response data
    """
    subject_dir = os.path.join(config.DATA_TRAIN_DIR, "train_response", subject)
    resp = {}
    for story in stories:
        resp_path = os.path.join(subject_dir, "%s.hf5" % story)
        hf = h5py.File(resp_path, "r")
        resp[story] = torch.nan_to_num(torch.tensor(hf["data"][:]))
        if vox is not None:
            resp[story] = resp[story][:, vox]
        if delay:
            resp[story] = make_delayed(resp[story], config.RESP_DELAYS)
        hf.close()
    if stack: return torch.vstack([resp[story] for story in stories]) 
    else: return resp

In [67]:
rstim, tr_stats, word_stats = get_stim(args, stories, llama, delay=True)

In [74]:
rstim_torch, tr_stats, word_stats = get_stim_torch(args, stories, llama, delay=True)

In [80]:
rstim_torch.cuda()

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.1978, -0.2035, -0.4431,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.4368, -0.7872, -1.7411,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.0638, -0.0788, -1.0208,  ...,  0.9335, -0.1072,  1.0394],
        [ 0.6023, -0.3566, -0.1209,  ..., -1.1465,  2.6773,  0.8388],
        [ 0.4143,  0.2978, -0.3690,  ...,  0.5060, -0.2307, -0.2083]],
       device='cuda:0')

In [81]:
rstim_torch.device

device(type='cpu')