In [15]:
import pandas as pd
import yaml
import torch
from pytorch_lightning import Trainer, seed_everything
from eyemind.trainer.loops import KFoldLoop
from eyemind.models.classifier import EncoderClassifierMultiSequenceModel, EncoderClassifierModel
from eyemind.models.encoder_decoder import VariableSequenceLengthEncoderDecoderModel
from eyemind.dataloading.gaze_data import SequenceToLabelDataModule, SequenceToSequenceDataModule

In [2]:
config_path = "/Users/rickgentry/emotive_lab/eyemind/experiment_configs/local/comprehension_nestedcv_test_config.yml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

In [3]:
seed_everything(config["seed_everything"], workers=True)

Global seed set to 42


42

## Load Model checkpoint

In [4]:
ckpt_path = "/Users/rickgentry/emotive_lab/eyemind/ray_results/comprehension-tune-sub/fold3_freeze_encoder=False/checkpoints/epoch=19-step=460.ckpt"    
model = EncoderClassifierModel.load_from_checkpoint(ckpt_path,encoder_weights_path=None)

CNN Layers: [16, 32]
16 2
32 16




## Setup Data and Trainer

In [5]:
datamodule = SequenceToLabelDataModule(**config["data"])
datamodule.setup()

In [6]:
datamodule.setup_cv_folds(4,4)

In [7]:
datamodule.setup_cv_fold_index(3,-1)

In [8]:
trainer = Trainer(**config["trainer"])

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Predictions

In [9]:
train_dl = datamodule.get_dataloader(datamodule.train_fold)
val_dl = datamodule.get_dataloader(datamodule.val_fold)
test_dl = datamodule.test_dataloader()

In [12]:
train_preds = trainer.predict(model, train_dl)
val_preds = trainer.predict(model, val_dl)
test_preds = trainer.predict(model, test_dl)

  rank_zero_warn(


Predicting DataLoader 0: 100%|██████████| 48/48 [00:13<00:00,  3.68it/s]
Predicting DataLoader 0: 100%|██████████| 17/17 [00:11<00:00,  1.48it/s]
Predicting DataLoader 0: 100%|██████████| 12/12 [00:11<00:00,  1.06it/s]


In [16]:
def get_fraction_of_comprehension(preds):
    return torch.cat(preds).mean()

In [19]:
print(f"train:{get_fraction_of_comprehension(train_preds)}, val: {get_fraction_of_comprehension(val_preds)}, test: {get_fraction_of_comprehension(test_preds)}")

train:0.6068152189254761, val: 0.580152690410614, test: 0.591160237789154
