## Evaluate Encoder Model

In [2]:
import os
import numpy as np
import json
import argparse
import h5py

import config
from GPT import GPT
from EncodingModel import EncodingModel
from StimulusModel import StimulusModel, get_lanczos_mat, affected_trs, LMFeatures
from utils_stim import get_stim
from utils_resp import get_resp
from utils_ridge.ridge import ridge, bootstrap_ridge
np.random.seed(42)

In [9]:
parser = argparse.ArgumentParser()
parser.add_argument("--subject", type = str, required = True)
parser.add_argument("--gpt", type = str, default = "perceived")
parser.add_argument("--experiment", type = str, default = "perceived_speech")
parser.add_argument("--task", type = str, default = "wheretheressmoke")
parser.add_argument("--sessions", nargs = "+", type = int, default = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20])
args = parser.parse_args("--subject S1".split())

In [22]:
# determine word rate model voxels based on experiment
word_rate_voxels = "speech" if args.experiment in ["imagined_speech", "perceived_movies"] else "auditory"

In [23]:
# load models
load_location = os.path.join(config.MODEL_DIR, args.subject)
word_rate_model = np.load(os.path.join(load_location, "word_rate_model_%s.npz" % word_rate_voxels), allow_pickle = True)
encoding_model = np.load(os.path.join(load_location, "encoding_model_%s.npz" % args.gpt))

In [24]:
# load responses
hf = h5py.File(os.path.join(config.DATA_TEST_DIR, "test_response", args.subject, args.experiment, args.task + ".hf5"), "r")
resp = np.nan_to_num(hf["data"][:])
hf.close()

In [25]:
# load EM
weights = encoding_model["weights"]
noise_model = encoding_model["noise_model"]
tr_stats = encoding_model["tr_stats"]
word_stats = encoding_model["word_stats"]
em = EncodingModel(resp, weights, encoding_model["voxels"], noise_model, device = config.EM_DEVICE)
em.set_shrinkage(config.NM_ALPHA)

In [26]:
assert args.task not in encoding_model["stories"]

In [12]:
# training stories
stories = []
with open(os.path.join(config.DATA_TRAIN_DIR, "sess_to_story.json"), "r") as f:
    sess_to_story = json.load(f) 
for sess in args.sessions:
    stories.extend(sess_to_story[str(sess)])

In [13]:
rresp = get_resp(args.subject, stories, stack = True)

In [14]:
rresp.shape

(27449, 81126)

#### Prepare Stimulus Features

In [18]:
# load gpt
with open(os.path.join(config.DATA_LM_DIR, args.gpt, "vocab.json"), "r") as f:
    gpt_vocab = json.load(f)
gpt = GPT(path = os.path.join(config.DATA_LM_DIR, args.gpt, "model"), vocab = gpt_vocab, device = config.GPT_DEVICE)
features = LMFeatures(model = gpt, layer = config.GPT_LAYER, context_words = config.GPT_WORDS)

In [19]:
# estimate encoding model
rstim, tr_stats, word_stats = get_stim(stories, features)

In [20]:
rstim.shape

(27449, 3072)

#### Calculate P(R|S)
In other words: "find the probability of a response given a stimulus"

In [21]:
likelihoods = em.prs(rstim, rresp)

AttributeError: 'numpy.ndarray' object has no attribute 'float'

In [22]:
nchunks = int(np.ceil(rresp.shape[0] / 5 / config.CHUNKLEN))

In [None]:
weights, alphas, bscorrs = bootstrap_ridge(rstim, rresp, use_corr = False, alphas = config.ALPHAS, nboots = config.NBOOTS, chunklen = config.CHUNKLEN, nchunks = nchunks)

In [49]:
weights.shape

(3072, 81126)

In [47]:
np.save('/workspace/weights/S1.npy', weights)

In [50]:
bscorrs = bscorrs.mean(2).max(0)
vox = np.sort(np.argsort(bscorrs)[-config.VOXELS:])

In [51]:
stim_dict = {story : get_stim([story], features, tr_stats = tr_stats) for story in stories}
resp_dict = get_resp(args.subject, stories, stack = False, vox = vox)
noise_model = np.zeros([len(vox), len(vox)])

In [52]:
## 138min 16.9s
for hstory in stories:
    tstim, hstim = np.vstack([stim_dict[tstory] for tstory in stories if tstory != hstory]), stim_dict[hstory]
    tresp, hresp = np.vstack([resp_dict[tstory] for tstory in stories if tstory != hstory]), resp_dict[hstory]
    bs_weights = ridge(tstim, tresp, alphas[vox])
    resids = hresp - hstim.dot(bs_weights)
    bs_noise_model = resids.T.dot(resids)
    noise_model += bs_noise_model / np.diag(bs_noise_model).mean() / len(stories)

In [53]:
# save
save_location = os.path.join(config.MODEL_DIR, args.subject)
os.makedirs(save_location, exist_ok = True)
np.savez(os.path.join(save_location, "encoding_model_%s" % args.gpt),
    weights = weights, noise_model = noise_model, alphas = alphas, voxels = vox, stories = stories,
    tr_stats = np.array(tr_stats), word_stats = np.array(word_stats))