In [29]:
import pandas as pd
import yaml
import torch
import os
from pytorch_lightning import Trainer, seed_everything
from eyemind.trainer.loops import KFoldLoop
import eyemind
from eyemind.models.transformers import InformerEncoderDecoderModel, InformerEncoderFixationModel, InformerMultiTaskEncoderDecoder
from eyemind.dataloading.informer_data import InformerDataModule, InformerMultiLabelDatamodule,  InformerVariableLengthDataModule
import matplotlib.pyplot as plt
from eyemind.analysis.visualize import plot_scanpath_labels, viz_coding, fixation_image, plot_scanpath_pc

# Overview
Simply load in already-trained models and generate predictions

This shouldn't be this difficult!

# Notes


in the original informer code, there is a command line option
`parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data')`
Check what this does

## encoder only
I have managed to get logits from the encoder, but model(batch) does not work when mdoel is the entire multitask enc-decoder stack.

## fixation decoder
How to load this? 
A: model.fi_decoder

# TODO ideas
[x] try loading encoder and fi decoder separately

[x] get predictions for other tasks

[x] plot predictions

[ ] plot attention




In [5]:
# Load our trained encoder decoder from checkpoint for each fold
# this shouold load the encoder and its 4 decoders

# using pytorch lightning module
# https://lightning.ai/docs/pytorch/stable/deploy/production_basic.html

repodir = os.path.dirname(os.path.dirname(eyemind.__file__))
test_data_dir = os.path.join(repodir,"data/EML/gaze+fix+reg")
test_label_file =  os.path.join(repodir,"./data/EML/EML1_pageLevel_500+_matchEDMinstances.csv")
# save_dir_base = f"{repodir}/lightning_logs/2024/cluster/new_multitask_informer_pretraining"
save_dir_base = f"{repodir}/lightning_logs/informer_pretraining_seed21"

is_old_version=True # i.e. before I made changes to model and datamodule to support multiple labels etc
fold=0
save_dir = os.path.join(save_dir_base, f'fold{fold}/')
config_path=os.path.join(save_dir,"config.yaml")
ckpt_path = os.path.join(save_dir,"checkpoints","last.ckpt")
# ckpt_path = os.path.join(save_dir,"checkpoints","epoch=168-step=8619.ckpt")


with open(config_path, "r") as f:
    config = yaml.safe_load(f)
seed_everything(config["seed_everything"], workers=True) # not sure if this is needed

model = InformerMultiTaskEncoderDecoder.load_from_checkpoint(ckpt_path,
                                                # encoder_weights_path=None
                                                )
encoder=model.encoder
# decoder=model.fi_decoder
model.eval()
encoder.eval()
# decoder.eval()

# set up an InformerDataModule to load the same data as used in training
# trainer = Trainer(**config["trainer"])
# data_dir = os.path.join(repodir,config["data"]["data_dir"])
# label_file = os.path.join(repodir,config["data"]["label_filepath"])
config["data"]["data_dir"]=test_data_dir
config["data"]["label_filepath"]=test_label_file

# edit config for consistency woth the new version of the datamodule
if is_old_version:
    config["data"]["min_sequence_length"] = config["data"]["min_scanpath_length"]
    config["data"].pop("min_scanpath_length")
    config["data"]["sample_label_col"]="fixation_label"
    config["data"]["file_label_col"]=None



Global seed set to 21


In [None]:

datamodule = InformerMultiLabelDatamodule(**config["data"])
datamodule.setup()

test_dl = datamodule.get_dataloader(datamodule.test_dataset) # this is the held out fold's dataloader
for i,batch in enumerate(test_dl):
    print(f"batch: {i}")
    print(f"length of batch: {len(batch)}")
    n_items = len(batch) # this is not the batch size, but the number of items (data and labels) to unpack
    if n_items==2: # just gaze sequence and fixation (sample) labels
        X, yi = batch
        X, X_mask = X 
        yi, yi_mask = yi
    elif n_items==4: # contrastive so X and X2 are present
        X, yi, X2, cl_y = batch
        X, X_mask = X 
        yi, yi_mask = yi
        X2, X2_mask = X2
    elif n_items==5: # sequence and fixation (sample) labels
        X, yi, seq_y, X2, cl_y = batch
        X, X_mask = X
        yi, yi_mask = yi
        seq_y, seq_y_mask = seq_y
        X2, X2_mask = X2
    with torch.no_grad():
        logits=encoder(X, None)


    if i==0: # just run a couple to check
        break


In [None]:
config["data"]

In [None]:

perc_nan = torch.isnan(logits).sum() / torch.numel(logits)
print(f'percentage of nans in  encoder logits: {100*perc_nan}')
# if torch.equal(logits2, logits):
#     print('providing labels made no difference')
# else:
#     print('removing labels changed the predicitons')

print(f'logits shape: {logits.shape}')


In [None]:
print(logits[0,:50,0])


In [None]:
# # We can see the attributes of the full model using dir
# dir(model)
# # it has a module list of decoders
# dir(model.decoders)
if hasattr(model, "fi_decoder"):
    print("model has fi_decoder")
elif hasattr(model, "fm_decoder"):
    model.fi_decoder = model.fm_decoder
    print("model has fm_decoder")
else:
    print("model has no fixation decoder")
# # Apart from pc we can access each task's decoder like so:
# print(f'FI decoder: {model.fi_decoder}')
# print(f'RC decoder: {model.rc_decoder}')
# print(f'CL decoder: {model.cl_decoder}')

In [None]:
# Great! Let's use the logits from the encoder as input to the fi decoder...

fixation_logits = model.fi_decoder.forward(logits)
# this is of shape n, len, 2. I assume the 2 dimensions here are prob(fix) and prob(sacc)? 
# do we need to just take one dimension? And softmax then threshold >/5?
fixation_preds=fixation_logits.max(2).indices
print(fixation_logits)
fixation_targets = yi
# how many unique values are there in the targets?
print(f"unique preds: {torch.unique(fixation_preds)}")
print(f"unique targets: {torch.unique(fixation_targets)}")
# or reshape batch into one long vector in Ricks's code, to get batch-wisemetric:
logits_long = fixation_logits.squeeze().reshape(-1,2)
targets_long = fixation_targets.reshape(-1).long()


In [None]:

# pick one from batch to plot
ix=4
one_pred = fixation_preds[ix,:]
one_target = fixation_targets[ix,:]
print(one_pred.shape)
print(one_target.shape)
fixation_image(one_pred, one_target, "Fixation Identification - Informer (top:pred, bottom:target)")

# SUCCESS!!


In [None]:
# evaluate FI 

#mask = torch.any(X == -180, dim=1)
loss = model.fi_criterion(logits_long, targets_long)
preds = model._get_preds(logits)
probs = model._get_probs(logits)
fixation_targets = fixation_targets.int()
auprc = model.fi_metric(probs, fixation_targets)

In [None]:
# predictive coding one batch
from eyemind.dataloading.batch_loading import predictive_coding_batch

pred_length = 150
label_length = 100
pc_seq_length = 350
X_pc, Y_pc = predictive_coding_batch(X, pc_seq_length=pc_seq_length, label_length=label_length, pred_length=pred_length)
with torch.no_grad():
    pc_logits = model.pc_decoder.forward(logits, Y_pc, pred_length=pred_length).squeeze()
pc_target=Y_pc[:,:-pred_length] # take just the predicted part as target
# coutn how many nan in pc_logits
nan_count = torch.isnan(pc_logits).sum()
# print(f"Number of nan in pc_logits: {nan_count} / {pc_logits.numel()}")
print(f"PC logits: {pc_logits.shape}")
perc_nan = nan_count/pc_logits.numel()
print(f"Percentage of nan in PC logits: {100*perc_nan:.2f}")
pc_logits = model.scaler.inverse_transform(pc_logits)
pc_target = model.scaler.inverse_transform(pc_target)
X = model.scaler.inverse_transform(X)

In [None]:
ix=9

handle = plot_scanpath_pc(X[ix,:,0], X[ix,:,1], pc_logits[ix,:,0].cpu().detach().numpy(), pc_logits[ix,:,1].cpu().detach().numpy())


In [None]:
# reconstruction
with torch.no_grad():
    recon_logits = model.rc_decoder.forward(logits, X)
recon_targets = X

perc_nan = torch.isnan(recon_logits).sum()/recon_logits.numel()
print(f"Percentage of nan in recon_logits: {100*perc_nan:.2f}")

# plot using pc function
ix=2
handle = plot_scanpath_pc(recon_targets[ix,:,0], recon_targets[ix,:,1], recon_logits[ix,:,0].cpu().detach().numpy(), recon_logits[ix,:,1].cpu().detach().numpy())


In [30]:
# make a dataloader with variable sequence length and padding 



{'data_dir': '/home/rosy/DeepGaze/data/EML/gaze+fix+reg',
 'label_filepath': '/home/rosy/DeepGaze/./data/EML/EML1_pageLevel_500+_matchEDMinstances.csv',
 'load_setup_path': None,
 'test_dir': None,
 'train_dataset': None,
 'val_dataset': None,
 'test_dataset': None,
 'train_fold': None,
 'val_fold': None,
 'num_workers': 4,
 'batch_size': 32,
 'pin_memory': True,
 'drop_last': True,
 'min_sequence_length': 10,
 'max_sequence_length': 2000,
 'label_col': None}

batch: 0
length of batch: 2
torch.Size([32, 2000, 2])
False
percentage of nans in  encoder logits: 0.00
percentage of nans in  encoder logits from masked: 0.00


tensor([[[ 0.7113,  0.0050,  0.4600,  ...,  0.0159, -0.0166,  0.1135],
         [-0.2500, -0.1553,  0.1488,  ..., -0.1673,  0.0022,  0.2387],
         [-0.2646, -0.1868,  0.2355,  ..., -0.2032,  0.0284,  0.2396],
         ...,
         [-0.2758,  0.0041,  0.4808,  ..., -0.1571,  0.1564,  0.2501],
         [-0.3866, -0.1225,  0.3696,  ..., -0.1294,  0.1486,  0.2532],
         [ 0.3169,  0.0255,  1.1996,  ..., -0.4156,  0.1024,  0.2111]],

        [[-0.6587, -0.0388,  1.5289,  ..., -0.2655,  0.1047,  0.0391],
         [-0.7262, -0.1766,  0.4579,  ..., -0.2528,  0.1531,  0.2002],
         [-0.7316, -0.1942,  0.4682,  ..., -0.2596,  0.1536,  0.1992],
         ...,
         [-0.7570,  0.1429,  1.1773,  ..., -0.1421, -0.1104,  0.1763],
         [-0.8990, -0.2138,  0.6514,  ..., -0.2344,  0.1967,  0.1939],
         [-0.6549, -0.0407,  1.5165,  ..., -0.3282,  0.1854,  0.0593]],

        [[ 0.2507,  0.0743,  0.7018,  ..., -0.3943, -0.2530,  0.1455],
         [-0.3186, -0.0794,  0.2028,  ..., -0

In [28]:
# dir(datamodule)
datamodule.sequence_length
datamodule.min_sequence_length
datamodule.max_sequence_length

500