In [1]:
%load_ext autoreload
%autoreload 2

import sys         
sys.path.append('./../../src/') 

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from tqdm.auto import tqdm
from joblib import Parallel, delayed

from data.ptbxl import PTBXLDataModule

In [2]:
ptbxl_datamodule = PTBXLDataModule(
    representation_type = 'per_beat_features',
    fs = 100,
    target = 'diagnostic_class',
    batch_size = 64,
    num_workers = 8
)
ptbxl_datamodule.setup()

In [3]:
class MLPClassifier(nn.Module):
    def __init__(self, in_dim=1000, n_classes=5):
        super().__init__()
        self.model = nn.Sequential(
            # nn.Linear(in_dim, in_dim // 2),
            # nn.ReLU(),
            # nn.Linear(in_dim // 2, in_dim // 4),
            # nn.ReLU(),
            # nn.Dropout(0.5),
            nn.Linear(in_dim, n_classes)
        )
        
    def forward(self, x):
        out = self.model(x)
        probs = F.softmax(out, dim=1)
        if torch.any(probs.isnan()) or torch.any(probs.isinf()):
            print(out)
        return probs
    
class LSTMClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_layers, bidirectional, n_classes=5):
        super().__init__()
        output_layers = 2 if bidirectional else 1
        self.net = nn.ModuleDict({
            'lstm': nn.LSTM(
                input_size=in_dim,
                hidden_size=hidden_dim,
                num_layers=num_layers,
                batch_first=True,
                bidirectional=bidirectional
            ),
            'dropout': nn.Dropout(p=0.5),
            'linear': nn.Linear(output_layers * hidden_dim, n_classes),
        })
        
    def forward(self, x):
        batch_size, n_channels, n_beats, n_feats = x.shape
        x = x.view(batch_size, n_beats, n_channels * n_feats)
        out, _ = self.net['lstm'](x)
        out = out[:, -1, :]
        # print(f'lstm output={out.size()}')

        # From [seq len, batch, num_directions * hidden_size]
        # to [batches, seqs, seq_len,prediction]
        # out = out.view(x_batches, x_seqs, x_seq_len, -1)
        # print(f'transformed output={out.size()}')

        # Data is fed to the Linear layer
        out = self.net['linear'](out)
        # print(f'linear output={out.size()}')

        # The prediction utilizing the whole sequence is the last one
        # y_pred = out[:, :, -1].unsqueeze(-1)
        # print(f'y_pred={y_pred.size()}')
        probs = F.softmax(out, dim=1)
        return probs

In [4]:
from evaluation.metrics import get_classification_metrics

class PTBXLWaveFormClassifier(LightningModule):
    def __init__(self, classifier: nn.Module, learning_rate: float = 1e-3):
        super().__init__()
        self.classifier = classifier
        self.learning_rate = learning_rate
        self.save_hyperparameters('learning_rate')
    
    def forward(self, x):
        return self.classifier(x)

    def _common_step(self, batch, batch_ids, stage, log=True):
        x, labels = batch
        probs = self(x)
        log_probs = torch.log(probs)
        preds = log_probs.argmax(axis=1)
        if log:
            y_pred_proba = probs.detach().numpy()
            metrics = get_classification_metrics(y_pred_proba, labels, auc=True)
            for metric, val in metrics.items():
                self.log(f"{stage}/{metric}", val, on_step=False, on_epoch=True)
        return labels, probs, log_probs, preds
        
    def training_step(self, batch, batch_idx):
        labels, probs, log_probs, preds = self._common_step(batch, batch_idx, 'train', log=True)
        loss = F.nll_loss(log_probs, labels)
        return loss
    
    def validation_step(self, batch, batch_idx):
        labels, probs, log_probs, preds = self._common_step(batch, batch_idx, 'val', log=True)
        loss = F.nll_loss(log_probs, labels)
        return log_probs
    
    def predict_step(self, batch, batch_idx):
        _, _, _, preds = self._common_step(batch, batch_idx, 'predict', log=False)
        return preds

    def test_step(self, batch, batch_idx):
        _, _, _, preds = self._common_step(batch, batch_idx, 'test', log=True)
        return preds
    
    def test_epoch_end(self, test_step_outputs):
        pass
        # dummy_input = torch.zeros((1, self.hparams["in_dims"]), device=self.device)
        # model_filename = "model_final.onnx"
        # self.to_onnx(model_filename, dummy_input, export_params=True)
        # wandb.save(model_filename)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=1e-3)

In [5]:
classifier = LSTMClassifier(
    in_dim = 600, # 12 ECG channels, 50 features per beat
    hidden_dim = 100, 
    num_layers = 2, 
    bidirectional = False, 
    n_classes = 5
)
model = PTBXLWaveFormClassifier(classifier, learning_rate=1e-3)
trainer = Trainer()
trainer.fit(model, ptbxl_datamodule)

  rank_zero_warn(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name       | Type           | Params
----------------------------------------------
0 | classifier | LSTMClassifier | 362 K 
----------------------------------------------
362 K     Trainable params
0         Non-trainable params
362 K     Total params
1.448     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [6]:
trainer.test(model, ptbxl_datamodule)

Testing: 0it [00:00, ?it/s]

[{'test/fscore': 0.45465998624474696,
  'test/acc': 0.6537530266343826,
  'test/auc': 0.8033713485644808}]