In [1]:
%load_ext autoreload
%autoreload 2

from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
import torch
from funcs import run_all_ml_experiments, run_ml_experiment, run_dl_experiment
from msr.training.data.transforms import Flatten, Permute

from msr.models.modules import ClassifierModule
from msr.models.architectures.networks.mlp import MLPClassifier
from msr.models.architectures.networks.cnn import CNNClassifier

In [59]:
ml_transform = Flatten(start_dim=0, end_dim=-1)

In [2]:
def rf_model_provider():
    return RandomForestClassifier(n_jobs=-1)

def lr_model_provider():
    return LogisticRegression(n_jobs=-1)

def mlp_model_provider(num_classes, input_shape):
    input_size = input_shape[0]
    net = MLPClassifier(
        num_classes=num_classes, 
        input_size=input_size,
        hidden_dims=[input_size // (2**i) for i in range(1, 4)]
    )
    return ClassifierModule(net=net)

def cnn_model_provider(num_classes, input_shape):
    net = CNNClassifier(
        num_classes=num_classes, 
        dim=1,
        in_channels=12,
        out_channels=[4],
        kernel_size=[11],
        maxpool_kernel_size=[2]
    )
    return ClassifierModule(net=net)

# **Random Forest**

In [53]:
rf_results = run_all_ml_experiments(rf_model_provider)

A Jupyter Widget

whole_signal_waveforms 0.7158476710319519
whole_signal_features 0.8975158929824829
agg_beat_waveforms 0.8873464465141296
agg_beat_features 0.9018481969833374


# **Logistic Regression**

In [54]:
lr_results = run_all_ml_experiments(lr_model_provider)

A Jupyter Widget

whole_signal_waveforms 0.49
whole_signal_features 0.84
agg_beat_waveforms 0.85
agg_beat_features 0.87


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

# **MLP**

In [84]:
mlp_results = run_dl_experiment("agg_beat_features", mlp_model_provider, ml_transform, ml_transform)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type          | Params
--------------------------------------------
0 | net       | MLPClassifier | 577 K 
1 | criterion | NLLLoss       | 0     
--------------------------------------------
577 K     Trainable params
0         Non-trainable params
577 K     Total params
2.312     Total estimated model params size (MB)


A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

`Trainer.fit` stopped: `max_epochs=10` reached.


# **CNN**

In [3]:
cnn_transform = Permute(dims=(1, 0))
cnn_results = run_dl_experiment("agg_beat_waveforms", cnn_model_provider, cnn_transform, cnn_transform)
cnn_results['metrics']['val/auroc']

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type          | Params
--------------------------------------------
0 | net       | CNNClassifier | 565   
1 | criterion | NLLLoss       | 0     
--------------------------------------------
565       Trainable params
0         Non-trainable params
565       Total params
0.002     Total estimated model params size (MB)


A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])


A Jupyter Widget

torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([

A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size

A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size

A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size

A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size

A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size

A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size

A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size

A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size

A Jupyter Widget

torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size([65, 12, 100])
torch.Size

A Jupyter Widget

`Trainer.fit` stopped: `max_epochs=10` reached.


torch.Size([821, 12, 100])
torch.Size([821, 12, 100])
torch.Size([1642, 12, 100])
torch.Size([1652, 12, 100])
