# PRNA Model Example

This notebook demonstrates how to use pre-trained Transformer models based on the PRNA architecture described in [this paper](https://www.physionet.org/files/challenge-2020/1.0.1/papers/CinC2020-107.pdf).

A full list of pre-trained models are available at [this location](https://github.com/stanfordmlgroup/aihc-win21-ed-monitor/blob/main/MODELS.md). Please contact tomjin \[at\] stanford.edu for access.

Prior to running this notebook, you will need to install the `edm` module by following the README.md on the homepage of this repository.


### Setup

In [1]:
import argparse
import csv
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from edm.models.transformer_model import load_best_model

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


In [10]:
# Define the output classes that the PRNA model was trained against

classes = sorted(['270492004', '164889003', '164890007', '426627000', '713427006',
                  '713426002', '445118002', '39732003', '164909002', '251146004',
                  '698252002', '10370003', '284470004', '427172004', '164947007',
                  '111975006', '164917005', '47665007', '427393009',
                  '426177001', '426783006', '427084000', '164934002',
                  '59931005'])

concept_to_desc = {
    "270492004": "first degree atrioventricular block",
    "164889003": "atrial fibrillation",
    "426627000": "bradycardia",
    "164890007": "atrial flutter",
    "713427006": "complete right bundle branch block",
    "713426002": "incomplete right bundle branch block",
    "445118002": "left anterior fascicular block",
    "39732003": "left axis deviation",
    "164909002": "left bundle branch block",
    "251146004": "low QRS voltage",
    "698252002": "non-specific intraventricular conduction delay",
    "10370003": "Pacing rhythm",
    "284470004": "Premature atrial contraction",
    "427172004": "Premature ventricular contractions",
    "164947007": "Prolonged PR interval",
    "111975006": "Prolonged QT interval",
    "164917005": "Q wave abnormal",
    "47665007": "Right axis deviation",
    "427393009": "Sinus arrhythmia",
    "426177001": "Sinus bradycardia",
    "426783006": "Sinus rhythm",
    "427084000": "Sinus tachycardia",
    "164934002": "T wave abnormal",
    "59931005": "T wave inversion",
    "59118001": "Right bundle branch block (disorder)",
    "63593006": "Supraventricular premature beats",
    "17338001": "Ventricular premature beats"
}


## Cardiac Abnormality Class Prediction

### Load Model

- `embedding_size` controls the size of the fully-connected layers. Refer to the pre-trained model chart.
- `remove_last_layer` controls whether the last layer should be removed. If not removed, we are predicting the cardiac abnormality class.
- `model_path` specifies where the pre-trained models are stored.

In [34]:
embedding_size = 64
remove_last_layer = False
model_path = "/deep/group/ed-monitor/models/prna/outputs-wide-64-15sec-bs64/saved_models/ctn/fold_1/ctn.tar"

model = load_best_model(model_path, deepfeat_sz=embedding_size, remove_last_layer=remove_last_layer)
model.eval()

deepfeat_sz=64, nb_patient_feats=0
Loading best model: best_loss 0.10535395583685707 best_auroc tensor(0.8580) at epoch 31


DataParallel(
  (module): CTN(
    (encoder): Sequential(
      (0): Conv1d(1, 128, kernel_size=(14,), stride=(3,), padding=(2,), bias=False)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv1d(128, 256, kernel_size=(14,), stride=(3,), bias=False)
      (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv1d(256, 256, kernel_size=(10,), stride=(2,), bias=False)
      (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
      (9): Conv1d(256, 256, kernel_size=(10,), stride=(2,), bias=False)
      (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU(inplace=True)
      (12): Conv1d(256, 256, kernel_size=(10,), stride=(1,), bias=False)
      (13): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_runnin

### Run Model Inference

Create examples input of shape `(num_samples, num_channels, ecg_length)`

- `num_samples`: Size of your mini-batch
- `num_channels`: Number of ECG channels. Most PRNA models were trained using a single channel, so this would be fixed to `1`
- `ecg_length`: Length of the ECG waveform. e.g. if 15 sec ECG at 500 Hz => length would be 7500

In [35]:
input_tensor = torch.zeros((1, 1, 7500))
print(f"Size of input {input_tensor.shape}")

Size of input torch.Size([1, 1, 7500])


---

Run model inference and print the probability of each output class:

In [36]:
probs = model(input_tensor, None).sigmoid().cpu().detach().numpy().tolist()

print(f"Length of output {len(probs)}")
print("---")

# Get the first sample and enumerate each output feature
probs = probs[0]
for i, prob in enumerate(probs):
    concept = classes[i]
    desc = concept_to_desc[concept]
    print(f"   {desc}: feature strength = {round(prob, 6)}")

Length of output 1
---
   Pacing rhythm: feature strength = 6.5e-05
   Prolonged QT interval: feature strength = 8.4e-05
   atrial fibrillation: feature strength = 0.044777
   atrial flutter: feature strength = 0.002004
   left bundle branch block: feature strength = 0.007083
   Q wave abnormal: feature strength = 7.5e-05
   T wave abnormal: feature strength = 0.000572
   Prolonged PR interval: feature strength = 0.0
   low QRS voltage: feature strength = 0.000371
   first degree atrioventricular block: feature strength = 0.007279
   Premature atrial contraction: feature strength = 0.046835
   left axis deviation: feature strength = 3.5e-05
   Sinus bradycardia: feature strength = 0.000714
   bradycardia: feature strength = 0.010025
   Sinus rhythm: feature strength = 0.030146
   Sinus tachycardia: feature strength = 0.021804
   Premature ventricular contractions: feature strength = 0.017828
   Sinus arrhythmia: feature strength = 0.000154
   left anterior fascicular block: feature str

## ECG Embeddings

This is identical to the cardiac abnormality class prediction (i.e. uses the same model), except we remove the last fully-connected layer of the model.

### Load Model

- `embedding_size` controls the size of the fully-connected layers and hence output embeddings. Refer to the pre-trained model chart.
- `remove_last_layer` controls whether the last layer should be removed. If removed, we are outputting the embeddings instead of the cardiac abnormality class.
- `model_path` specifies where the pre-trained models are stored.

In [27]:
embedding_size = 64
remove_last_layer = True
model_path = "/deep/group/ed-monitor/models/prna/outputs-wide-64-15sec-bs64/saved_models/ctn/fold_1/ctn.tar"

model = load_best_model(model_path, deepfeat_sz=embedding_size, remove_last_layer=remove_last_layer)
model.eval()

deepfeat_sz=64, nb_patient_feats=0
Loading best model: best_loss 0.10535395583685707 best_auroc tensor(0.8580) at epoch 31


Sequential(
  (0): Sequential(
    (0): Conv1d(1, 128, kernel_size=(14,), stride=(3,), padding=(2,), bias=False)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv1d(128, 256, kernel_size=(14,), stride=(3,), bias=False)
    (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv1d(256, 256, kernel_size=(10,), stride=(2,), bias=False)
    (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): Conv1d(256, 256, kernel_size=(10,), stride=(2,), bias=False)
    (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): Conv1d(256, 256, kernel_size=(10,), stride=(1,), bias=False)
    (13): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU(inplace=True)
    (15): Co

### Run Model Inference

Create examples input of shape `(num_samples, num_channels, ecg_length)`

- `num_samples`: Size of your mini-batch
- `num_channels`: Number of ECG channels. Most PRNA models were trained using a single channel, so this would be fixed to `1`
- `ecg_length`: Length of the ECG waveform. e.g. if 15 sec ECG at 500 Hz => length would be 7500

In [28]:
input_tensor = torch.zeros((1, 1, 7500))
print(f"Size of input {input_tensor.shape}")

Size of input torch.Size([1, 1, 7500])


---

Run model inference and print the probability of each output class:

In [33]:
embedding = model(input_tensor).sigmoid().cpu().detach().numpy()

print(f"Length of embedding {len(embedding)}")
print("---")
print(f"Embedding shape: {embedding[0].shape}")
print(f"Embedding: {embedding[0]}")

Length of embedding 1
---
Embedding shape: (64,)
Embedding: [0.07299398 0.15182257 0.01397314 0.17013563 0.3177759  0.69479364
 0.83668965 0.09662822 0.3713273  0.77085304 0.4777339  0.85967314
 0.8841248  0.45576975 0.5174365  0.21101093 0.29974195 0.98962945
 0.7486474  0.16653965 0.686424   0.13565567 0.33195427 0.15843181
 0.12933792 0.6975239  0.0405843  0.02273056 0.6107944  0.23825002
 0.19473583 0.39380682 0.7354453  0.3443672  0.23817264 0.10366894
 0.7796624  0.28014058 0.7583948  0.45048335 0.24592234 0.7468526
 0.03179573 0.23923036 0.99203616 0.30388948 0.843778   0.6548131
 0.82875854 0.92582357 0.09824307 0.37577543 0.16038077 0.03856083
 0.9765047  0.4893258  0.09355064 0.3517185  0.69492686 0.1309337
 0.9475582  0.49119705 0.72859776 0.17606543]
