In [1]:
%load_ext autoreload
%autoreload 2

from msr.training.data.datamodules import PtbXLDataModule
from msr.training.trainers import MLClassifierTrainer, MLRegressorTrainer
from msr.data.download.ptbxl import FS
from msr.evaluation.plotters import MatplotlibPlotter, PlotlyPlotter, plot_classifier_evaluation, BasePlotter
from msr.evaluation.loggers import MLWandbLogger

In [10]:
rep_type = "whole_signal_features"
TARGET = "diagnostic_class"
BASE_PARAMS = dict(fs=FS, target=TARGET)

In [6]:
from typing import Any, Dict, List, Optional, Literal, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from msr.evaluation.metrics import get_classification_metrics
from pytorch_lightning.loggers import WandbLogger
from abc import abstractmethod
from functools import partial

loss_functions = {
    "mae": nn.L1Loss(), 
    "mse": nn.MSELoss(), 
    "negative_log_likelihood": nn.NLLLoss(), 
    "cross_entropy": nn.CrossEntropyLoss(),
    "hinge_embedding": nn.HingeEmbeddingLoss()
}
    
        
class BaseTask:
    @abstractmethod
    def get_metrics(self):
        pass
    
    @abstractmethod
    def plot_evaluation(
        self, 
        y_values: Dict[str, Tuple[np.ndarray, np.ndarray]], 
        metrics: Dict[str, float], 
        plotter: BasePlotter,
        feature_importances: List[float] = None
    ):
        pass
   

class Classifier:
    def get_metrics(self, preds, target):
        return get_classification_metrics(num_clases=self.datamodule.num_classes, preds=preds, target=target)
    
    def plot_evaluation(
        self, 
        y_values: Dict[str, Tuple[np.ndarray, np.ndarray]], 
        metrics: Dict[str, float], 
        plotter: BasePlotter,
        feature_importances: List[float] = None
    ):
        return plot_classifier_evaluation(
            y_values=y_values,
            metrics=metrics,
            class_names=self.datamodule.class_names,
            feature_names=self.feature_names,
            feature_importances=feature_importances,
            plotter=plotter
        )
    
    
class Regressor:  
    def get_metrics(self, preds, target):
        return get_regression_metrics(preds=preds, target=target)
        
    def plot_evaluation(
        self, 
        y_values: Dict[str, Tuple[np.ndarray, np.ndarray]], 
        metrics: Dict[str, float], 
        plotter: BasePlotter,
        feature_importances: List[float] = None
    ):
        return plot_regressor_evaluation(
            y_values=y_values,
            metrics=metrics,
            feature_names=self.feature_names,
            feature_importances=feature_importances,
            plotter=plotter
        )
    
    
    
    
    
class BaseTrainer:
    def __init__(self, model, datamodule):
        self.model = model
        self.datamodule = datamodule
        self.feature_names = datamodule.feature_names
   
    @abstractmethod
    def fit(self):
        pass
    
    @abstractmethod
    def predict(self, data):
        pass
    
    @abstractmethod
    def evaluate(self, plotter: BasePlotter = None, logger: MLWandbLogger = None):
        all_y_values = {
            # "train": {"preds": self.train(), "target": self.datamodule.train.targets},
            "val": {
                "preds": self.predict(self.datamodule.val.data), 
                "target": self.datamodule.val.targets
            },
            "test": {
                "preds": self.predict(self.datamodule.test.data), 
                "target": self.datamodule.test.targets
            },
        }

        metrics = {split: self.get_metrics(**y_values) for split, y_values in all_y_values.items()}
        evaluation_results = {
            "metrics": pd.json_normalize(metrics, sep="/").to_dict(orient="records")[0]  # flattened dict
        }
        if plotter is not None:
            evaluation_results["figs"] = self.plot_evaluation(all_y_values, metrics, plotter)
        if logger is not None:
            for name, results in evaluation_results.items():
                blacklist = ["/roc"]
                filtered_results = {
                    name: value for name, value in results.items() if all([key not in name for key in blacklist])
                }
                logger.log(filtered_results)
            logger.finish()
        return evaluation_results    
    

    
    
    
class DLTrainer(BaseTrainer):
    def __init__(self, trainer: pl.Trainer, model: nn.Module, datamodule: pl.LightningDataModule):
        self.trainer = trainer
        super().__init__(model, datamodule)

    def fit(self):
        self.trainer.fit(self.model, self.datamodule)

    def predict(self, data):
        return self.model(data) 
    
    
class DLClassifierTrainer(DLTrainer, Classifier):
    def __init__(self, trainer: pl.Trainer, model: nn.Module, datamodule: pl.LightningDataModule):
        super().__init__(trainer, model, datamodule)
        
        
class DLRegressorTrainer(DLTrainer, Regressor):
    def __init__(self, trainer: pl.Trainer, model: nn.Module, datamodule: pl.LightningDataModule):
        super().__init__(trainer, model, datamodule)
        

        
        
class MLTrainer(BaseTrainer):
    def fit(self):
        self.model.fit(X=self.datamodule.train.data.numpy(), y=self.datamodule.train.targets)

    def predict(self, X):
        return self.model.predict(X)
    

class MLClassifierTrainer(MLTrainer, Classifier):
    def __init__(self, model, datamodule: pl.LightningDataModule):
        super().__init__(model, datamodule)
        
    def predict(self, X):
        return self.model.predict_proba(X)


class MLRegressorTrainer(MLTrainer, Regressor):
    def __init__(self, model: nn.Module, datamodule: pl.LightningDataModule):
        super().__init__(model, datamodule)        

In [7]:
class ClassifierModule(pl.LightningModule):
    def __init__(
        self,
        net: nn.Module,
        loss_metric: Literal["mse", "mae"] = "mse",
    ):
        super().__init__()
        self.net = net
        self.save_hyperparameters(logger=False, ignore=['net'])
        self.criterion = loss_functions[loss_metric]
        self.get_metrics = partial(get_classification_metrics, num_clases=self.net.num_classes)

    def forward(self, x):
        return self.net(x)
    
    def _common_step(self, batch, batch_idx: int, stage: str):
        data, target = batch
        preds = self.forward(data)
        loss = self.criterion(preds, target)
        return {"loss": loss, "preds": preds, "target": target}
        
    def training_step(self, batch, batch_idx: int):
        return self._common_step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx: int):
        return self._common_step(batch, batch_idx, "val")
        
    def test_step(self, batch, batch_idx: int):
        return self._common_step(batch, batch_idx, "test")

    def _common_epoch_end(self, outputs, stage: str):
        loss = torch.tensor([output["loss"] for output in outputs]).mean()
        preds = torch.cat([output["preds"] for output in outputs], dim=0)
        target = torch.cat([output["target"] for output in outputs], dim=0)
        metrics = self.metrics.get_metrics(preds, target)
        metrics["loss"] = loss
        metrics = {f"{stage}/{name}": value for name, value in metrics.items()}
        results = {
            "metrics": metrics,
            "y_values": {"preds": preds, "target": target}
        }
        if self.trainer.sanity_checking or self.trainer.testing:
            return results
        self.log(f"{stage}/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=False)
        self.logger.log_metrics(metrics, step=self.current_epoch)
        return results

    def training_epoch_end(self, outputs: List):
        results = self._common_epoch_end(outputs, "train")
        trainer_metrics = {
            "epoch": self.current_epoch,
            "learning_rate": self.optimizers().param_groups[0]["lr"],
        }
        self.logger.log_metrics(trainer_metrics, step=self.current_epoch)
        

    def validation_epoch_end(self, outputs: List):
        results = self._common_epoch_end(outputs, "val")
        

    def test_epoch_end(self, outputs: Dict):
        results = self._common_epoch_end(outputs, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            params=self.parameters(),
            lr=0.001,
            betas=(0.9, 0.999),
            weight_decay=0.01,
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.1, 
            patience=10, 
            threshold=0.0001,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler, 
                "monitor": "val/loss",
                "interval": "epoch",
                "frequency": 1,
            }
        }

In [8]:
class MLPClassifier(nn.Module):
    def __init__(self, in_dim=1000, num_classes=5):
        super().__init__()
        self.num_classes = num_classes
        self.model = nn.Sequential(
            nn.Linear(in_dim, num_classes)
        )
        
    def forward(self, x):
        out = self.model(x)
        probs = F.softmax(out, dim=1)
        return probs

---
# **ML**

In [11]:
from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier()
datamodule = PtbXLDataModule(rep_type, **BASE_PARAMS)
datamodule.setup()

ptbxl_ml_trainer = MLClassifierTrainer(model, datamodule)

In [12]:
ptbxl_ml_trainer.fit()

In [56]:
net = MLPClassifier(
    in_dim=dm.feature_names.__len__(),
    num_classes=dm.class_names.__len__()
)

model = ClassifierModule(net, "negative_log_likelihood")

In [57]:
logger = WandbLogger(project='ptbxl', name='DL')

trainer = pl.Trainer(
    logger=logger,
    accelerator="auto"
)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [58]:
ptbxl_trainer = DLClassifierTrainer(trainer, model, dm)