In [1]:
# import 
import numpy as np
from scipy.stats import chi2
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.dates as mdates
import matplotlib.colors as mcolors
import polars as pl
import pandas as pd
import seaborn as sns
import torch
import lightning as L
from sklearn.calibration import calibration_curve
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix, roc_curve
from omegaconf import OmegaConf, DictConfig
import hydra
import wandb
from dataset import SupervisedDataset
from lightning_modules import SupervisedTask
from models.ecg_models import *
from run import interpolate
from torchview import draw_graph
import cairosvg
from bs4 import BeautifulSoup
import torch.nn as nn
pl.Config.set_tbl_rows(50)
MY_NAVY = '#001F54'

  torch.utils._pytree._register_pytree_node(


In [2]:
WANDB_RUN='payalchandak/SILVER/nuzrc47q'

In [3]:
device = 'cuda:2'
cfg = OmegaConf.create(wandb.Api().run(WANDB_RUN).config)
L.seed_everything(cfg.utils.seed)
train_pyd = hydra.utils.instantiate(cfg.dataset, split='train')
cfg = interpolate(cfg, train_pyd)
del train_pyd
trainer = L.Trainer(devices=[int(device[-1])])
LM = SupervisedTask.load_from_checkpoint(cfg.best_model_path, map_location=torch.device(device))
model = LM.model
model.to(device)
model.eval()
cfg.optimizer.batch_size = 2048
cfg.dataset.config.label = 'future_1_365_any_below_40'
pyd = hydra.utils.instantiate(cfg.dataset, split='test')
pyd.data = pyd.data.reset_index(drop=1)
assert len(pyd)
loader = torch.utils.data.DataLoader(
    dataset = pyd,
    batch_size = cfg.optimizer.batch_size,
    num_workers = 0, 
    collate_fn = pyd.collate,
    shuffle=False,
    pin_memory=True
)
for batch in loader: 
    break

Seed set to 140799
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
# architecture 

class ECGEncoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        x = self.encoder(x)
        return self.decoder(x)


graph = draw_graph(
    ECGEncoder(model.ecg_encoder, model.ecg_decoder),
    input_data=batch['ecg'][0].unsqueeze(0),
    graph_dir="TB",                # top-to-bottom layout
    expand_nested=True,
    save_graph=True,
    filename="ecg_encoder",
    roll=True,
    depth=4,
    show_shapes=True,
    hide_module_functions=False   # we want to manually clean depth text
)

graph.visual_graph.render(format="svg", filename="ecg_encoder", cleanup=True)

with open("ecg_encoder.svg", "r") as f:
    svg = f.read()
soup = BeautifulSoup(svg, "xml")
for text in soup.find_all("text"):
    if text.string:
        s = text.string.strip()
        if s.startswith("depth:"):
            text.decompose()  # delete the depth line entirely
        elif s == "input-tensor":
            text.string.replace_with("12-lead ECG")
        elif s == "output-tensor":
            text.string.replace_with("ECG embed")
with open("ecg_encoder.svg", "w") as f:
    f.write(str(soup))
cairosvg.svg2png(url="ecg_encoder.svg",write_to="ecg_encoder.png",dpi=300)

graph = draw_graph(
    ECGEncoder(model.ecg_encoder, model.ecg_decoder),
    input_data=batch['ecg'][0].unsqueeze(0),
    graph_dir="TB",                # top-to-bottom layout
    expand_nested=True,
    save_graph=True,
    filename="ecg_encoder_summary",
    roll=True,
    depth=1,
    show_shapes=True,
    hide_module_functions=False   # we want to manually clean depth text
)

graph.visual_graph.render(format="svg", filename="ecg_encoder_summary", cleanup=True)

with open("ecg_encoder_summary.svg", "r") as f:
    svg = f.read()
soup = BeautifulSoup(svg, "xml")
for text in soup.find_all("text"):
    if text.string:
        s = text.string.strip()
        if s.startswith("depth:"):
            text.decompose()  # delete the depth line entirely
        elif s == "input-tensor":
            text.string.replace_with("12-lead ECG")
        elif s == "output-tensor":
            text.string.replace_with("ECG embed")
with open("ecg_encoder_summary.svg", "w") as f:
    f.write(str(soup))
cairosvg.svg2png(url="ecg_encoder_summary.svg",write_to="ecg_encoder_summary.png",dpi=300)


graph = draw_graph(
    model.prior_lvef_decoder,
    input_data=torch.randn(1,15),
    graph_dir="TB",                # top-to-bottom layout
    expand_nested=True,
    save_graph=True,
    filename="lvef_encoder",
    roll=True,
    show_shapes=True,
    hide_module_functions=False   # we want to manually clean depth text
)

graph.visual_graph.render(format="svg", filename="lvef_encoder", cleanup=True)

with open("lvef_encoder.svg", "r") as f:
    svg = f.read()
soup = BeautifulSoup(svg, "xml")
for text in soup.find_all("text"):
    if text.string:
        s = text.string.strip()
        if s.startswith("depth:"):
            text.decompose()  # delete the depth line entirely
        elif s == "input-tensor":
            text.string.replace_with("LVEF history")
        elif s == "output-tensor":
            text.string.replace_with("LVEF embed")
with open("lvef_encoder.svg", "w") as f:
    f.write(str(soup))
cairosvg.svg2png(url="lvef_encoder.svg",write_to="lvef_encoder.png",dpi=300)


class Decoder(nn.Module):
    def __init__(self, mlp):
        super().__init__()
        self.mlp = mlp

    def forward(self, x, y):
        return self.mlp(x+y)


graph = draw_graph(
    Decoder(model.mlp),
    input_data=(torch.randn(1,512),torch.randn(1,512)),
    graph_dir="TB",                # top-to-bottom layout
    expand_nested=True,
    save_graph=True,
    filename="decoder",
    roll=True,
    show_shapes=True,
    hide_module_functions=False   # we want to manually clean depth text
)

graph.visual_graph.render(format="svg", filename="decoder", cleanup=True)

with open("decoder.svg", "r") as f:
    svg = f.read()
soup = BeautifulSoup(svg, "xml")
input_count = 0
for text in soup.find_all("text"):
    if text.string:
        s = text.string.strip()
        if s.startswith("depth:"):
            text.decompose()
        elif s == "input-tensor":
            if input_count == 0:
                text.string.replace_with("ECG embed")
            elif input_count == 1:
                text.string.replace_with("LVEF embed")
            input_count += 1
        elif s == "output-tensor":
            text.string.replace_with("Prediction")
with open("decoder.svg", "w") as f:
    f.write(str(soup))
cairosvg.svg2png(url="decoder.svg",write_to="decoder.png",dpi=300)

import os
from pathlib import Path

# Delete all SVG files in current directory
for file in Path(".").glob("*.svg"):
    file.unlink()

