In [3]:
import pandas as pd
import numpy as np
import torch
from torch import nn
from torchmetrics.classification import BinaryAccuracy

In [11]:

# Load data
X_train = np.loadtxt("./X_train.csv", delimiter=",", dtype=float)
X_test = np.loadtxt("./X_test.csv", delimiter=",", dtype=float)
y_train = np.loadtxt("./y_train.csv", delimiter=",", dtype=float)
y_test = np.loadtxt("./y_test.csv", delimiter=",", dtype=float)

In [14]:
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1)

  X_train = torch.tensor(X_train, dtype=torch.float32)


In [21]:
X_train.shape, y_train.shape, X_test.shape, y_test.shape

(torch.Size([316175, 80]),
 torch.Size([316175, 1]),
 torch.Size([79044, 80]),
 torch.Size([79044, 1]))

In [22]:
# Create model
model = nn.Sequential(
    nn.Linear(80, 40),
    nn.LayerNorm(40),
    nn.ReLU(),
    nn.Linear(40, 20),
    nn.ReLU(),
    nn.Linear(20, 10),
    nn.ReLU(),
    nn.Linear(10, 1),
    nn.Sigmoid(),
)

# Define loss function
loss_function = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


In [23]:

# Train model
epochs = 2000
batch_size = 1000
losses = []
losses_val = []
accuracies = []
accuracies_val = []
measure_accuracy = BinaryAccuracy()
best_val_accuracy = 0.0
early_stop_patience = 5
early_stop_counter = 0

# Create dataset and loader
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)


In [24]:

# Train
for epoch in range(epochs):
    # Iterate over the data in batches
    for X_batch, y_batch in train_loader:
        # Forward pass - get predictions for train
        y_train_pred = model.forward(X_batch)

        # Calculate loss
        loss = loss_function(y_train_pred, y_batch)

        # Calculate other metrics
        accuracy = measure_accuracy(y_train_pred, y_batch)

        # Save metrics and print
        losses.append(loss)
        accuracies.append(accuracy)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Calculate validation loss and accuracy
    with torch.no_grad():
        y_test_pred = model.forward(X_test)
        val_loss = loss_function(y_test_pred, y_test)
        val_accuracy = measure_accuracy(y_test_pred, y_test)

    # Early stopping
    if val_accuracy >= best_val_accuracy:
        best_val_accuracy = val_accuracy
        early_stop_counter = 0
        best_model = model.state_dict()
    else:
        early_stop_counter += 1

    # Save metrics and print
    losses_val.append(val_loss)
    accuracies_val.append(val_accuracy)
    print(
        f"epoch: {epoch}  loss: {losses[-1].item():10.8f} val_loss: {val_loss.item():10.8f} accuracy: {accuracies[-1].item():10.8f} val_accuracy: {val_accuracy.item():10.8f}"
    )

    # Check early stopping
    if early_stop_counter >= early_stop_patience:
        print(
            f"Early stopping at epoch {epoch}, best validation accuracy: {best_val_accuracy}"
        )
        break


epoch: 0  loss: 0.34584638 val_loss: 0.28053796 accuracy: 0.85142857 val_accuracy: 0.88816357
epoch: 1  loss: 0.27333695 val_loss: 0.26875427 accuracy: 0.89714283 val_accuracy: 0.88870758
epoch: 2  loss: 0.28632250 val_loss: 0.26728269 accuracy: 0.85714287 val_accuracy: 0.88856840
epoch: 3  loss: 0.23953845 val_loss: 0.26218408 accuracy: 0.90857142 val_accuracy: 0.88889730
epoch: 4  loss: 0.27422708 val_loss: 0.26639074 accuracy: 0.87428570 val_accuracy: 0.88744241
epoch: 5  loss: 0.20336471 val_loss: 0.26022887 accuracy: 0.93142855 val_accuracy: 0.88901120
epoch: 6  loss: 0.26233146 val_loss: 0.26108950 accuracy: 0.88000000 val_accuracy: 0.88894790
epoch: 7  loss: 0.25390762 val_loss: 0.26133147 accuracy: 0.89714283 val_accuracy: 0.88894790
epoch: 8  loss: 0.22811857 val_loss: 0.25924847 accuracy: 0.90857142 val_accuracy: 0.88888466
epoch: 9  loss: 0.22042516 val_loss: 0.25958446 accuracy: 0.89714283 val_accuracy: 0.88866961
epoch: 10  loss: 0.17372759 val_loss: 0.25907132 accuracy: 0