# Plot Scripts

In [None]:
import numpy as np
import pandas as pd

# matplotlib
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.transforms as mtransfor
from matplotlib.ticker import FuncFormatter
import seaborn as sns

plt.style.use('default')
plt.rcParams['axes.facecolor']='white'
plt.rcParams.update({"axes.grid" : True, "grid.color": "gainsboro"})
plt.rcParams['legend.frameon']=True
plt.rcParams['legend.facecolor']='white'
plt.rcParams['legend.edgecolor']='grey'
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.linewidth"]  = 1

## Draw OSMAE / ES scores on different thresholds

In [None]:
from datasets.loader.load_los_info import get_los_info
from datasets.loader.datamodule import EhrDataModule
from pipelines import DlPipeline
import lightning as L

In [None]:
# init config (CDSL dataset, TCN multitask model, fold-0, seed-0)
config = {
  'model': 'TCN',
  'dataset': 'cdsl',
  'task': 'multitask',
  'epochs': 100,
  'patience': 10,
  'batch_size': 128,
  'learning_rate': 0.001,
  'main_metric': 'auprc',
  'demo_dim': 2,
  'lab_dim': 97,
  'hidden_dim': 128,
  'output_dim': 1,
  }

thresholds = np.arange(0,10,0.1)[1:].tolist()

# load CDSL fold-0 data
los_config = get_los_info(f'datasets/{config["dataset"]}/processed/fold_0')
los_config['threshold'] = thresholds
config.update({"los_info": los_config})
dm = EhrDataModule(f'datasets/{config["dataset"]}/processed/fold_0', batch_size=config["batch_size"])

# load TCN multitask model
checkpoint_path = f'logs/test/{config["dataset"]}/{config["task"]}/{config["model"]}-fold0-seed0/checkpoints/best.ckpt'
pipeline = DlPipeline(config)
trainer = L.Trainer(accelerator="cpu", max_epochs=1, logger=False, num_sanity_val_steps=0)
trainer.test(pipeline, dm)

# get scores
perf = pipeline.test_performance

In [None]:
print(len(perf['osmae_list']), len(perf['es_list']))
es = perf['es_list'][::4]
osmae = perf['osmae_list'][::4]
thres = thresholds[::4]
print(len(es), len(osmae), len(thres))

In [None]:
# ES Score
ax = sns.regplot(x=thres, y=es, marker="o", color="g", line_kws={"color": "grey", "linestyle": "-", "linewidth": "1"}, ci=99.9999)
plt.xlabel('Threshold γ')
plt.ylabel('ES Score')

plt.savefig('logs/figures/es_trend.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# OSMAE Score
ax = sns.regplot(x=thres, y=osmae, marker="o", color="dodgerblue", line_kws={"color": "grey", "linestyle": "-", "linewidth": "1"}, ci=99.9999)
plt.xlabel('Threshold γ')
plt.ylabel('OSMAE Score')

plt.savefig('logs/figures/osmae_trend.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show()