# Introduction

We're going to train a model to detect anomalous electrocardiogram (ECG) signals.

# Dataset


[The PTB Diagnostic ECG Database](https://www.physionet.org/physiobank/database/ptbdb/)

> This dataset has been used in exploring heartbeat classification using deep neural network architectures, and observing some of the capabilities of transfer learning on it. The signals correspond to electrocardiogram (ECG) shapes of heartbeats for the normal case and the cases affected by different arrhythmias and myocardial infarction. These signals are preprocessed and segmented, with each segment corresponding to a heartbeat.
- Number of Samples: 14552
- Number of Categories: 2
- Sampling Frequency: 125Hz
- Data Source: Physionet's PTB Diagnostic Database

**Note:** _All the samples are cropped, downsampled and padded with zeroes if necessary to the fixed dimension of 188._

In [None]:
import altair as alt
import matplotlib as plt
import numpy as np
import polars as pl
import random
import torch
import torcheval
from torcheval.metrics import BinaryConfusionMatrix, BinaryF1Score, BinaryPrecision, BinaryRecall, BinaryAccuracy

from collections import OrderedDict
from datetime import datetime, timedelta
from pathlib import Path
from tqdm.notebook import tqdm
from ecg import DATA_DIR

alt.data_transformers.disable_max_rows()

In [None]:
# The datasets have been downloaded, converted to .parquet, and moved to the `data/` directory
train_file = "mitbih_train.parquet"
test_file = "mitbih_test.parquet"

train_df = pl.read_parquet(DATA_DIR / train_file)
test_df = pl.read_parquet(DATA_DIR / test_file)

In [None]:
train_df.tail()

In [None]:
test_df.shape

In [None]:
abnormal_file = "ptbdb_abnormal.parquet"
normal_file = "ptbdb_normal.parquet"

abnormal_df = pl.read_parquet(DATA_DIR / abnormal_file)
normal_df = pl.read_parquet(DATA_DIR / normal_file)

## EDA

In [None]:
abnormal_df.shape, normal_df.shape

In [None]:
# normal_df = normal_df.with_columns(pl.Series("target", ["normal"] * len(normal_df)))
# abnormal_df = abnormal_df.with_columns(pl.Series("target", ["abnormal"] * len(normal_df)))


df = pl.concat(
    [
        normal_df.with_columns(pl.Series("class", ["normal"] * normal_df.shape[0])),
        abnormal_df.with_columns(pl.Series("class", ["abnormal"] * abnormal_df.shape[0]))
    ],
    how="vertical"
)

In [None]:
def plot_samples(df: pl.DataFrame, class_name: str, samples: int = 100, opacity: float = 0.05) -> alt.Chart:
    data = (
        df.filter(
            pl.col("class") == class_name
        )
        .drop(["class", "target"])
        .sample(n=samples) # we don't seed so that we can see different samples
        .transpose()
    )
    plot_df = (
        pl.concat([
            data
            .select(col)
            .rename({col: "signal"})
            .with_columns(
                pl.Series("case", [i] * data.shape[0]),
                pl.Series("measurement", list(range(data.shape[0]))),
                pl.Series("color", [0] * data.shape[0]),
            )
            for i, col in enumerate(data.columns)
        ], 
        how="vertical")
        .with_row_index()
    )
    
    return alt.Chart(plot_df, title=class_name).mark_line().encode(
        x=alt.X("measurement", title="measurement"),
        y=alt.Y("signal", title="signal", scale=alt.Scale(domain=[0.0, 1.2])),
        color=alt.Color("color:N", title=None),
        detail="case",
        opacity=alt.value(opacity),
    ).properties(height=300, width=400)

In [None]:
plot_samples(df, "normal", 100, 0.1) | plot_samples(df, "abnormal", 100, 0.1) 

We can see that signals for both the "normal" cases and "abnormal" are between [0.0, 1.0]  and all cases start at or near 1.0.  
From the 75th measurement (0.6 seconds from the start) the signals can spike to near 1.0 for a couple time steps before dropping abck to nominal levels.
Some cases drop to 0.0 from about the 100th measurement (0.8 seconds from the start), this just means the signal ended early and has been padded with zeros until the 188th measurement.

The normal cases have a spike in the signal at about the 35th measurement (0.28 seconds from the start) spread across ~30 measurements (0.24 seconds).  Those that don't end early tend to have another spike at the end spread across ~30 measurements.

The abnormal cases are more variable, especially around where the first spike occurs in the normal cases.  The nominal level for each case is also more variable.

Let's plot the mean and standard deviation across each of the classes.

In [None]:
def plot_rolling_mean(df: pl.DataFrame, class_name: str, window: int = 5) -> alt.Chart:
    period = f"{window}i"
    
    data = df.filter(pl.col("class") == class_name).drop("class")
    mean = data.mean()
    std = data.std()
    
    rolling_mean_df = (
        mean.drop("target").transpose()
        .with_row_index()
        .rolling("index", period=period)
        .agg([
            pl.col(pl.Float64).mean()
        ])
        .drop("index")
    )
    
    rolling_std_df = (
        std.drop("target").transpose()
        .with_row_index()
        .rolling("index", period=period)
        .agg([
            pl.col(pl.Float64).mean()
        ])
        .drop("index")
    )
    margin = rolling_std_df * 2

    # signals are always positive so we clip the low
    lower_bound = (rolling_mean_df - margin).with_columns(pl.col(pl.Float64).clip(0.0, 1.0))
    upper_bound = (rolling_mean_df + margin).with_columns(pl.col(pl.Float64).clip(0.0, 1.0))

    plot_df = pl.concat([
        rolling_mean_df.rename({"column_0": "mean"}),
        lower_bound.rename({"column_0": "lower"}),
        upper_bound.rename({"column_0": "upper"})
    ], how="horizontal").with_row_index()
    
    line = alt.Chart(plot_df).mark_line().encode(
        x=alt.X("index", title="measurement"),
        y=alt.Y("mean", title="mean signal")
    )
    
    band = alt.Chart(plot_df, title=class_name).mark_area().encode(
        x=alt.X("index", title="measurement"),
        y=alt.Y("lower", scale=alt.Scale(domain=[-0.1, 1.1])),
        y2=alt.Y2("upper"),
        opacity=alt.value(0.25),
    ).properties(height=300, width=400)
    
    return band + line

In [None]:
plot_rolling_mean(df, "normal", 5) | plot_rolling_mean(df, "abnormal", 5)

Here we've calculated the mean and standard deviation for each class at each measurement then taken a rolling average over the given window to smooth out the curve.
The doubled standard deviation is shown as the upper and lower bounds, clipped to the range [1.0, 0.0].

In the normal cases, we see a distinct spike at the 30th measurement and another more spread out spike about the 110th measurement.  The nominal signal level is ~0.2.  The signals end at about the 120th measurement (nominal level starts to decrease).

In the abnormal cases, we see the same spikes but with less prominence, meaning when the spike occurs is variable.  The second spike is notably shifted forward to about the 90th measurement.  Overall, the signal has a much higher spread and the nominal signal level is ~0.25-0.3.   The signals tend to end 20 measurements earlier than normal signals, at the 100th measurement (nominal level starts to decrease).

## Modelling

In [None]:
class ECGDataset(torch.utils.data.Dataset):   
    """Dataset to sample cases from the training dataset.

    TODO: randomly sample variable lengths
    """
    def __init__(self, X: pl.DataFrame, y: pl.Series) -> None:
        self.data = X
        self.target = y
        self.X = X.to_torch().to(torch.float32)
        self.y = y.to_torch().to(torch.float32)

    def __len__(self) -> None:
        return len(self.X)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        return self.X[idx], self.y[idx]

    def get_row(self, idx: int) -> pl.DataFrame:
        return pl.concat([self.target, self.data], how="horizontal")

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class ConvAutoEncoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int) -> None:
        super(ConvAutoEncoder, self).__init__()
        self.input_dim = input_dim # (188)
        self.hidden_dim = hidden_dim

        # Building an linear encoder with Linear layer followed by Relu activation function
        # 188 ==> 16
        self.encoder = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(self.input_dim, 128)),
            ('relu1', nn.ReLU()),
            ('linear2', nn.Linear(128, 64)),
            ('relu2', nn.ReLU()),
            ('linear3', nn.Linear(64, 32)),
            ('relu3', nn.ReLU()),
            ('linear4', nn.Linear(32, 16)),
        ]))

        # Building an linear decoder with Linear layer followed by Relu activation function
        # The Sigmoid activation function outputs the value between 0 and 1
        # 16 ==> 188
        self.decoder = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(16, 32)),
            ('relu1', nn.ReLU()),
            ('linear2', nn.Linear(32, 64)),
            ('relu2', nn.ReLU()),
            ('linear3', nn.Linear(64, 128)),
            ('relu3', nn.ReLU()),
            ('linear4', nn.Linear(128, self.input_dim)),
            ("activation", torch.nn.Sigmoid())
        ]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        encoded = self.encoder(x) 
        decoded = self.decoder(encoded)
        return decoded
        
class ConvAutoEncoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, kernel_size: int, stride: int) -> None:
        super(ConvAutoEncoder, self).__init__()
        self.input_dim = input_dim # (188)
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.stride = 1

        # Building an linear encoder with Linear layer followed by Relu activation function
        # 188 ==> 16

        assert (self.kernel_size % 2 != 0) # and (stride == 1)
        pool_padding = (kernel_size - 1) // 2
        
        self.encoder = nn.Sequential(OrderedDict([
            # Conv1d Layer 1: Input Channels = 1, Output Channels = hidden_features
            ('conv1', nn.Conv1d(1, self.hidden_dim, kernel_size=kernel_size, stride=self.stride)),
            ('relu1', nn.ReLU()),
            ('norm1', nn.BatchNorm1d(self.hidden_dim)),
            ('pool1', nn.MaxPool1d(kernel_size=kernel_size, stride=1, padding=pool_padding)),

            # # Conv1d Layer 2: Input Channels = hidden_features, Output Channels = hidden_features
            # ('conv2', nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=kernel_size, stride=self.stride)),
            # ('relu2', nn.ReLU()),
            # ('norm2', nn.BatchNorm1d(self.hidden_dim)),
            # ('pool2', nn.MaxPool1d(kernel_size=kernel_size, stride=1, padding=pool_padding)),

            # # Conv1d Layer 2: Input Channels = hidden_features, Output Channels = hidden_features
            # ('conv3', nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=kernel_size, stride=self.stride)),
            # ('relu3', nn.ReLU()),
            # ('norm3', nn.BatchNorm1d(self.hidden_dim)),
            # ('pool3', nn.MaxPool1d(kernel_size=kernel_size, stride=1, padding=pool_padding)),
        ]))

        # Classifier - might need to add more units
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.hidden_dim * 179, 1),
            # nn.Dropout(0.2),
            nn.Sigmoid(),
        )

        # Building an linear decoder with Linear layer followed by Relu activation function
        # The Sigmoid activation function outputs the value between 0 and 1
        # 16 ==> 188
        self.decoder = nn.Sequential(OrderedDict([
            # # Conv1d Layer 1: Input Channels = hidden_features, Output Channels = hidden_features
            # ('conv1', nn.ConvTranspose1d(self.hidden_dim, self.hidden_dim, kernel_size=kernel_size, stride=self.stride)),
            # ('relu1', nn.ReLU()),
            # ('norm1', nn.BatchNorm1d(self.hidden_dim)),

            # # Conv1d Layer 1: Input Channels = hidden_features, Output Channels = hidden_features
            # ('conv2', nn.ConvTranspose1d(self.hidden_dim, self.hidden_dim, kernel_size=kernel_size, stride=self.stride)),
            # ('relu2', nn.ReLU()),
            # ('norm2', nn.BatchNorm1d(self.hidden_dim)),

            # Conv1d Layer 1: Input Channels = hidden_features, Output Channels = 1
            ('conv3', nn.ConvTranspose1d(self.hidden_dim, 1, kernel_size=kernel_size, stride=self.stride)),
            ('relu3', nn.ReLU()),
            ('norm3', nn.BatchNorm1d(1)),
        ]))

    def forward(self, x: torch.Tensor):
        # our input is (n_batch, length_sequence) we need to add a dimension for n_channels (n_batch, n_channels, length_sequence)
        # for the Conv1D layers
        encoded = self.encoder(x[:, None, :]) 
        # print(encoded.shape)
        pred = self.classifier(encoded)
        # print(pred.shape)
        decoded = self.decoder(encoded)
        return pred, decoded.squeeze()

In [None]:
dataset = ECGDataset(df.drop(["target", "class"]), df.select("target"))
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.1, 0.1])

# hyperparameters
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)

epochs = 10
input_dim = dataset[0][0].shape[0]
hidden_dim = 128
batch_size = 32
kernel_size = 9
stride = 1

model = ConvAutoEncoder(input_dim, hidden_dim, kernel_size, stride)
# model = torch.compile(model)

# Validation using MSE Loss function
mse_loss_function = torch.nn.MSELoss()
bce_loss_function = torch.nn.BCELoss()
 
# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(
    model.parameters(),
    lr = 1e-1,
    weight_decay = 1e-8
)
# scheduler = torch.optim.lr_scheduler.StepLR(
#     optimizer, 
#     step_size=4, 
#     gamma=0.1, 
# )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode="min",
    patience=2
)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

gen = torch.Generator()
gen.manual_seed(0)

train_dataloader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    shuffle = True,
    worker_init_fn=seed_worker,
    generator=gen,
    # num_workers=4,
    # persistent_workers=True
)

val_dataloader = torch.utils.data.DataLoader(
    dataset = val_dataset,
    batch_size = batch_size,
    shuffle = True,
    worker_init_fn=seed_worker,
    generator=gen,
    # num_workers=4,
    # persistent_workers=True
)

In [None]:
accuracy = BinaryAccuracy()
confusion_matrix = BinaryConfusionMatrix()
f1_score = BinaryF1Score()
precision = BinaryPrecision()
recall = BinaryRecall()

accuracy_history = []
confusion_matrix_history = []
f1_score_history = []
precision_history = []
recall_history = []

outputs = []
train_batch_losses = []
train_epoch_losses = []
val_batch_losses = []
val_epoch_losses = []

train_batch_mse_losses = []
train_epoch_mse_losses = []
train_batch_bce_losses = []
train_epoch_bce_losses = []

val_batch_mse_losses = []
val_epoch_mse_losses = []
val_batch_bce_losses = []
val_epoch_bce_losses = []

alpha_mse = 1.0
alpha_bce = 1.0

for epoch in (pbar := tqdm(range(epochs), desc=f"lr={scheduler.get_last_lr()[0]:0.6f}")):
    ####################
    # Train
    ####################
    model.train()

    train_epoch_loss = 0
    train_epoch_mse_loss = 0
    train_epoch_bce_loss = 0
    val_epoch_mse_loss = 0
    val_epoch_bce_loss = 0
    for i, batch in enumerate(train_dataloader):
        X, target = batch
        
        # Output of Autoencoder
        pred, reconstructed = model(X)
        
        # Calculating the loss function
        mse_loss = mse_loss_function(reconstructed, X)
        bce_loss = bce_loss_function(pred, target)
        loss = alpha_mse * mse_loss + alpha_bce * bce_loss
        
        optimizer.zero_grad() # The gradients are set to zero,
        loss.backward() # the gradients are computed and stored.
        optimizer.step() # .step() performs parameter update
        
        # Storing the losses in a list for plotting
        train_batch_losses.append(loss.item())
        train_epoch_loss += loss.item()
        train_batch_mse_losses.append(mse_loss.item())
        train_epoch_mse_loss += mse_loss.item()
        train_batch_bce_losses.append(bce_loss.item())
        train_epoch_bce_loss += bce_loss.item()

    # TODO: averaging this is wrong, should sum them average
    train_epoch_mean_loss = epoch_loss / i
    train_epoch_losses.append(train_epoch_mean_loss)
    train_mean_epoch_mse_loss = train_epoch_mse_loss / i
    train_epoch_mse_losses.append(train_mean_epoch_mse_loss)
    train_mean_epoch_bce_loss = train_epoch_bce_loss / i
    train_epoch_bce_losses.append(train_mean_epoch_bce_loss)
    
    ####################
    # Validation
    ####################
    model.eval()
    with torch.no_grad(): 
        val_epoch_loss = 0
        for j, batch in enumerate(train_dataloader):
            X, target = batch
            pred, reconstructed = model(X)            
            mse_loss = mse_loss_function(reconstructed, X)
            bce_loss = bce_loss_function(pred, target)

            val_epoch_loss = alpha_mse * mse_loss + alpha_bce * bce_loss
            val_batch_losses.append(val_epoch_loss.item())
            val_epoch_loss += val_epoch_loss.item()
            
            val_batch_mse_losses.append(mse_loss.item())
            val_epoch_mse_loss += mse_loss.item()
            val_batch_bce_losses.append(bce_loss.item())
            val_epoch_bce_loss += bce_loss.item()

            accuracy.update(pred.squeeze(), target.squeeze().int())
            confusion_matrix.update(pred.squeeze(), target.squeeze().int())
            f1_score.update(pred.squeeze(), target.squeeze().int())
            precision.update(pred.squeeze(), target.squeeze().int())
            recall.update(pred.squeeze(), target.squeeze().int())

        # TODO: averaging this is wrong, should sum them average
        val_mean_epoch_loss = val_epoch_loss / j
        val_mean_epoch_mse_loss = val_epoch_mse_loss / j
        val_mean_epoch_bce_loss = val_epoch_bce_loss / j
        val_epoch_losses.append(val_mean_epoch_loss)
        val_epoch_mse_losses.append(val_mean_epoch_mse_loss)
        val_epoch_bce_losses.append(val_mean_epoch_bce_loss)

        accuracy_history.append(accuracy.compute())
        confusion_matrix_history.append(accuracy.compute())
        f1_score_history.append(accuracy.compute())
        precision_history.append(accuracy.compute())
        recall_history.append(accuracy.compute())

    # scheduler.step()
    val_epoch_loss = alpha_mse * val_mean_epoch_mse_loss + alpha_bce * val_mean_epoch_bce_loss
    scheduler.step(val_epoch_loss) # update the learning rate if not learning
    lr = scheduler.get_last_lr()[0]
    pbar.set_description(f"lr={lr:0.6f}")
    print(f"{epoch} (lr={lr:0.6f}) - train loss: ({train_mean_epoch_mse_loss:0.6f}, {train_mean_epoch_bce_loss:0.6f}) | val loss: ({val_mean_epoch_mse_loss:0.6f}, {val_mean_epoch_bce_loss:0.6f})")
    
    # outputs.append((epochs, batch, reconstructed))


In [None]:
# Loss
batch_losses_df = pl.DataFrame({
    "batch": [x for x in range(0, len(train_batch_losses))] + [x for x in range(0, len(val_batch_losses))],
    "loss": train_batch_losses + val_batch_losses,
    "mode": ["train"] * len(train_batch_losses) + ["validation"] * len(val_batch_losses)
})

epoch_losses_df = pl.DataFrame({
    "epoch": [x for x in range(0, len(train_epoch_losses))] + [x for x in range(0, len(val_epoch_losses))],
    "loss": train_epoch_losses + val_epoch_losses,
    "mode": ["train"] * len(train_epoch_losses) + ["validation"] * len(val_epoch_losses)
})

batch_loss_chart = (
    alt.Chart(batch_losses_df.to_pandas(), title="Loss per Batch")
    .mark_line()
    .encode(
        x='batch:Q',
        y='loss:Q',
        color='mode:N',
    ).properties(height=150, width=400)
)

epoch_loss_chart = (
    alt.Chart(epoch_losses_df.to_pandas(), title="Average Loss per Epoch")
    .mark_line()
    .encode(
        x='epoch:Q',
        y='loss:Q',
        color='mode:N',
    ).properties(height=150, width=400)
)

# MSE Loss
batch_mse_losses_df = pl.DataFrame({
    "batch": [x for x in range(0, len(train_batch_mse_losses))] + [x for x in range(0, len(val_batch_mse_losses))],
    "loss": train_batch_mse_losses + val_batch_mse_losses,
    "mode": ["train"] * len(train_batch_mse_losses) + ["validation"] * len(val_batch_mse_losses)
})

epoch_mse_losses_df = pl.DataFrame({
    "epoch": [x for x in range(0, len(train_epoch_mse_losses))] + [x for x in range(0, len(val_epoch_mse_losses))],
    "loss": train_epoch_mse_losses + val_epoch_mse_losses,
    "mode": ["train"] * len(train_epoch_mse_losses) + ["validation"] * len(val_epoch_mse_losses)
})

batch_mse_loss_chart = (
    alt.Chart(batch_mse_losses_df.to_pandas(), title="MSE Loss per Batch")
    .mark_line()
    .encode(
        x='batch:Q',
        y='loss:Q',
        color='mode:N',
    ).properties(height=150, width=400)
)

epoch_mse_loss_chart = (
    alt.Chart(epoch_mse_losses_df.to_pandas(), title="Average MSE Loss per Epoch")
    .mark_line()
    .encode(
        x='epoch:Q',
        y='loss:Q',
        color='mode:N',
    ).properties(height=150, width=400)
)


# BCE Loss
batch_bce_losses_df = pl.DataFrame({
    "batch": [x for x in range(0, len(train_batch_bce_losses))] + [x for x in range(0, len(val_batch_bce_losses))],
    "loss": train_batch_bce_losses + val_batch_bce_losses,
    "mode": ["train"] * len(train_batch_bce_losses) + ["validation"] * len(val_batch_bce_losses)
})


epoch_bce_losses_df = pl.DataFrame({
    "epoch": [x for x in range(0, len(train_epoch_bce_losses))] + [x for x in range(0, len(val_epoch_bce_losses))],
    "loss": train_epoch_bce_losses + val_epoch_bce_losses,
    "mode": ["train"] * len(train_epoch_bce_losses) + ["validation"] * len(val_epoch_bce_losses)
})

batch_bce_loss_chart = (
    alt.Chart(batch_bce_losses_df.to_pandas(), title="BCE Loss per Batch")
    .mark_line()
    .encode(
        x='batch:Q',
        y='loss:Q',
        color='mode:N',
    ).properties(height=150, width=400)
)

epoch_bce_loss_chart = (
    alt.Chart(epoch_bce_losses_df.to_pandas(), title="Average BCE Loss per Epoch")
    .mark_line()
    .encode(
        x='epoch:Q',
        y='loss:Q',
        color='mode:N',
    ).properties(height=150, width=400)
)

(batch_loss_chart | epoch_loss_chart) & (batch_mse_loss_chart | epoch_mse_loss_chart) & (batch_bce_loss_chart | epoch_bce_loss_chart)

# Scratch Pad

In [None]:
pred.to(torch.int).squeeze()

In [None]:
# [nBatch, nChannels, length]
hidden_features = 128
# stride need to be 1 and kernel_size uneven for the padding to work and dimensions align
kernel_size = 9

# encode
print("Encoding:")
m = nn.Conv1d(1, hidden_features, kernel_size=kernel_size, stride=1)
input = torch.randn(32, 1, 187)
output = m(input)
print("Conv:     ", input.shape, output.shape)

norm_input = output
norm = nn.BatchNorm1d(hidden_features)
norm_output = norm(norm_input)
print("Norm:     ", norm_input.shape, norm_output.shape)

pool_input = norm_output
pool = nn.MaxPool1d(kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2) # (kernel_size-1)//2 - https://stackoverflow.com/a/71022586
pool_output = pool(pool_input)
print("Pool:     ", pool_input.shape, pool_output.shape)


# decode
print("\nDecoding:")
conv_input = pool_output
conv_input = norm_output
conv = nn.ConvTranspose1d(hidden_features, 1, kernel_size=kernel_size, stride=1)
conv_output = conv(conv_input)
print("ConvTrans:", conv_input.shape, conv_output.shape)

norm_input = conv_output
norm = nn.BatchNorm1d(1)
norm_output = norm(norm_input)
print("Norm:     ", norm_input.shape, norm_output.shape)


# linear_input = norm_output
# linear = nn.Linear(32, hidden_features, 187)
# linear_output = linear(linear_input)
# print("Linear:   ", linear_input.shape, linear_output.shape)

In [None]:
a = torch.zeros(4, 5, 6)
a = a[:, :, None, :]
a.shape, a.squeeze().shape

In [None]:
torch.randn(32, 187).view(32, 1, 187).shape

In [None]:
(kernel_size-1)//2

In [None]:
dataset = ECGDataset(df.drop(["target", "class"]), df.select("target"))
dataset.get_row(0)

In [None]:
m = nn.Softmax(dim=1)
input = torch.randn(5, 1)
output = m(input)

In [None]:
input, output

In [None]:
pred.int()

In [None]:
m = nn.Threshold(0.5, 0.0)
input = torch.randn(2)
output = m(input)

In [None]:
input, output

In [None]:
m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, 2, requires_grad=True)
target = torch.rand(3, 2, requires_grad=False)
output = loss(m(input), target)
output.backward()

In [None]:
m(input), target