## Train Encoder Model

In [1]:
import os
import numpy as np
import json
import argparse

import config
from GPT import GPT
from StimulusModel import LMFeatures
from utils_stim import get_stim
from utils_resp import get_resp
from utils_ridge.ridge import ridge, bootstrap_ridge
np.random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm
  _C._set_default_tensor_type(t)


In [5]:
parser = argparse.ArgumentParser()
parser.add_argument("--subject", type = str, required = True)
parser.add_argument("--gpt", type = str, default = "perceived")
parser.add_argument("--sessions", nargs = "+", type = int, default = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20])
args = parser.parse_args("--subject S1".split())

In [6]:
# training stories
stories = []
with open(os.path.join(config.DATA_TRAIN_DIR, "sess_to_story.json"), "r") as f:
    sess_to_story = json.load(f)
for sess in args.sessions:
    stories.extend(sess_to_story[str(sess)])

In [7]:
# load gpt
with open(os.path.join(config.DATA_LM_DIR, args.gpt, "vocab.json"), "r") as f:
    gpt_vocab = json.load(f)
gpt = GPT(path = os.path.join(config.DATA_LM_DIR, args.gpt, "model"), vocab = gpt_vocab, device = config.GPT_DEVICE)
features = LMFeatures(model = gpt, layer = config.GPT_LAYER, context_words = config.GPT_WORDS)

In [8]:
# estimate encoding model
rstim, tr_stats, word_stats = get_stim(stories, features)

#### Training Responses

In [9]:
rresp = get_resp(args.subject, stories, stack = True)

In [10]:
print(rstim.shape)
print(rresp.shape)

(27449, 3072)
(27449, 81126)


In [11]:
nchunks = int(np.ceil(rresp.shape[0] / 5 / config.CHUNKLEN))

In [None]:
weights, alphas, bscorrs = bootstrap_ridge(rstim, rresp, use_corr = False, alphas = config.ALPHAS, nboots = config.NBOOTS, chunklen = config.CHUNKLEN, nchunks = nchunks)

In [None]:
weights.shape

(3072, 81126)

In [None]:
bscorrs = bscorrs.mean(2).max(0)
vox = np.sort(np.argsort(bscorrs)[-config.VOXELS:])

In [None]:
stim_dict = {story : get_stim([story], features, tr_stats = tr_stats) for story in stories}
resp_dict = get_resp(args.subject, stories, stack = False, vox = vox)
noise_model = np.zeros([len(vox), len(vox)])

In [None]:
## 138min 16.9s
for hstory in stories:
    tstim, hstim = np.vstack([stim_dict[tstory] for tstory in stories if tstory != hstory]), stim_dict[hstory]
    tresp, hresp = np.vstack([resp_dict[tstory] for tstory in stories if tstory != hstory]), resp_dict[hstory]
    print(tstim.shape)
    print(tresp.shape)
    bs_weights = ridge(tstim, tresp, alphas[vox])
    print(bs_weights.shape)
    resids = hresp - hstim.dot(bs_weights)
    bs_noise_model = resids.T.dot(resids)
    noise_model += bs_noise_model / np.diag(bs_noise_model).mean() / len(stories)
    break

In [53]:
# save
save_location = os.path.join(config.MODEL_DIR, args.subject)
os.makedirs(save_location, exist_ok = True)
np.savez(os.path.join(save_location, "encoding_model_%s" % args.gpt),
    weights = weights, noise_model = noise_model, alphas = alphas, voxels = vox, stories = stories,
    tr_stats = np.array(tr_stats), word_stats = np.array(word_stats))