In [None]:
import torch
from torch.utils.data import DataLoader, ConcatDataset
import torch.nn as nn
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from torchinfo import summary
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt 
from utils.timeseriesdataset import TimeSeriesDataset
from utils.padding import pad_batch, LABEL_PADDING_VALUE
from models.models import RegressionModel
import pickle 
from pathlib import Path
import optuna.visualization as vis
import optuna

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
print('The model is running on:', DEVICE) 

  from .autonotebook import tqdm as notebook_tqdm


The model is running on: cpu


# Create DataLoaders

In [None]:
simulated_tracks_directory = Path("<enter dir filepath that has train val and test data>")

# for faster training we use pickled data, implementation without pickle see commented below
train_files = list(simulated_tracks_directory.glob("*/train_instances.pkl"))
val_files = list(simulated_tracks_directory.glob("*/val_instances.pkl"))

train_instances = []
val_instances = []

for file in train_files:
    with open(file, "rb") as f:
        train_instances += pickle.load(f)

for file in val_files:
    with open(file, "rb") as f:
        val_instances += pickle.load(f)

print("Train data: ", len(train_instances),  "Val data: ", len(val_instances))

# filepaths = list(simulated_tracks_directory.rglob('*.parquet'))
# print("Number of files found:", len(filepaths))
# random.shuffle(filepaths)
# train_instances = [TimeSeriesDataset(filepath, augment=True) for filepath in filepaths[:int(len(filepaths)*0.7)]]
# test_instances = [TimeSeriesDataset(filepath, augment=False) for filepath in filepaths[int(len(filepaths)*0.7):int(len(filepaths)*0.85)]]
# val_instances = [TimeSeriesDataset(filepath, augment=False) for filepath in filepaths[int(len(filepaths)*0.85):]]

In [None]:
conc_train = ConcatDataset(train_instances)
conc_val = ConcatDataset(val_instances)

train_loader = DataLoader(conc_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)
val_loader = DataLoader(conc_val, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)

print("DataLoader Sizes:", len(train_loader), len(val_loader))

# Training Functions

In [None]:
continuous_loss_fn = nn.L1Loss(reduction='none')

def train_one_epoch(model, optimizer, dataloader):
    model.train()
    running_loss = 0
    runs = 0

    for inputs, alpha_labels,_,_ in dataloader:

        if runs >= 10000:
            break

        inputs, alpha_labels = inputs.to(DEVICE), alpha_labels.to(DEVICE)
        mask = (alpha_labels != LABEL_PADDING_VALUE).float()

        outputs = model(inputs)
        outputs = outputs.squeeze(-1)
        total_loss = (continuous_loss_fn(outputs, alpha_labels) * mask).sum() / mask.sum()
                
        optimizer.zero_grad()
        total_loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        
        running_loss += total_loss.item()
        runs += 1

    return running_loss/runs


def evaluate_model(model, dataloader):
    model.eval()
    
    running_val_total = 0.0
    val_runs = 0

    with torch.no_grad():
        for inputs, alpha_labels,_,_ in dataloader:

            if val_runs >= 10000:
                break
            
            inputs, alpha_labels = inputs.to(DEVICE), alpha_labels.to(DEVICE)
            mask = (alpha_labels != LABEL_PADDING_VALUE).float()
            
            outputs = model(inputs)  
            outputs = outputs.squeeze(-1)
            loss_alpha = (continuous_loss_fn(outputs, alpha_labels) * mask).sum() / mask.sum()            
            running_val_total += loss_alpha.item()
            val_runs += 1
    
    return running_val_total / val_runs

# Objective Function

Here you can change the code to tune either model architecture, learning rates, epochs, batch_size, etc.

In [None]:
def objective(trial):
    # Hyperparameter suggestions
    l2_lambda = trial.suggest_float("lambda_l2", 1e-6, 1e-1, log=True)
    learning_rate = trial.suggest_float("lr", 1e-6, 1e-1, log=True)

    # Initialize model and optimizer
    model = RegressionModel().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=l2_lambda)
    best_val_loss = float("inf")

    for epoch in range(3): 
        
        train_one_epoch(model, optimizer=optimizer, dataloader=train_loader)
        val_total_loss = evaluate_model(model, val_loader)
        
        # Report intermediate loss to Optuna
        trial.report(val_total_loss, epoch)

        # Prune trial if it should be pruned
        # if trial.should_prune():
        #     raise optuna.TrialPruned()

        # Update best validation loss
        if val_total_loss < best_val_loss:
            best_val_loss = val_total_loss

    return best_val_loss

# Optimise

In [None]:
os.makedirs("optuna_study", exist_ok=True)
storage_path = "sqlite:///optuna_study/tune_alpha.db"

# pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction="minimize",
                            # pruner=pruner,
                            storage=storage_path)

study.optimize(objective, n_trials=100)

# Plot parameters

In [2]:
# # Access the study using the storage path
storage_path = "sqlite:///optuna_study/tune_alpha.db"
study = optuna.load_study(study_name='no-name-b3db059b-57bf-4d87-8d74-0587b504d1f3', storage=storage_path)

best_trial = study.best_trial

print(f"Best lambda_l2: {best_trial.params['lambda_l2']}")
print(f"Best lr: {best_trial.params['lr']}")
print(f"Best value: {best_trial.value}")

# Plot optimization history
fig1 = vis.plot_optimization_history(study, target=lambda t: t.values[0], target_name="Alpha Loss")
fig1.show()

# Plot hyperparameter importances
fig2 = vis.plot_param_importances(study)
fig2.show()

# Plot hyperparameter relationships (example for lambda_l2 vs objective value)
fig3 = vis.plot_slice(study, params=['lambda_l2'], target=lambda t: t.values[0], target_name="Total Loss")
fig3.show()

fig4 = vis.plot_slice(study, params=['lr'], target=lambda t: t.values[0], target_name="Total Loss")
fig4.show()


Best lambda_l2: 1.2584473282901645e-06
Best lr: 0.0018892576262405766
Best value: 0.14468583373501898


[W 2025-03-12 21:05:59,089] Trial 1 is omitted in visualization because its objective value is inf or nan.


[W 2025-03-12 21:05:59,819] Trial 1 is omitted in visualization because its objective value is inf or nan.


[W 2025-03-12 21:05:59,836] Trial 1 is omitted in visualization because its objective value is inf or nan.
