# Data Preparation

In [1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    roc_auc_score,
)

# Load the data
X, y = load_breast_cancer(return_X_y=True)
n_samples = X.shape[0]

# Define desired fractions
test_frac = 0.15
val_frac = 0.15

# Compute absolute counts
n_test = int(np.floor(test_frac * n_samples))
n_val = int(np.floor(val_frac * n_samples))
n_train = n_samples - n_test - n_val

print(f"Total samples: {n_samples}")
print(f"Train samples: {n_train}")
print(f"Validation samples: {n_val}")
print(f"Test samples: {n_test}")

# Split off the test set
X_train_val, X_test, y_train_val, y_test = train_test_split(
    X,
    y,
    test_size=n_test,
    random_state=42,
    stratify=y
)

# Split train versus validation set
X_train, X_val, y_train, y_val = train_test_split(
    X_train_val,
    y_train_val,
    test_size=n_val,
    random_state=42,
    stratify=y_train_val
)

# Fit scaler on training set only
scaler = StandardScaler().fit(X_train)

# Apply same transformation to train, val, and test
X_train = scaler.transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

# Print the shapes
print("\nTrain set shape:", X_train.shape, y_train.shape)
print("Validation set shape:", X_val.shape, y_val.shape)
print("Test set shape:", X_test.shape, y_test.shape)

Total samples: 569
Train samples: 399
Validation samples: 85
Test samples: 85

Train set shape: (399, 30) (399,)
Validation set shape: (85, 30) (85,)
Test set shape: (85, 30) (85,)


# Keras

In [2]:
import tensorflow as tf

# Build tf.data Datasets
BATCH_SIZE = 32

train_ds = (
    tf.data.Dataset.from_tensor_slices(
        (X_train.astype(np.float32), y_train.astype(np.float32))
    )
    .shuffle(len(X_train))
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices(
        (X_val.astype(np.float32), y_val.astype(np.float32))
    )
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices(
        (X_test.astype(np.float32), y_test.astype(np.float32))
    )
    .batch(BATCH_SIZE)
)

# Define the model
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(X_train.shape[1],)),
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(1, activation="sigmoid"),
])

# Compile the model
model.compile(
    optimizer="adam",
    loss="binary_crossentropy",
    metrics=[
        "accuracy",
        tf.keras.metrics.Precision(name="precision"),
        tf.keras.metrics.Recall(name="recall"),
        tf.keras.metrics.AUC(name="auc"),
    ],
)

# Set up model checkpointing
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "best_model.keras",
    monitor="val_loss",
    save_best_only=True,
    save_weights_only=False,
)

# Train the model with validation set
EPOCHS = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[checkpoint_cb]
)

# Load the best model
model = tf.keras.models.load_model("best_model.keras")

# Evaluate on the test set
print("\nTest Set Evaluation:")
model.evaluate(test_ds)

# Get predictions and compute metrics
y_pred_prob = model.predict(test_ds).ravel()
y_pred = (y_pred_prob >= 0.5).astype(int)

print("\nAccuracy:", np.float32(accuracy_score(y_test, y_pred)).round(3))
print("Precision:", np.float32(precision_score(y_test, y_pred)).round(3))
print("Recall:", np.float32(recall_score(y_test, y_pred)).round(3))
print("ROC AUC:", np.float32(roc_auc_score(y_test, y_pred_prob)).round(3))

Epoch 1/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 205ms/step - accuracy: 0.4934 - auc: 0.7190 - loss: 0.7133 - precision: 0.7974 - recall: 0.2391 - val_accuracy: 0.9647 - val_auc: 0.9861 - val_loss: 0.3953 - val_precision: 0.9630 - val_recall: 0.9811
Epoch 2/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.9164 - auc: 0.9690 - loss: 0.3661 - precision: 0.9424 - recall: 0.9230 - val_accuracy: 0.9647 - val_auc: 0.9956 - val_loss: 0.2074 - val_precision: 0.9630 - val_recall: 0.9811
Epoch 3/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.9694 - auc: 0.9917 - loss: 0.2011 - precision: 0.9649 - recall: 0.9887 - val_accuracy: 0.9647 - val_auc: 0.9982 - val_loss: 0.1239 - val_precision: 0.9630 - val_recall: 0.9811
Epoch 4/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.9650 - auc: 0.9914 - loss: 0.1346 - precision: 0.9611 - recall: 0.9842 - val_a

# PyTorch

In [3]:
import copy
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare TensorDatasets
train_ds = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32),
    torch.tensor(y_train, dtype=torch.float32),
)

val_ds = TensorDataset(
    torch.tensor(X_val, dtype=torch.float32),
    torch.tensor(y_val, dtype=torch.float32),
)

test_ds = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32),
    torch.tensor(y_test, dtype=torch.float32),
)

# Prepare DataLoaders
BATCH_SIZE = 32

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
)

test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
)

# Define the MLP
class MLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        return self.net(x).squeeze(1)

# Initialize the model
model = MLP(X.shape[1]).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop with validation set
EPOCHS = 20

best_val_loss = float("inf")
best_model_wts = copy.deepcopy(model.state_dict())

for epoch in range(EPOCHS):
    # Training step
    model.train()
    train_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)

    # Validation step
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_loader.dataset)

    # Checkpoint the best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_wts = copy.deepcopy(model.state_dict())

    print(
        f"Epoch {epoch+1}: "
        f"train_loss={train_loss:.4f}, "
        f"val_loss={val_loss:.4f}"
    )

# Load the best weights
model.load_state_dict(best_model_wts)

# Evaluate on the test set
model.eval()
probs, preds = [], []
with torch.no_grad():
    for xb, _ in test_loader:
        xb = xb.to(device)
        logits = model(xb)
        p = torch.sigmoid(logits).cpu().numpy()
        probs.append(p)
        preds.append((p >= 0.5).astype(int))
probs = np.concatenate(probs)
preds = np.concatenate(preds)

print("\nAccuracy:", np.float32(accuracy_score(y_test, preds)).round(3))
print("Precision:", np.float32(precision_score(y_test, preds)).round(3))
print("Recall:", np.float32(recall_score(y_test, preds)).round(3))
print("ROC AUC:", np.float32(roc_auc_score(y_test, probs)).round(3))

Epoch 1: train_loss=0.6283, val_loss=0.5387
Epoch 2: train_loss=0.4631, val_loss=0.3543
Epoch 3: train_loss=0.3004, val_loss=0.1996
Epoch 4: train_loss=0.1887, val_loss=0.1144
Epoch 5: train_loss=0.1268, val_loss=0.0725
Epoch 6: train_loss=0.0943, val_loss=0.0510
Epoch 7: train_loss=0.0783, val_loss=0.0402
Epoch 8: train_loss=0.0675, val_loss=0.0323
Epoch 9: train_loss=0.0616, val_loss=0.0276
Epoch 10: train_loss=0.0561, val_loss=0.0241
Epoch 11: train_loss=0.0518, val_loss=0.0213
Epoch 12: train_loss=0.0489, val_loss=0.0198
Epoch 13: train_loss=0.0448, val_loss=0.0196
Epoch 14: train_loss=0.0440, val_loss=0.0159
Epoch 15: train_loss=0.0398, val_loss=0.0151
Epoch 16: train_loss=0.0376, val_loss=0.0148
Epoch 17: train_loss=0.0358, val_loss=0.0133
Epoch 18: train_loss=0.0331, val_loss=0.0115
Epoch 19: train_loss=0.0321, val_loss=0.0114
Epoch 20: train_loss=0.0304, val_loss=0.0106

Accuracy: 0.929
Precision: 0.98
Recall: 0.906
ROC AUC: 0.994


# PyTorch Lightning

In [4]:
!pip install -q pytorch-lightning

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.1/823.1 kB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m108.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m87.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m47.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m43.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [5]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl

# Lightning Module
class LitMLP(pl.LightningModule):
    def __init__(self, in_dim):
        super().__init__()
        self.save_hyperparameters()
        self.model = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x).squeeze(1)

    def training_step(
        self,
        batch,
        _,
    ):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log(
            "train_loss",
            loss,
            prog_bar=True,
        )
        return loss

    def validation_step(
        self,
        batch,
        _,
    ):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log(
            "val_loss",
            loss,
            prog_bar=True,
        )

    def predict_step(
        self,
        batch,
        batch_idx,
        dataloader_idx=None
    ):
        x, _ = batch
        logits = self(x)
        return torch.sigmoid(logits)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Prepare TensorDatasets
train_ds = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32),
    torch.tensor(y_train, dtype=torch.float32),
)

val_ds = TensorDataset(
    torch.tensor(X_val, dtype=torch.float32),
    torch.tensor(y_val, dtype=torch.float32),
)

test_ds = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32),
    torch.tensor(y_test, dtype=torch.float32),
)

# Prepare DataLoaders
BATCH_SIZE = 32

pl_train = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

pl_val = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
)

pl_test = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
)

# Model checkpointing callback
checkpoint_cb = pl.callbacks.ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    filename="litmlp-best",
    save_top_k=1,
    verbose=True,
)

# Train with Trainer
EPOCHS = 20

trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    callbacks=[checkpoint_cb],
    enable_model_summary=True,
    enable_progress_bar=True,
    logger=True,
)

lit_model = LitMLP(X.shape[1])
trainer.fit(lit_model, pl_train, pl_val)

# Load the best checkpoint
best_path = checkpoint_cb.best_model_path
lit_model = LitMLP.load_from_checkpoint(
    best_path,
    in_dim=X_train.shape[1]
)

# Final predictions on test set
preds_list = trainer.predict(lit_model, pl_test)
probs = np.concatenate([p.numpy() for p in preds_list]).ravel()
preds = (probs >= 0.5).astype(int)

print("\nAccuracy:", np.float32(accuracy_score(y_test, preds)).round(3))
print("Precision:", np.float32(precision_score(y_test, preds)).round(3))
print("Recall:", np.float32(recall_score(y_test, preds)).round(3))
print("ROC AUC:", np.float32(roc_auc_score(y_test, probs)).round(3))

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | model     | Sequential        | 6.2 K  | train
1 | criterion | BCEWithLogitsLoss | 0      | train
--------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (13) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 13: 'val_loss' reached 0.55907 (best 0.55907), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 26: 'val_loss' reached 0.38890 (best 0.38890), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 2, global step 39: 'val_loss' reached 0.20531 (best 0.20531), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 3, global step 52: 'val_loss' reached 0.10817 (best 0.10817), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 4, global step 65: 'val_loss' reached 0.06756 (best 0.06756), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 5, global step 78: 'val_loss' reached 0.04747 (best 0.04747), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 6, global step 91: 'val_loss' reached 0.03780 (best 0.03780), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 7, global step 104: 'val_loss' reached 0.03116 (best 0.03116), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 8, global step 117: 'val_loss' reached 0.02701 (best 0.02701), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 9, global step 130: 'val_loss' reached 0.02320 (best 0.02320), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 10, global step 143: 'val_loss' reached 0.02106 (best 0.02106), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 11, global step 156: 'val_loss' reached 0.01898 (best 0.01898), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 12, global step 169: 'val_loss' reached 0.01749 (best 0.01749), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 13, global step 182: 'val_loss' reached 0.01691 (best 0.01691), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 14, global step 195: 'val_loss' reached 0.01523 (best 0.01523), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 15, global step 208: 'val_loss' reached 0.01464 (best 0.01464), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 16, global step 221: 'val_loss' reached 0.01328 (best 0.01328), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 17, global step 234: 'val_loss' reached 0.01279 (best 0.01279), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 18, global step 247: 'val_loss' reached 0.01271 (best 0.01271), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 19, global step 260: 'val_loss' reached 0.01226 (best 0.01226), saving model to '/content/lightning_logs/version_0/checkpoints/litmlp-best.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]


Accuracy: 0.941
Precision: 0.98
Recall: 0.925
ROC AUC: 0.994


# JAX

In [6]:
!pip install -q optax flax jax jaxlib

In [7]:
import jax
import jax.numpy as jnp
import optax
import pickle

# Convert data to JAX arrays
X_train_j = jnp.array(X_train)
y_train_j = jnp.array(y_train)
X_val_j = jnp.array(X_val)
y_val_j = jnp.array(y_val)
X_test_j = jnp.array(X_test)
y_test_j = jnp.array(y_test)

# Initialize parameters
def init_params(key):
    keys = jax.random.split(key, 3)
    def glorot(
        in_dim,
        out_dim,
        k,
    ):
        std = jnp.sqrt(2.0 / (in_dim + out_dim))
        return std * jax.random.normal(k, (in_dim, out_dim))
    return {
        "w1": glorot(X.shape[1], 64, keys[0]),
        "b1": jnp.zeros(64),
        "w2": glorot(64, 64, keys[1]),
        "b2": jnp.zeros(64),
        "w3": glorot(64, 1, keys[2]),
        "b3": jnp.zeros(1),
    }

# Forward pass
def forward(params, x):
    x = jax.nn.relu(x @ params["w1"] + params["b1"])
    x = jax.nn.relu(x @ params["w2"] + params["b2"])
    return jnp.squeeze(x @ params["w3"] + params["b3"], -1)

# Loss function
def loss_fn(params, x, y):
    logits = forward(params, x)
    return jnp.mean(optax.sigmoid_binary_cross_entropy(logits, y))

# Update step
opt = optax.adam(1e-3)

@jax.jit
def update(
    params,
    opt_state,
    x,
    y,
):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, new_opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss

@jax.jit
def eval_loss(
    params,
    x,
    y,
):
    return loss_fn(params, x, y)

# Simple generator for shuffled mini-batches
def data_batches(X_arr, y_arr, batch_size=32):
    n = X_arr.shape[0]
    idx = np.random.permutation(n)
    for i in range(0, n, batch_size):
        batch = idx[i : i + batch_size]
        yield X_arr[batch], y_arr[batch]

# Training loop with validation set
EPOCHS = 20

key = jax.random.PRNGKey(0)
params = init_params(key)
opt_state = opt.init(params)

best_val_loss = float("inf")

for epoch in range(1, EPOCHS + 1):
    # Training step
    for xb, yb in data_batches(X_train_j, y_train_j):
        params, opt_state, _ = update(params, opt_state, xb, yb)

    # Validation step
    v_loss = float(eval_loss(params, X_val_j, y_val_j))
    print(f"Epoch {epoch:2d}: val_loss = {v_loss:.4f}")

    # Checkpoint if improved
    if v_loss < best_val_loss:
        best_val_loss = v_loss
        with open("best_params.pkl", "wb") as f:
            pickle.dump(params, f)

# Restore the best parameters
with open("best_params.pkl", "rb") as f:
    params = pickle.load(f)
print(f"\nRestored best params with val_loss = {best_val_loss:.4f}")

# Predict on the test set
logits_test = forward(params, X_test_j)
probs = jax.nn.sigmoid(logits_test)
preds = (probs >= 0.5).astype(int)

# Compute metrics
probs_np = np.array(probs)
preds_np = np.array(preds)

print("\nAccuracy:", np.float32(accuracy_score(y_test, preds_np)).round(3))
print("Precision:", np.float32(precision_score(y_test, preds_np)).round(3))
print("Recall:", np.float32(recall_score(y_test, preds_np)).round(3))
print("ROC AUC:", np.float32(roc_auc_score(y_test, probs_np)).round(3))

Epoch  1: val_loss = 0.3129
Epoch  2: val_loss = 0.1588
Epoch  3: val_loss = 0.0969
Epoch  4: val_loss = 0.0665
Epoch  5: val_loss = 0.0498
Epoch  6: val_loss = 0.0406
Epoch  7: val_loss = 0.0335
Epoch  8: val_loss = 0.0293
Epoch  9: val_loss = 0.0262
Epoch 10: val_loss = 0.0243
Epoch 11: val_loss = 0.0224
Epoch 12: val_loss = 0.0206
Epoch 13: val_loss = 0.0197
Epoch 14: val_loss = 0.0185
Epoch 15: val_loss = 0.0174
Epoch 16: val_loss = 0.0160
Epoch 17: val_loss = 0.0157
Epoch 18: val_loss = 0.0148
Epoch 19: val_loss = 0.0134
Epoch 20: val_loss = 0.0143

Restored best params with val_loss = 0.0134

Accuracy: 0.941
Precision: 0.98
Recall: 0.925
ROC AUC: 0.994


# Flax

In [8]:
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.training import train_state, checkpoints
import os
import pickle
from absl import logging

# Only show errors
logging.set_verbosity(logging.ERROR)

# Convert data to JAX arrays
X_train_j = jnp.array(X_train)
y_train_j = jnp.array(y_train)
X_val_j = jnp.array(X_val)
y_val_j = jnp.array(y_val)
X_test_j = jnp.array(X_test)
y_test_j = jnp.array(y_test)

# Define Flax MLP
class FlaxMLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return jnp.squeeze(x, -1)

# Create training state
def create_train_state(rng, learning_rate):
    model = FlaxMLP()
    params = model.init(rng, jnp.ones([1, X.shape[1]]))["params"]
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )

# Training step
@jax.jit
def train_step(
    state,
    x,
    y,
):
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, x)
        return jnp.mean(optax.sigmoid_binary_cross_entropy(logits, y))

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)

# Data batch generator
def data_batches(X_arr, y_arr, batch_size=32):
    n = X_arr.shape[0]
    idx = np.random.permutation(n)
    for i in range(0, n, batch_size):
        batch = idx[i : i + batch_size]
        yield X_arr[batch], y_arr[batch]

# Set up absolute checkpoint directory
CKPT_DIR = os.path.abspath("./checkpoints")
os.makedirs(CKPT_DIR, exist_ok=True)

# Initialize state
rng = jax.random.PRNGKey(0)
state = create_train_state(rng, learning_rate=1e-3)

# Training loop with validation set
EPOCHS = 20

best_val_loss = float("inf")

for epoch in range(1, EPOCHS + 1):
    # Training step
    for xb, yb in data_batches(X_train_j, y_train_j):
        state = train_step(state, xb, yb)

    # Validation step
    val_logits = state.apply_fn({"params": state.params}, X_val_j)
    v_loss = float(
        jnp.mean(optax.sigmoid_binary_cross_entropy(val_logits, y_val_j))
    )
    print(f"Epoch {epoch:2d}: val_loss = {v_loss:.4f}")

    # Checkpoint if improved
    if v_loss < best_val_loss:
        best_val_loss = v_loss
        checkpoints.save_checkpoint(
            ckpt_dir=CKPT_DIR,
            target=state,
            step=epoch,
            keep=1,
            overwrite=True,
        )

# Restore the best checkpoint
state = checkpoints.restore_checkpoint(
    ckpt_dir=CKPT_DIR,
    target=state,
)
print(f"\nRestored checkpoint at val_loss = {best_val_loss:.4f}")

# Evaluate on test set
logits_test = state.apply_fn({"params": state.params}, X_test_j)
probs = jax.nn.sigmoid(logits_test)
preds = (probs >= 0.5).astype(int)

# Compute metrics
probs_np = np.array(probs)
preds_np = np.array(preds)

print("\nAccuracy:", np.float32(accuracy_score(y_test, preds_np)).round(3))
print("Precision:", np.float32(precision_score(y_test, preds_np)).round(3))
print("Recall:", np.float32(recall_score(y_test, preds_np)).round(3))
print("ROC AUC:", np.float32(roc_auc_score(y_test, probs_np)).round(3))

Epoch  1: val_loss = 0.3484
Epoch  2: val_loss = 0.1932
Epoch  3: val_loss = 0.1141
Epoch  4: val_loss = 0.0754
Epoch  5: val_loss = 0.0554
Epoch  6: val_loss = 0.0450
Epoch  7: val_loss = 0.0374
Epoch  8: val_loss = 0.0311
Epoch  9: val_loss = 0.0272
Epoch 10: val_loss = 0.0243
Epoch 11: val_loss = 0.0218
Epoch 12: val_loss = 0.0205
Epoch 13: val_loss = 0.0186
Epoch 14: val_loss = 0.0176
Epoch 15: val_loss = 0.0166
Epoch 16: val_loss = 0.0137
Epoch 17: val_loss = 0.0140
Epoch 18: val_loss = 0.0138
Epoch 19: val_loss = 0.0124
Epoch 20: val_loss = 0.0112

Restored checkpoint at val_loss = 0.0112

Accuracy: 0.953
Precision: 0.98
Recall: 0.943
ROC AUC: 0.993


# Sklearn

In [9]:
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import (
    log_loss,
    accuracy_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
import numpy as np
import os
import pickle

# Create checkpoint directory
CKPT_DIR = os.path.abspath("./checkpoints")
os.makedirs(CKPT_DIR, exist_ok=True)
CKPT_FILE = os.path.join(CKPT_DIR, "best_mlp_model.pkl")

# Instantiate the model
model = MLPClassifier(
    hidden_layer_sizes=(64, 64),
    activation='relu',
    solver='adam',
    random_state=42,
    warm_start=True,
    max_iter=1,
)

# Perform initial partial_fit with all classes
classes = np.unique(y_train)
model.partial_fit(X_train, y_train, classes=classes)

best_val_loss = float("inf")

# Incremental training loop with checkpointing
EPOCHS = 20

for epoch in range(1, EPOCHS + 1):
    # Train one epoch
    model.partial_fit(X_train, y_train)

    # Compute validation loss
    val_pred_prob = model.predict_proba(X_val)[:, 1]
    val_loss = log_loss(y_val, val_pred_prob)
    print(f"Epoch {epoch:2d}: val_loss={val_loss:.4f}")

    # Save checkpoint if validation loss improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        with open(CKPT_FILE, "wb") as f:
            pickle.dump(model, f)

# Load the best model from checkpoint
with open(CKPT_FILE, "rb") as f:
    best_model = pickle.load(f)
    print(f"\nRestored best model with val_loss = {best_val_loss:.4f}")

# Final predictions and metrics on test set
y_prob = best_model.predict_proba(X_test)[:, 1]
y_pred = (y_prob >= 0.5).astype(int)

print("\nAccuracy:", np.float32(accuracy_score(y_test, y_pred)).round(3))
print("Precision:", np.float32(precision_score(y_test, y_pred)).round(3))
print("Recall:", np.float32(recall_score(y_test, y_pred)).round(3))
print("ROC AUC:", np.float32(roc_auc_score(y_test, y_prob)).round(3))

Epoch  1: val_loss=0.6307
Epoch  2: val_loss=0.5533
Epoch  3: val_loss=0.4864
Epoch  4: val_loss=0.4284
Epoch  5: val_loss=0.3774
Epoch  6: val_loss=0.3327
Epoch  7: val_loss=0.2939
Epoch  8: val_loss=0.2602
Epoch  9: val_loss=0.2309
Epoch 10: val_loss=0.2058
Epoch 11: val_loss=0.1840
Epoch 12: val_loss=0.1652
Epoch 13: val_loss=0.1490
Epoch 14: val_loss=0.1349
Epoch 15: val_loss=0.1229
Epoch 16: val_loss=0.1123
Epoch 17: val_loss=0.1032
Epoch 18: val_loss=0.0951
Epoch 19: val_loss=0.0880
Epoch 20: val_loss=0.0818

Restored best model with val_loss = 0.0818

Accuracy: 0.941
Precision: 0.962
Recall: 0.943
ROC AUC: 0.991


# Numpyro

In [10]:
!pip install -q numpyro

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/365.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m365.8/365.8 kB[0m [31m28.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [11]:
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer.autoguide import AutoDiagonalNormal
from numpyro.infer import (
    SVI,
    Trace_ELBO,
    Predictive,
)
from numpyro.optim import Adam
import os
import pickle

# Convert data to JAX arrays
X_train_j = jnp.array(X_train)
y_train_j = jnp.array(y_train)
X_val_j = jnp.array(X_val)
y_val_j = jnp.array(y_val)
X_test_j = jnp.array(X_test)
y_test_j = jnp.array(y_test)

# Define Bayesian MLP model
def bayes_mlp(X, y=None):
    n_feats = X.shape[1]
    w1 = numpyro.sample(
        'w1',
        dist.Normal(0, 1)
        .expand([n_feats, 64])
        .to_event(2),
    )
    b1 = numpyro.sample(
        'b1',
        dist.Normal(0, 1)
        .expand([64])
        .to_event(1),
    )
    h1 = jax.nn.relu(jnp.dot(X, w1) + b1)

    w2 = numpyro.sample(
        'w2',
        dist.Normal(0, 1)
        .expand([64, 64])
        .to_event(2),
    )
    b2 = numpyro.sample(
        'b2',
        dist.Normal(0, 1)
        .expand([64])
        .to_event(1),
    )
    h2 = jax.nn.relu(jnp.dot(h1, w2) + b2)

    w3 = numpyro.sample(
        'w3',
        dist.Normal(0, 1)
        .expand([64, 1])
        .to_event(2),
    )
    b3 = numpyro.sample(
        'b3',
        dist.Normal(0, 1)
        .expand([1])
        .to_event(1),
    )
    logits = jnp.squeeze(jnp.dot(h2, w3) + b3, -1)

    with numpyro.plate('data', X.shape[0]):
        numpyro.sample(
            'obs',
            dist.Bernoulli(logits=logits),
            obs=y,
        )

# Set up SVI with a mean-field guide
guide = AutoDiagonalNormal(bayes_mlp)
svi = SVI(
    bayes_mlp,
    guide,
    Adam(1e-3),
    loss=Trace_ELBO(),
)

# Initialize SVI state
rng = jax.random.PRNGKey(0)
state = svi.init(rng, X_train_j, y_train_j)

# Checkpoint directory and tracking
CKPT_FILE = "best_svi_state.pkl"
best_val_loss = float("inf")

# Training loop with validation set
EPOCHS = 20

for epoch in range(1, EPOCHS + 1):
    # Perform one SVI update on full training set
    state, train_loss = svi.update(state, X_train_j, y_train_j)

    # Evaluate on validation set
    val_loss = svi.evaluate(state, X_val_j, y_val_j)

    print(
        f"Epoch {epoch:2d}: "
        f"train_loss={train_loss:.4f}, "
        f"val_loss={val_loss:.4f}"
    )

    # Checkpoint if validation improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        with open(CKPT_FILE, "wb") as f:
            pickle.dump(state, f)

# Restore the best checkpoint
with open(CKPT_FILE, "rb") as f:
    state = pickle.load(f)
print(f"\nRestored best SVI state at val_loss = {best_val_loss:.4f}")

# Posterior predictive on the test set
params = svi.get_params(state)

predictive = Predictive(
    bayes_mlp,
    guide=guide,
    params=params,
    num_samples=1000,
)

samples = predictive(rng, X_test_j)['obs']
pred_mean = jnp.mean(samples, axis=0)

y_pred = (np.array(pred_mean) >= 0.5).astype(int)
y_prob = np.array(pred_mean)

# Final test metrics
print("\nAccuracy:", np.float32(accuracy_score(y_test, y_pred)).round(3))
print("Precision:", np.float32(precision_score(y_test, y_pred)).round(3))
print("Recall:", np.float32(recall_score(y_test, y_pred)).round(3))
print("ROC AUC:", np.float32(roc_auc_score(y_test, y_prob)).round(3))

Epoch  1: train_loss=39555.5938, val_loss=19390.5527
Epoch  2: train_loss=34198.5586, val_loss=20088.6113
Epoch  3: train_loss=35490.2539, val_loss=20219.5586
Epoch  4: train_loss=37443.4336, val_loss=19862.0352
Epoch  5: train_loss=34640.1836, val_loss=19329.1621
Epoch  6: train_loss=32711.1016, val_loss=18780.4668
Epoch  7: train_loss=33420.7305, val_loss=19111.2520
Epoch  8: train_loss=31365.5781, val_loss=18970.6914
Epoch  9: train_loss=29998.2930, val_loss=18346.3770
Epoch 10: train_loss=27892.5430, val_loss=18525.0254
Epoch 11: train_loss=30377.2266, val_loss=18309.7031
Epoch 12: train_loss=30521.4492, val_loss=17837.3633
Epoch 13: train_loss=26068.6055, val_loss=18379.1680
Epoch 14: train_loss=29714.1953, val_loss=17473.5527
Epoch 15: train_loss=25563.0312, val_loss=17435.8281
Epoch 16: train_loss=23810.2090, val_loss=17668.1367
Epoch 17: train_loss=26875.6562, val_loss=17854.7539
Epoch 18: train_loss=26705.3203, val_loss=17260.2031
Epoch 19: train_loss=24422.1914, val_loss=1751

# Manual

In [12]:
import os
import pickle
import numpy as np
from scipy.special import expit  # Vectorized sigmoid

# Activation functions and derivatives
def relu(z):
    return np.maximum(0, z)

def relu_derivative(z):
    return (z > 0).astype(float)

def sigmoid(z):
    return expit(z)

def sigmoid_derivative(z):
    s = expit(z)
    return s * (1 - s)

# Xavier initialization
def init_params(layer_sizes, seed=42):
    params = {}
    rng = np.random.default_rng(seed)

    for i in range(len(layer_sizes) - 1):
        fan_in = layer_sizes[i]
        fan_out = layer_sizes[i + 1]
        limit = np.sqrt(6 / (fan_in + fan_out))
        params[f'W{i+1}'] = rng.uniform(
            -limit,
            limit,
            size=(fan_in, fan_out),
        )
        params[f'b{i+1}'] = np.zeros((1, fan_out))
    return params

# Forward pass
def forward(params, x):
    caches = {'A0': x}
    num_layers = len(params) // 2

    for i in range(1, num_layers + 1):
        w = params[f'W{i}']
        b = params[f'b{i}']
        z = caches[f'A{i-1}'] @ w + b
        caches[f'Z{i}'] = z
        if i < num_layers:
            a = relu(z)
        else:
            a = sigmoid(z)
        caches[f'A{i}'] = a
    return a, caches

# Loss (binary cross-entropy)
def compute_loss(y_true, y_pred):
    eps = 1e-15
    y_pred = np.clip(y_pred, eps, 1 - eps)
    return -np.mean(
        y_true * np.log(y_pred)
        + (1 - y_true) * np.log(1 - y_pred)
    )

# Backward pass
def backward(params, caches, y_true):
    grads = {}
    batch_size = y_true.shape[0]
    num_layers = len(params) // 2

    a_last = caches[f'A{num_layers}'].reshape(-1, 1)
    d_a = -(y_true.reshape(-1, 1) / a_last
            - (1 - y_true).reshape(-1, 1)
              / (1 - a_last))

    for i in reversed(range(1, num_layers + 1)):
        z = caches[f'Z{i}']
        a_prev = caches[f'A{i-1}']
        if i == num_layers:
            d_z = d_a * sigmoid_derivative(z)
        else:
            d_z = d_a * relu_derivative(z)
        grads[f'dW{i}'] = a_prev.T @ d_z / batch_size
        grads[f'db{i}'] = np.sum(
            d_z,
            axis=0,
            keepdims=True,
        ) / batch_size
        d_a = d_z @ params[f'W{i}'].T
    return grads

# Adam optimizer state initialization
def init_adam(params):
    v = {k: np.zeros_like(v) for k, v in params.items()}
    s = {k: np.zeros_like(v) for k, v in params.items()}
    return v, s

def adam_update(
    params,
    grads,
    v,
    s,
    t,
    lr=1e-3,
    beta1=0.9,
    beta2=0.999,
    eps=1e-8,
):
    for key in params:
        dv = grads[f'd{key}']
        v[key] = beta1 * v[key] + (1 - beta1) * dv
        s[key] = beta2 * s[key] + (1 - beta2) * (dv ** 2)
        v_corr = v[key] / (1 - beta1 ** t)
        s_corr = s[key] / (1 - beta2 ** t)
        params[key] -= lr * v_corr / (np.sqrt(s_corr) + eps)

# Prepare data and hyperparameters
layer_sizes = [X_train.shape[1], 64, 64, 1]
params = init_params(layer_sizes)
v_state, s_state = init_adam(params)
timestep = 0

EPOCHS = 20
BATCH_SIZE = 32
N_TRAIN = X_train.shape[0]

# Checkpoint setup
ckpt_dir = os.path.abspath('./checkpoints')
os.makedirs(ckpt_dir, exist_ok=True)
ckpt_file = os.path.join(ckpt_dir, 'best_params.pkl')
best_val_loss = float('inf')

# Training loop with validation set
for epoch in range(1, EPOCHS + 1):
    idx = np.random.permutation(N_TRAIN)

    X_shuffled = X_train[idx]
    y_shuffled = y_train[idx]

    # Training step
    for start in range(0, N_TRAIN, BATCH_SIZE):
        end = start + BATCH_SIZE
        X_batch = X_shuffled[start:end]
        y_batch = y_shuffled[start:end]
        y_pred, cache = forward(params, X_batch)
        grads = backward(params, cache, y_batch)
        timestep += 1
        adam_update(
            params,
            grads,
            v_state,
            s_state,
            timestep,
        )

    # Validation step
    y_val_pred, _ = forward(params, X_val)
    val_loss = compute_loss(
        y_val,
        y_val_pred.ravel(),
    )
    print(f'Epoch {epoch:2d}: val_loss={val_loss:.4f}')

    # Checkpointing
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        with open(ckpt_file, 'wb') as f:
            pickle.dump(params, f)

# Restore the best parameters
with open(ckpt_file, 'rb') as f:
    params = pickle.load(f)
print(f'\nRestored best params at val_loss = {best_val_loss:.4f}')

# Final evaluation on the test set
y_test_pred, _ = forward(params, X_test)
y_prob = y_test_pred.ravel()
y_pred = (y_prob >= 0.5).astype(int)

# Compute metrics
accuracy = np.mean(y_pred == y_test)

precision = np.sum(
    (y_pred == 1) & (y_test == 1)
) / np.sum(y_pred == 1)

recall = np.sum(
    (y_pred == 1) & (y_test == 1)
) / np.sum(y_test == 1)

def compute_roc_auc(y_true, y_scores):
    order = np.argsort(-y_scores)
    y_true_sorted = y_true[order]
    distinct = np.where(np.diff(y_scores[order]))[0]
    thresh_idxs = np.concatenate(
        [distinct, [len(y_scores) - 1]],
    )
    tps = np.cumsum(y_true_sorted)[thresh_idxs]
    fps = 1 + thresh_idxs - tps
    tpr = tps / tps[-1]
    fpr = fps / fps[-1]
    return np.trapezoid(tpr, fpr)

roc_auc = compute_roc_auc(y_test, y_prob)

print("\nAccuracy:", accuracy.round(3))
print("Precision:", precision.round(3))
print("Recall:", recall.round(3))
print("ROC AUC:", roc_auc.round(3))

Epoch  1: val_loss=0.2750
Epoch  2: val_loss=0.1408
Epoch  3: val_loss=0.0911
Epoch  4: val_loss=0.0664
Epoch  5: val_loss=0.0532
Epoch  6: val_loss=0.0451
Epoch  7: val_loss=0.0396
Epoch  8: val_loss=0.0347
Epoch  9: val_loss=0.0313
Epoch 10: val_loss=0.0283
Epoch 11: val_loss=0.0256
Epoch 12: val_loss=0.0252
Epoch 13: val_loss=0.0239
Epoch 14: val_loss=0.0234
Epoch 15: val_loss=0.0222
Epoch 16: val_loss=0.0213
Epoch 17: val_loss=0.0195
Epoch 18: val_loss=0.0204
Epoch 19: val_loss=0.0201
Epoch 20: val_loss=0.0192

Restored best params at val_loss = 0.0192

Accuracy: 0.965
Precision: 0.981
Recall: 0.962
ROC AUC: 0.995
