In [6]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

from pt_runner.v1 import CheckpointHandler, DataHandlerPT, EarlyStopper

In [7]:
df = pd.read_excel("data.xlsx", index_col="exp")
_X = df.iloc[:, :-3].values
_Y = df.iloc[:, -3:].values
print(_X.shape)
print(_Y.shape)
data_handler = DataHandlerPT(
    _X=_X, _Y=_Y, scalerX=StandardScaler(), scalerY=StandardScaler()
)

(100, 47)
(100, 3)


In [None]:
# Define the model
num_features = data_handler._X.shape[1]
num_outputs = data_handler._Y.shape[1]


class MyModel(nn.Module):
    def __init__(self, num_features, num_outputs):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(num_features, 24)
        self.fc2 = nn.Linear(24, 12)
        self.fc3 = nn.Linear(12, 6)
        self.fc4 = nn.Linear(6, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)
        return x


model = MyModel(num_features, num_outputs)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5)
loss_fn = nn.MSELoss()  # mean square error

In [9]:
# Data
data_handler.split_and_scale(test_size=0.2, val_size=0.1, random_state=0)
ds_train = data_handler.get_train()
ds_test = data_handler.get_test()
ds_val = data_handler.get_val()


In [10]:
cph = CheckpointHandler()

model, optimizer, epoch, val_loss = cph.load(
    save_path="./checkpoints/2025-05-25_08-58.pth", model=model, optimizer=optimizer
)
print(f"Load model @ epoch {epoch} with loss {val_loss}")
model.eval()
with torch.no_grad():
    X_test, Y_test = ds_test[:]
    test_pred = model(X_test)
    final_loss = loss_fn(test_pred, Y_test)
    print(f"Test loss: {final_loss:.4f}")

Load model @ epoch 1330 with loss 0.02943030558526516
Test loss: 0.0228
