In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

%matplotlib inline
sns.set_style('darkgrid')

In [2]:
mit_train = pd.read_csv('/scratch/gilbreth/lu992/final_50024_project/ecg_data/mitdb_360_train.csv', header=None)
mit_test = pd.read_csv('/scratch/gilbreth/lu992/final_50024_project/ecg_data/mitdb_360_test.csv', header=None)

In [3]:
# Separate target from data
y_train = mit_train[360]
X_train = mit_train.loc[:, :359]

y_test = mit_test[360]
X_test = mit_test.loc[:, :359]

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

In [5]:
X_train, y_train, X_test, y_test = map(
    torch.from_numpy, 
    (X_train.values, y_train.values, X_test.values, y_test.values)
)

In [6]:
# Convert to 3D tensor
X_train = X_train.unsqueeze(1)
X_test = X_test.unsqueeze(1)

In [7]:
# Batch size
bs = 128

train_ds = TensorDataset(X_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

test_ds = TensorDataset(X_test, y_test)
test_dl = DataLoader(test_ds, batch_size=bs * 2)

In [9]:
!pip install torchdiffeq

Defaulting to user installation because normal site-packages is not writeable
Collecting torchdiffeq
  Downloading torchdiffeq-0.2.3-py3-none-any.whl (31 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.3


In [8]:
import time
from models import norm, ResBlock, ODEfunc, ODENet, Flatten, count_parameters

In [9]:
# Helpers adapted from https://pytorch.org/tutorials/beginner/nn_tutorial.html

def get_model(is_odenet=True, dim=64, adam=False, **kwargs):
    """
    Initialize ResNet or ODENet with optimizer.
    """
    downsampling_layers = [
        nn.Conv1d(1, dim, 3, 1),
        norm(dim),
        nn.ReLU(inplace=True),
        nn.Conv1d(dim, dim, 4, 2, 1),
        norm(dim),
        nn.ReLU(inplace=True),
        nn.Conv1d(dim, dim, 4, 2, 1)
    ]

    feature_layers = [ODENet(ODEfunc(dim), **kwargs)] if is_odenet else [ResBlock(dim) for _ in range(6)]

    fc_layers = [norm(dim), nn.ReLU(inplace=True), nn.AdaptiveAvgPool1d(1), Flatten(), nn.Linear(dim, 5)]

    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers)

    opt = optim.Adam(model.parameters()) if adam else optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    return model, opt

def loss_batch(model, loss_func, xb, yb, opt=None):
    """
    Calculate loss and update weights if training. Return loss, number of examples, and number of correct predictions.
    """
    output = model(xb.float())
    loss = loss_func(output, yb.long())
    preds = torch.argmax(output, dim=1)  # Get the index of the max log-probability
    correct = (preds == yb).float().sum()  # Count correct predictions

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb), correct.item()

In [10]:
import csv
import time

def fit(epochs, model, loss_func, opt, train_dl, valid_dl, csv_filename='training_log.csv'):
    """
    Train and evaluate the neural network model. Track and print training and validation loss and accuracy.
    Log these metrics to a CSV file.
    """
    # Open CSV file and set up CSV writer
    with open(csv_filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write header to CSV file
        writer.writerow(['Epoch', 'Train Loss', 'Train Accuracy', 'Validation Loss', 'Validation Accuracy'])
        
        for epoch in range(epochs):
            print(f"Training... epoch {epoch + 1}")
            
            model.train()  # Set model to training mode
            total_loss, total_correct, total = 0, 0, 0

            batch_count = 0
            start = time.time()
            for xb, yb in train_dl:
                loss, num, correct = loss_batch(model, loss_func, xb, yb, opt)
                total_loss += loss * num
                total_correct += correct
                total += num
                batch_count += 1
                curr_time = time.time()
                percent = round(batch_count/len(train_dl) * 100, 1)
                elapsed = round((curr_time - start)/60, 1)
                print(f"    Percent trained: {percent}%  Time elapsed: {elapsed} min", end='\r')

            train_loss = total_loss / total
            train_acc = total_correct / total
            
            model.eval()  # Set model to validation mode
            with torch.no_grad():
                val_losses, val_nums, val_corrects = zip(
                    *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
                )
            
            val_loss = sum(np.multiply(val_losses, val_nums)) / sum(val_nums)
            val_acc = sum(val_corrects) / sum(val_nums)

            # Print epoch summary
            print(f"\nEpoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                  f"Valid Loss: {val_loss:.4f}, Valid Acc: {val_acc:.4f}\n")

            # Write metrics to CSV file
            writer.writerow([epoch + 1, train_loss, train_acc, val_loss, val_acc])


In [11]:
odenet, odeopt = get_model(adam=False, rtol=1e-3, atol=1e-3)

In [12]:
resnet, resopt = get_model(is_odenet=False, adam=False)

In [14]:
fit(15, odenet, F.cross_entropy, odeopt, train_dl, test_dl, csv_filename='odenet_training_log.csv')

Training... epoch 1
    Percent trained: 100.0%  Time elapsed: 30.8 min
Epoch 1: Train Loss: 0.2960, Train Acc: 0.9145, Valid Loss: 0.9209, Valid Acc: 0.7537

Training... epoch 2
    Percent trained: 100.0%  Time elapsed: 37.1 min
Epoch 2: Train Loss: 0.1434, Train Acc: 0.9614, Valid Loss: 0.6880, Valid Acc: 0.8094

Training... epoch 3
    Percent trained: 100.0%  Time elapsed: 38.6 min
Epoch 3: Train Loss: 0.1041, Train Acc: 0.9708, Valid Loss: 0.6075, Valid Acc: 0.8343

Training... epoch 4
    Percent trained: 100.0%  Time elapsed: 42.4 min
Epoch 4: Train Loss: 0.0854, Train Acc: 0.9760, Valid Loss: 0.6123, Valid Acc: 0.8086

Training... epoch 5
    Percent trained: 100.0%  Time elapsed: 45.2 min
Epoch 5: Train Loss: 0.0715, Train Acc: 0.9800, Valid Loss: 0.4547, Valid Acc: 0.8691

Training... epoch 6
    Percent trained: 100.0%  Time elapsed: 47.0 min
Epoch 6: Train Loss: 0.0643, Train Acc: 0.9814, Valid Loss: 0.3805, Valid Acc: 0.8854

Training... epoch 7
    Percent trained: 100.0

In [15]:
fit(15, resnet, F.cross_entropy, resopt, train_dl, test_dl, csv_filename='resnet_training_log.csv')

Training... epoch 1
    Percent trained: 100.0%  Time elapsed: 2.9 min
Epoch 1: Train Loss: 0.3096, Train Acc: 0.9123, Valid Loss: 1.0560, Valid Acc: 0.7277

Training... epoch 2
    Percent trained: 100.0%  Time elapsed: 2.9 min
Epoch 2: Train Loss: 0.1383, Train Acc: 0.9622, Valid Loss: 0.6655, Valid Acc: 0.7783

Training... epoch 3
    Percent trained: 100.0%  Time elapsed: 2.9 min
Epoch 3: Train Loss: 0.0971, Train Acc: 0.9729, Valid Loss: 0.3477, Valid Acc: 0.8946

Training... epoch 4
    Percent trained: 100.0%  Time elapsed: 2.9 min
Epoch 4: Train Loss: 0.0779, Train Acc: 0.9786, Valid Loss: 0.2969, Valid Acc: 0.9043

Training... epoch 5
    Percent trained: 100.0%  Time elapsed: 2.9 min
Epoch 5: Train Loss: 0.0650, Train Acc: 0.9817, Valid Loss: 0.2652, Valid Acc: 0.9263

Training... epoch 6
    Percent trained: 100.0%  Time elapsed: 2.9 min
Epoch 6: Train Loss: 0.0547, Train Acc: 0.9844, Valid Loss: 0.2282, Valid Acc: 0.9334

Training... epoch 7
    Percent trained: 100.0%  Tim

In [16]:
def accuracy(model, X_test, y_test):
    model.eval()
    with torch.no_grad():
        logits = model(X_test.float())
    preds = torch.argmax(F.softmax(logits, dim=1), axis=1).numpy()
    return (preds == y_test.numpy()).mean()

# Testing accuracy after 10 epochs for ResNet and ODENet

In [17]:
print(f"ResNet accuracy: {round(accuracy(resnet, X_test, y_test), 3)}")
print(f"ODENet accuracy: {round(accuracy(odenet, X_test, y_test), 3)}")

ResNet accuracy: 0.957
ODENet accuracy: 0.941


# Model complexity

In [18]:
print("Number of tunable parameters in...")
print(f"    ResNet: {count_parameters(resnet)}")
print(f"    ODENet: {count_parameters(odenet)}")

Number of tunable parameters in...
    ResNet: 182853
    ODENet: 59333
