In [1]:
import argparse
import yaml
import importlib
import utils
import os
from tqdm import tqdm
import numpy as np

import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

from dataset import Features_to_fMRI_Dataset

from models import MappingNetwork
from pl_trainer import VanillaTrainer
from loss import mse_cos_loss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42

utils.set_seed(seed)

cfg = yaml.load(open('config.yaml', 'rb'), Loader=yaml.FullLoader)

Seed is set.


In [3]:
subj_idx = 1
side = 'right'
device = "cuda:0"

checkpoint_path = f"/SSD/slava/algonauts/clip_sam_nn_training/fmri_mapping_subj0{subj_idx}_{side}_seed_{seed}/"
checkpoint_path = os.path.join(checkpoint_path, os.listdir(checkpoint_path)[0])

full_dataset = Features_to_fMRI_Dataset(subj_idx=subj_idx, side=side, mode='test')

# Define model
model = MappingNetwork(
    clip_dim=512,
    sam_dim=1024,
    out_dim=full_dataset.feat_dim
)

# Optimizers, etc.
module = importlib.import_module(cfg["OPTIMIZER"]["MODULE"])
optimizer = getattr(module, cfg["OPTIMIZER"]["CLASS"])(
    model.parameters(), **cfg["OPTIMIZER"]["ARGS"]
)

module = importlib.import_module(cfg["SCHEDULER"]["MODULE"])
scheduler = getattr(module, cfg["SCHEDULER"]["CLASS"])(
    optimizer, **cfg["SCHEDULER"]["ARGS"])

criterion = mse_cos_loss()

cfg['SEED'] = seed
cfg['OUTPUT_DIM'] = full_dataset.feat_dim

lightning_model = VanillaTrainer.load_from_checkpoint(checkpoint_path,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        loss_fn=criterion,
        config=cfg
).to(device).eval()


First five filenames:  ['test-0001_nsd-00845.npy', 'test-0002_nsd-00946.npy', 'test-0003_nsd-01517.npy', 'test-0004_nsd-02655.npy', 'test-0005_nsd-02713.npy']
Test mode, no labels path


Global seed set to 42
  rank_zero_warn(


In [4]:
predictions = []

for idx, (clip_img_feat, clip_txt_feat, sam_feat, filename) in tqdm(enumerate(full_dataset)):

    # to ensure the order
    assert idx+1 == int(filename.split('_')[0].split('-')[-1])

    clip_img_feat = clip_img_feat.to(device).unsqueeze(0)
    clip_txt_feat = clip_txt_feat.to(device).unsqueeze(0)
    sam_feat = sam_feat.to(device).unsqueeze(0)

    pred = lightning_model(clip_img_feat, clip_txt_feat, sam_feat)

    pred = pred.squeeze().cpu().detach().numpy()
    
    assert not np.isnan(pred).any(), 'there is nan values'

    predictions.append(pred)

predictions = np.array(predictions)
predictions.shape

159it [00:00, 223.04it/s]


(159, 20544)

In [6]:
save_name = f'/SSD/slava/algonauts/clip_sam_nn_submission/subj0{subj_idx}'

if side=='right':
    save_name = os.path.join(save_name, 'rh_pred_test.npy')
elif side=='left':
    save_name = os.path.join(save_name, 'lh_pred_test.npy')
else:
    raise NameError

print('Saving into ', save_name)

Saving into  /SSD/slava/algonauts/clip_sam_nn_submission/subj01/rh_pred_test.npy
