# eval

In [13]:
import numpy as np
from sklearn.linear_model import LinearRegression
import pickle
import yaml
from tqdm import tqdm
import jax.numpy as jnp
from src.utils import get_args_and_config
from src.data import get_data
from src.eval import corr
from src.model import forward
from src.fmri import plot_brain

In [2]:
args, config = get_args_and_config()
data = get_data(args, config)

100%|██████████| 6/6 [00:34<00:00,  5.70s/it]


In [15]:
def algonauts_model(subject, train_data, test_data):
    """algonauts_model function"""
    train_data_img = np.vstack([x[0] for x in train_data])
    train_data_lh = np.vstack([x[1] for x in train_data])
    train_data_rh = np.vstack([x[2] for x in train_data])
    train_data = (train_data_img, train_data_lh, train_data_rh)

    lh_model = LinearRegression().fit(train_data_img, train_data_lh)
    rh_model = LinearRegression().fit(train_data_img, train_data_rh)

    # save model
    # pickle.dump(lh_model, open(f'./models/{subject}_lh_algonauts_model.pkl', 'wb'))
    # pickle.dump(rh_model, open(f'./models/{subject}_rh_algonauts_model.pkl', 'wb'))

    # test model
    test_data_img, test_data_lh, test_data_rh, _ = test_data
    
    test_lh_pred = lh_model.predict(test_data_img)
    test_rh_pred = rh_model.predict(test_data_img)
    test_lh_corr = jnp.median(corr(test_lh_pred, test_data_lh))
    test_rh_corr = jnp.median(corr(test_rh_pred, test_data_rh))
    print(subject)
    print(f"test lh corr: {test_lh_corr}")
    print(f"test rh corr: {test_rh_corr}")
    return test_lh_corr, test_rh_corr


algonauts_baseline = {subject: {'lh': None, 'rh': None} for subject in data.keys()}

for subject, (folds, test_data) in tqdm(data.items()):
    lh_corr, rh_corr = algonauts_model(subject, folds, test_data)
    algonauts_baseline[subject]['lh'] = lh_corr.item()
    algonauts_baseline[subject]['rh'] = rh_corr.item()


yaml.dump(algonauts_baseline, open('config/algonauts_baseline.yaml', 'w'))

 17%|█▋        | 1/6 [00:08<00:41,  8.21s/it]

subj01
test lh corr: 0.21079011261463165
test rh corr: 0.2079717516899109


 33%|███▎      | 2/6 [00:16<00:32,  8.20s/it]

subj02
test lh corr: 0.2017333209514618
test rh corr: 0.20145922899246216


 50%|█████     | 3/6 [00:24<00:23,  7.94s/it]

subj03
test lh corr: 0.18003970384597778
test rh corr: 0.1731811910867691


 67%|██████▋   | 4/6 [00:31<00:15,  7.73s/it]

subj04
test lh corr: 0.15315435826778412
test rh corr: 0.18053296208381653


 83%|████████▎ | 5/6 [00:39<00:07,  7.89s/it]

subj05
test lh corr: 0.2231391966342926
test rh corr: 0.21678632497787476


100%|██████████| 6/6 [00:47<00:00,  7.97s/it]

subj07
test lh corr: 0.15720497071743011
test rh corr: 0.1556631624698639



