##### Import

In [5]:
import warnings
import papermill as pm
import scrapbook as sb
import pandas as pd
import numpy as np
from scipy.stats import spearmanr
from tqdm import tqdm
import shap
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler
# from sklearn.linear_model import LinearRegression, Ridge, Lasso
# from sklearn.pipeline import Pipeline
import os
import gc
import sys

# Filter out warning messages
warnings.filterwarnings('ignore')

# Set pandas display options
pd.set_option('display.max_columns', 10000)
pd.set_option('display.max_rows', 10000)

# Set seaborn style
sns.set_style('whitegrid')

# Add the parent directory to sys.path
sys.path.insert(1, os.path.join(sys.path[0], '..'))

# Index and deciles for data slicing
idx = pd.IndexSlice



from pathlib import Path

# Paths to the downloaded datasets, model, and hyperparameters
data_dir = Path('data/')
model_dir = Path('models/')
best_hyperparams_dir = Path('best_hyperparams/')
study_dir = Path('study/')

# Create directories if they do not exist
data_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
best_hyperparams_dir.mkdir(parents=True, exist_ok=True)
study_dir.mkdir(parents=True, exist_ok=True)

In [6]:
# from pathlib import Path
# import pandas as pd
# from utils import rank_stocks_and_quantile
# # UNSEEN_KEY = '/data/YEAR_20220803_20230803'
# top = 250  # parameters -> papermill
# DATA_STORE = Path(f'data/{top}_dataset.h5')
# with pd.HDFStore(DATA_STORE) as store:
#     # unseen = store[UNSEEN_KEY]
#     print(store.keys())

In [7]:
"""
Process Large Financial Datasets from HDF5 Format.

This script reads, processes, and normalizes financial datasets stored in an HDF5 format.
The primary processing steps involve converting data types, handling infinite values, and
scaling the dataset. The MinMaxScaler, computed from the entire dataset, is employed for normalization.
Once data processing is complete, stocks are ranked, and quantiles are determined in post-processing.

Attributes:
    - top (int): Number of top stocks to consider.
    - DATA_STORE (Path): Path to the HDF5 file containing the datasets.
    - dataset_keys (list of str): Keys identifying which datasets to process in the HDF5 store.
    - target_string (str): Target column identifier for post-processing.
    - CHUNK_SIZE (int): Size of chunks in which data is read and processed.

Functions:
    - convert_dtype(chunk, feature_columns, dtype='float32'): Converts dtype of specified columns in a chunk.
    - handle_infinite_values(chunk, feature_columns): Handles infinite values in a chunk.
    - process_chunk(chunk, feature_columns, scaler=None): Process a single chunk with optional normalization.

Workflow:
    1. Set parameters and paths.
    2. Define utility functions.
    3. Identify features and target columns from the first chunk.
    4. Determine the MinMaxScaler using all chunks in the dataset.
    5. Process and concatenate chunks to form the dataset.
    6. Rank stocks and compute quantiles in post-processing.
"""

import gc
import numpy as np
import pandas as pd
from pathlib import Path
from utils import rank_stocks_and_quantile
from sklearn.preprocessing import MinMaxScaler

# Parameters and data paths
TOP = top = 250
DATA_STORE = Path(f'data/{top}_dataset.h5')
dataset_keys = [
    '/data/YEAR_20200930_20220802',
    '/data/YEAR_20181024_20200929',
    '/data/YEAR_20161116_20181023',
    '/data/YEAR_20141210_20161115'
]
target_string = 'TARGET_ret_fwd'
CHUNK_SIZE = 50000

def convert_dtype(chunk, feature_columns, dtype='float32'):
    """Converts the datatype of the specified columns."""
    chunk[feature_columns] = chunk[feature_columns].astype(dtype)
    return chunk

def handle_infinite_values(chunk, feature_columns):
    """Handle infinite values by replacing them with the maximum finite value."""
    max_val = np.finfo('float32').max
    chunk[feature_columns] = chunk[feature_columns].replace([np.inf, -np.inf], max_val)
    return chunk

def process_chunk(chunk, feature_columns, scaler=None):
    """Process a single chunk of data."""
    chunk = convert_dtype(chunk, feature_columns)
    chunk = handle_infinite_values(chunk, feature_columns)
    
    # Normalize with scaler if provided
    if scaler:
        chunk[feature_columns] = scaler.transform(chunk[feature_columns])
    
    return chunk

# Identify features and targets based on the first chunk
with pd.HDFStore(DATA_STORE) as store:
    first_chunk = store.select(dataset_keys[0], stop=CHUNK_SIZE)
    features = [col for col in first_chunk.columns if col.startswith('FEATURE_')]
    target = [col for col in first_chunk.columns if col.startswith('TARGET_')]

# Determine the scaler using the entire dataset for the identified features
scaler = MinMaxScaler()

for key in dataset_keys:
    with pd.HDFStore(DATA_STORE) as store:
        for chunk in store.select(key, chunksize=CHUNK_SIZE):
            # Convert dtype and handle infinite values
            chunk = convert_dtype(chunk, features)
            chunk = handle_infinite_values(chunk, features)
            scaler.partial_fit(chunk[features])

# Process and concatenate chunks
dataset = pd.DataFrame()
for key in dataset_keys:
    with pd.HDFStore(DATA_STORE) as store:
        for chunk in store.select(key, chunksize=CHUNK_SIZE):
            processed_chunk = process_chunk(chunk, features, scaler)
            dataset = pd.concat([dataset, processed_chunk], ignore_index=False)
            del processed_chunk
            gc.collect()

# Post-processing steps
dataset = rank_stocks_and_quantile(dataset, target_substring=target_string)
dataset.index.set_levels(dataset.index.levels[0].tz_localize(None), \
    level=0, inplace=True)

In [8]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm
from joblib import Parallel, delayed

PADDING_VALUE = -1
MAX_LEN = None  # If you have a predefined value, set it here; otherwise, it gets calculated automatically.

def pad_sequence(inputs, padding_value=-1, max_len=None):
    if max_len is None:
        max_len = max([input.shape[0] for input in inputs])
    padded_inputs = []
    masks = []
    for input in inputs:
        pad_len = max_len - input.shape[0]
        padded_input = F.pad(input, (0, 0, 0, pad_len), value=padding_value)
        mask = torch.ones((input.shape[0], 1), dtype=torch.float)
        masks.append(
            torch.cat((mask, torch.zeros((pad_len, 1), dtype=torch.float)), dim=0)
        )
        padded_inputs.append(padded_input)
    return torch.stack(padded_inputs), torch.stack(masks)

def convert_to_torch(timestamp, data):
    feature_names = [col for col in data.columns if col.startswith('FEATURE_')]
    target_names = [col for col in data.columns if col.startswith('TARGET_')]
    
    inputs = torch.from_numpy(
                data[feature_names].values.astype(np.float32))
    labels = torch.from_numpy(
                data[target_names].values.astype(np.float32))

    padded_inputs, masks_inputs = pad_sequence(
            [inputs], padding_value=PADDING_VALUE, max_len=MAX_LEN)
    padded_labels, masks_labels = pad_sequence(
            [labels], padding_value=PADDING_VALUE, max_len=MAX_LEN)

    return {
        timestamp: (
            padded_inputs,
            padded_labels,
            masks_inputs,
            target_names
        )
    }

def get_era2data(df):
    # Group by the Timestamp index (level=0)
    res = Parallel(n_jobs=-1, prefer="threads")(
        delayed(convert_to_torch)(timestamp, data)
        for timestamp, data in tqdm(df.groupby(level=0)))
    
    era2data = {}
    for r in tqdm(res):
        era2data.update(r)
    return era2data

# Assuming DataFrame is named "dataset": testing the function
timestamp2data_dataset = get_era2data(dataset)

100%|██████████| 1924/1924 [00:03<00:00, 547.97it/s]
100%|██████████| 1924/1924 [00:00<00:00, 2235412.99it/s]


In [9]:
import torch
import torch.nn as nn

def pearsonr(x, y):
    xm, ym = x - x.mean(), y - y.mean()
    r_num = torch.sum(xm * ym)
    r_den = torch.sqrt(torch.sum(xm ** 2) + 1e-10) * torch.sqrt(torch.sum(ym ** 2) + 1e-10)
    correlation = r_num / r_den
    return correlation.requires_grad_()  # Ensure that the returned tensor requires gradients

def spearmanr(x, y):
    rank_x = x.argsort().argsort().float()
    rank_y = y.argsort().argsort().float()
    return pearsonr(rank_x, rank_y)

def pairwise_ranking_loss(outputs, target_labels, masks_inputs):
    sorted_indices = torch.argsort(target_labels, dim=-1, descending=True)
    sorted_outputs = torch.gather(outputs, -1, sorted_indices)
    
    diff_vector = sorted_outputs[:, 1:] - sorted_outputs[:, :-1]
    sigmoid_diff = 1.0 / (1.0 + torch.exp(-diff_vector))
    loss = -torch.log(torch.clamp(sigmoid_diff, min=1e-10, max=1-1e-10))
    
    min_dim = min(masks_inputs.shape[1], loss.shape[1])
    masked_loss = loss[:, :min_dim] * masks_inputs[:, :min_dim]
    
    return torch.sum(masked_loss)

def calculate_loss(outputs, criterion, target_labels, masks_inputs, alpha_mse=0.5, alpha_corr=1.0, alpha_rank=1.0):
    
    # 1. Print basic info
    # print(f"Outputs Range: {outputs.min().item()}, {outputs.max().item()}")
    # print(f"Target Labels Range: {target_labels.min().item()}, {target_labels.max().item()}")
    # print(f"Masks Range: {masks_inputs.min().item()}, {masks_inputs.max().item()}")
    
    # Calculating the MSE loss
    mse_main = criterion(outputs * masks_inputs, target_labels * masks_inputs)
    # print(f"MSE Loss: {mse_main.item()}")
    
    # Calculating the Spearman Correlation
    non_zero_mask = masks_inputs.view(-1).nonzero().squeeze()
    spearman_corr = spearmanr(outputs[0][:, 0][non_zero_mask], target_labels[0][:, 0][non_zero_mask])
    # print(f"Spearman Correlation: {spearman_corr.item()}")

    # Calculating the Ranking Loss
    ranking_loss = pairwise_ranking_loss(outputs, target_labels, masks_inputs)
    # print(f"Ranking Loss: {ranking_loss.item()}")

    # Using alpha values to weight the losses and normalize them
    losses = [mse_main, -spearman_corr, ranking_loss]
    alphas = [alpha_mse, alpha_corr, alpha_rank]
    weights = [alpha / (loss + 1e-10) for alpha, loss in zip(alphas, losses)]
    normalized_weights = [weight / sum(weights) for weight in weights]
    
    combined_loss = sum(w * l for w, l in zip(normalized_weights, losses))
    
    return combined_loss.requires_grad_(), mse_main, spearman_corr


In [10]:
# Training loop
def train_on_batch(model, criterion, optimizer, batch, lookahead):
    inputs, labels, masks_inputs, target_names = batch

    # Get index for specific label dynamically
    specific_label_name = f'TARGET_ret_fwd_{lookahead:02d}d_rank_quantiled'
    specific_label_index = target_names.index(specific_label_name)

    # Use that index to fetch the specific column
    labels = labels[:, :, specific_label_index].unsqueeze(2)

    # print('labels shape: ', labels.shape)

    # Zero the parameter gradients
    optimizer.zero_grad()

    outputs = model(inputs / 4.0, masks_inputs)

    # print("Output shape: ", outputs.shape)

    assert labels.shape == outputs.shape, \
        f"Shape mismatch: labels {labels.shape}, outputs {outputs.shape}"

    loss, _mse, _corr = calculate_loss(outputs, criterion, labels, masks_inputs)
    
    loss.backward()
    optimizer.step()
    
    return loss.item(), _mse.item(), _corr.item()

def evaluate_on_batch(model, criterion, batch, lookahead):
    inputs, labels, masks_inputs, target_names = batch

    # Get index for specific label dynamically
    specific_label_name = f'TARGET_ret_fwd_{lookahead:02d}d_rank_quantiled'
    specific_label_index = target_names.index(specific_label_name)

    # Use that index to fetch the specific column
    labels = labels[:, :, specific_label_index].unsqueeze(2)

    model.eval()
    with torch.no_grad():
        outputs = model(inputs / 4, masks_inputs)

        # print('output form eval: ', outputs.shape)

        assert labels.shape == outputs.shape, \
            f"Shape mismatch: labels {labels.shape}, outputs {outputs.shape}"

        loss, mse, corr = calculate_loss(outputs, criterion, labels, masks_inputs)

        # Assuming masks_inputs is of shape (1, 253, 1)
        non_zero_indices = masks_inputs.squeeze().nonzero().squeeze()

        # Gather values from outputs tensor using the non-zero indices
        preds = torch.gather(outputs.squeeze(), 0, non_zero_indices).cpu().numpy()

    return loss.item(), mse.item(), corr.item(), preds


def compute_fold_metrics(era_scores, weights=None):
    era_scores = pd.Series(era_scores)
    
    # Calculate metrics
    mean_correlation = np.mean(era_scores)
    std_deviation = np.std(era_scores)
    sharpe_ratio = mean_correlation / std_deviation
    max_dd = (era_scores.cummax() - era_scores).max()

    # Smart Sharpe
    smart_sharpe = mean_correlation \
        / (std_deviation + np.std(era_scores.diff()))
    
    # Autocorrelation
    autocorrelation = era_scores.autocorr()

    metrics = pd.Series({
        'mean_correlation': mean_correlation,
        'std_deviation': std_deviation,
        'sharpe_ratio': sharpe_ratio,
        'smart_sharpe': smart_sharpe,
        'autocorrelation': autocorrelation,
        'max_dd': max_dd,
        'min_correlation': era_scores.min(),
        'max_correlation': era_scores.max(),
    })

    if weights:
        normalized_metrics = (metrics - metrics.min()) / (metrics.max() - metrics.min())
        weighted_values = normalized_metrics.multiply(pd.Series(weights))
        metrics["weighted_score"] = weighted_values.sum()

    _ = gc.collect()

    return metrics

In [11]:
from tqdm import tqdm

def train_model(model, criterion, optimizer, scheduler, \
                num_epochs, patience, train_loader, lookahead, \
                device, val_loader=None, is_lr_scheduler=True):
    best_score = float('-inf')  # Initialize with negative infinity since we want to maximize Sharpe ratio
    best_corr = None
    best_model_wts = None  # Changing from 'best_model' to avoid confusion with the model object
    all_val_scores = []
    all_val_outputs = {}
    no_improve_epoch = 0

    model = model.to(device)

    epoch_progress = tqdm(range(num_epochs), desc="Epochs", leave=False)

    for epoch in epoch_progress:
        total_loss = []
        total_corr = []

        # Training
        for era_num in tqdm(train_loader, desc="Training", leave=False):
            data = train_loader[era_num]
            batch = (data[0].to(device), data[1].to(device), data[2].to(device), data[3])
            
            loss, _mse, _corr = train_on_batch(model, criterion, optimizer, batch, lookahead)
            # print(loss)
            total_loss.append(loss)
            total_corr.append(_corr)

        # Adjust learning rate if is_lr_scheduler is True
        if is_lr_scheduler:
            scheduler.step()

        # Validation - Only if val_loader is provided
        if val_loader:
            val_total_loss = []
            val_total_corr = []
            val_total_outputs = {}

            with torch.no_grad():
                for era_num in tqdm(val_loader, desc="Validation", leave=False):
                    data = val_loader[era_num]
                    batch = (data[0].to(device), data[1].to(device), data[2].to(device), data[3])
                    
                    loss, _mse, _corr, outputs = evaluate_on_batch(model, criterion, batch, lookahead)
                    val_total_loss.append(loss)
                    val_total_corr.append(_corr)
                    val_total_outputs[era_num] = outputs

            all_val_scores.append(val_total_corr) 
            all_val_outputs.update(val_total_outputs)

            # Early stopping check based on Sharpe score
            current_score = np.mean(val_total_corr) / np.std(val_total_corr)  # Assuming Sharpe ratio here
            if current_score > best_score:
                best_score = current_score
                best_corr = val_total_corr.copy()
                best_model_wts = model.state_dict().copy()
                no_improve_epoch = 0
            else:
                no_improve_epoch += 1
                if no_improve_epoch >= patience:
                    epoch_progress.set_description(f'Early stopping at epoch {epoch+1}')
                    epoch_progress.refresh()
                    break

        torch.cuda.empty_cache()
        _ = gc.collect()

    if val_loader:  # If validation data was provided
        return best_model_wts, best_corr, all_val_scores
    else:  # If only training data was used without validation
        return model.state_dict(), None, None

In [12]:
import optuna
import mlflow
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from utils import CustomBackwardMultipleTimeSeriesCV
from model import Transformer
from model import RankPredictorNN
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


# Constants and hyperparameters
NUM_EPOCHS = 15
PATIENCE = 5
FEATURE_DIM = len(features)  # Assuming 'features' is defined elsewhere in your code
OUTPUT_DIM = 1
NUM_TRAIL = 25
device = "cuda" if torch.cuda.is_available() else "cpu"

# # Choose model
# model = Transformer(
#     input_dim=FEATURE_DIM,
#     d_model=hidden_dim,
#     output_dim=OUTPUT_DIM,
#     num_heads=num_heads,
#     num_layers=num_layers,
# ).to(device)

weights = {
    'mean_correlation': 0.0,
    'std_deviation': -0.025, # Mild penalty for higher volatility
    'sharpe_ratio': 0.95,    # Primary objective, so highest weight
    'smart_sharpe': 0.075,   # Supplementary to Sharpe Ratio but considering autocorrelation
    'autocorrelation': -0.1, # Penalize strategies showing signs of overfitting
    'max_dd': -0.1,          # Major risk metric, negative to penalize higher drawdowns
    'min_correlation': 0.0,
    'max_correlation': 0.0,
}

def objective(trial, dataset, device):  # Placeholder for dataset
    print(f"\n--- Starting Trial: {trial.number + 1} ---")

    # Suggest hyperparameters
    train_length_multiplier = trial.suggest_int('train_length_multiplier', 10, 15)
    val_period_length = trial.suggest_categorical('val_period_length', [21, 42, 63])
    lookahead = trial.suggest_categorical('lookahead', [1, 5, 21])
    num_heads = trial.suggest_int("num_heads", 1, 5)
    hidden_dim = trial.suggest_int("hidden_dim", 64, 256, step=2)
    num_layers = trial.suggest_int("num_layers", 1, 5)
    lr = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)

    print(f"Hyperparameters for this trial: {trial.params}")


    # Initialize CV and other variables
    cv = CustomBackwardMultipleTimeSeriesCV(dataset,
                                    train_period_length=int(21 * train_length_multiplier),
                                    test_period_length=val_period_length,
                                    lookahead=lookahead, date_idx='date')

    cv.update_lookahead(lookahead)
    fold_weighted_scores = []
    for train_idx, test_idx in cv:
        # Choose model
        model = Transformer(
            input_dim=FEATURE_DIM,
            d_model=hidden_dim,
            output_dim=OUTPUT_DIM,
            num_heads=num_heads,
            num_layers=num_layers).to(device)

        # # Initialize model, loss, optimizer
        # model = RankPredictorNN(input_dim=FEATURE_DIM, output_dim=OUTPUT_DIM).to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = StepLR(optimizer, step_size=100, gamma=0.1)

        # Prepare data batches
        train_data = dataset.iloc[train_idx]
        test_data = dataset.iloc[test_idx]
        train_batches = get_era2data(train_data)
        validation_batches = get_era2data(test_data)

        # Train and validate model
        _, val_corr_on_fold, _ = train_model(
            model, criterion, optimizer, scheduler, NUM_EPOCHS, PATIENCE,
            train_batches, lookahead, device, validation_batches, is_lr_scheduler=True
        )

        # print(val_corr_on_fold)

        # Compute metrics
        scores_on_fold = compute_fold_metrics(val_corr_on_fold)

        # Normalize and weight scores
        normalized_scores = (scores_on_fold - scores_on_fold.min()) \
            / (scores_on_fold.max() - scores_on_fold.min())
        weighted_scores_on_fold = normalized_scores.multiply(pd.Series(weights))

        # Append to list
        fold_weighted_scores.append(weighted_scores_on_fold.sum())

    # Calculate overall score
    overall_score = np.mean(fold_weighted_scores)
    # print('==================')
    # print('Overall score: ', overall_score)
    # print('==================')

    # Log metrics
    with mlflow.start_run():
        mlflow.log_params(trial.params)
        mlflow.log_metric("avg_score_across_folds", overall_score)

    return -overall_score if not np.isnan(overall_score) else 1e-9

def callback(study, trial):
    print(f"\n--- Trial {trial.number + 1} finished ---")
    print(f"Value: {trial.value} and parameters: {trial.params}")
    
    completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    
    if completed_trials:
        best_trial_number = study.best_trial.number + 1  # Adding 1 to align with your display
        print(f"Best is trial {best_trial_number} with value: {study.best_trial.value}\n")
    else:
        print("No successful trials yet.\n")

study_dir = "/home/sayem/Desktop/Project/study"
# study = optuna.create_study(study_name='Maximizing the Sharpe', direction='minimize',
#                             storage=f'sqlite:///{study_dir}/study.db', load_if_exists=True)
study = optuna.create_study(study_name='Maximizing the Sharpe', \
    direction='minimize', load_if_exists=True)
# study.optimize(objective, n_trials=NUM_TRAIL, callbacks=[callback])
study.optimize(lambda trial: objective(trial, dataset, device), n_trials=NUM_TRAIL, callbacks=[callback])

[I 2023-10-10 19:32:58,405] A new study created in memory with name: Maximizing the Sharpe



--- Starting Trial: 1 ---
Hyperparameters for this trial: {'train_length_multiplier': 14, 'val_period_length': 63, 'lookahead': 1, 'num_heads': 4, 'hidden_dim': 94, 'num_layers': 2, 'learning_rate': 0.0034386896867631745}


100%|██████████| 294/294 [00:00<00:00, 553.89it/s]
100%|██████████| 294/294 [00:00<00:00, 1412514.75it/s]
100%|██████████| 63/63 [00:00<00:00, 214.80it/s]
100%|██████████| 63/63 [00:00<00:00, 1355082.83it/s]
100%|██████████| 294/294 [00:00<00:00, 486.38it/s]                         
100%|██████████| 294/294 [00:00<00:00, 2129750.22it/s]
100%|██████████| 63/63 [00:00<00:00, 326.79it/s]
100%|██████████| 63/63 [00:00<00:00, 1601461.53it/s]
100%|██████████| 294/294 [00:00<00:00, 584.69it/s]                       
100%|██████████| 294/294 [00:00<00:00, 1992125.00it/s]
100%|██████████| 63/63 [00:00<00:00, 202.65it/s]
100%|██████████| 63/63 [00:00<00:00, 1428330.55it/s]
100%|██████████| 294/294 [00:00<00:00, 434.01it/s]                        
100%|██████████| 294/294 [00:00<00:00, 1167732.36it/s]
100%|██████████| 63/63 [00:00<00:00, 305.58it/s]
100%|██████████| 63/63 [00:00<00:00, 1270390.15it/s]
100%|██████████| 294/294 [00:00<00:00, 578.51it/s]                         
100%|██████████| 294


--- Trial 1 finished ---
Value: -0.5374418203570192 and parameters: {'train_length_multiplier': 14, 'val_period_length': 63, 'lookahead': 1, 'num_heads': 4, 'hidden_dim': 94, 'num_layers': 2, 'learning_rate': 0.0034386896867631745}
Best is trial 1 with value: -0.5374418203570192


--- Starting Trial: 2 ---
Hyperparameters for this trial: {'train_length_multiplier': 13, 'val_period_length': 63, 'lookahead': 1, 'num_heads': 5, 'hidden_dim': 100, 'num_layers': 1, 'learning_rate': 0.07862796272993648}


100%|██████████| 273/273 [00:00<00:00, 507.79it/s]
100%|██████████| 273/273 [00:00<00:00, 2041078.42it/s]
100%|██████████| 63/63 [00:00<00:00, 266.21it/s]
100%|██████████| 63/63 [00:00<00:00, 1327844.98it/s]
100%|██████████| 273/273 [00:00<00:00, 485.84it/s]                       
100%|██████████| 273/273 [00:00<00:00, 1585934.89it/s]
100%|██████████| 63/63 [00:00<00:00, 341.21it/s]
100%|██████████| 63/63 [00:00<00:00, 917504.00it/s]
100%|██████████| 273/273 [00:00<00:00, 397.35it/s]                       
100%|██████████| 273/273 [00:00<00:00, 1998333.32it/s]
100%|██████████| 63/63 [00:00<00:00, 258.09it/s]
100%|██████████| 63/63 [00:00<00:00, 1078535.31it/s]
100%|██████████| 273/273 [00:00<00:00, 650.95it/s]                       
100%|██████████| 273/273 [00:00<00:00, 1619582.73it/s]
100%|██████████| 63/63 [00:00<00:00, 263.85it/s]
100%|██████████| 63/63 [00:00<00:00, 1554359.72it/s]
100%|██████████| 273/273 [00:00<00:00, 641.92it/s]                       
100%|██████████| 273/273 [


--- Trial 2 finished ---
Value: -0.9712407100769145 and parameters: {'train_length_multiplier': 13, 'val_period_length': 63, 'lookahead': 1, 'num_heads': 5, 'hidden_dim': 100, 'num_layers': 1, 'learning_rate': 0.07862796272993648}
Best is trial 2 with value: -0.9712407100769145


--- Starting Trial: 3 ---
Hyperparameters for this trial: {'train_length_multiplier': 13, 'val_period_length': 21, 'lookahead': 5, 'num_heads': 3, 'hidden_dim': 94, 'num_layers': 2, 'learning_rate': 0.0038600273457302173}


100%|██████████| 273/273 [00:00<00:00, 605.23it/s]
100%|██████████| 273/273 [00:00<00:00, 1579372.40it/s]
100%|██████████| 21/21 [00:00<00:00, 6076.60it/s]
100%|██████████| 21/21 [00:00<00:00, 456375.05it/s]
100%|██████████| 273/273 [00:00<00:00, 546.08it/s]                       
100%|██████████| 273/273 [00:00<00:00, 2055736.07it/s]
100%|██████████| 21/21 [00:00<00:00, 7175.59it/s]
100%|██████████| 21/21 [00:00<00:00, 543706.07it/s]
100%|██████████| 273/273 [00:00<00:00, 395.49it/s]                       
100%|██████████| 273/273 [00:00<00:00, 1930935.91it/s]
100%|██████████| 21/21 [00:00<00:00, 6558.48it/s]
100%|██████████| 21/21 [00:00<00:00, 688128.00it/s]
100%|██████████| 273/273 [00:00<00:00, 581.52it/s]     
100%|██████████| 273/273 [00:00<00:00, 1647546.75it/s]
100%|██████████| 21/21 [00:00<00:00, 6491.30it/s]
100%|██████████| 21/21 [00:00<00:00, 540370.45it/s]
100%|██████████| 273/273 [00:00<00:00, 481.00it/s]                       
100%|██████████| 273/273 [00:00<00:00, 1534


--- Trial 3 finished ---
Value: -0.9358579392046803 and parameters: {'train_length_multiplier': 13, 'val_period_length': 21, 'lookahead': 5, 'num_heads': 3, 'hidden_dim': 94, 'num_layers': 2, 'learning_rate': 0.0038600273457302173}
Best is trial 2 with value: -0.9712407100769145


--- Starting Trial: 4 ---
Hyperparameters for this trial: {'train_length_multiplier': 14, 'val_period_length': 63, 'lookahead': 5, 'num_heads': 5, 'hidden_dim': 126, 'num_layers': 3, 'learning_rate': 0.004944406684974152}


100%|██████████| 294/294 [00:00<00:00, 491.20it/s]
100%|██████████| 294/294 [00:00<00:00, 2093591.47it/s]
100%|██████████| 63/63 [00:00<00:00, 216.29it/s]
100%|██████████| 63/63 [00:00<00:00, 1591814.17it/s]
100%|██████████| 294/294 [00:00<00:00, 457.34it/s]                       
100%|██████████| 294/294 [00:00<00:00, 1244324.29it/s]
100%|██████████| 63/63 [00:00<00:00, 185.15it/s]
100%|██████████| 63/63 [00:00<00:00, 1234771.74it/s]
100%|██████████| 294/294 [00:00<00:00, 471.93it/s]                        
100%|██████████| 294/294 [00:00<00:00, 1646362.32it/s]
100%|██████████| 63/63 [00:00<00:00, 279.54it/s]
100%|██████████| 63/63 [00:00<00:00, 1190275.46it/s]
100%|██████████| 294/294 [00:00<00:00, 505.79it/s]                       
100%|██████████| 294/294 [00:00<00:00, 2266774.59it/s]
100%|██████████| 63/63 [00:00<00:00, 294.72it/s]
100%|██████████| 63/63 [00:00<00:00, 1148874.57it/s]
100%|██████████| 294/294 [00:00<00:00, 636.44it/s]                         
100%|██████████| 294/2


--- Trial 4 finished ---
Value: -0.8275171345926184 and parameters: {'train_length_multiplier': 14, 'val_period_length': 63, 'lookahead': 5, 'num_heads': 5, 'hidden_dim': 126, 'num_layers': 3, 'learning_rate': 0.004944406684974152}
Best is trial 2 with value: -0.9712407100769145


--- Starting Trial: 5 ---
Hyperparameters for this trial: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 1, 'num_heads': 1, 'hidden_dim': 152, 'num_layers': 4, 'learning_rate': 0.00099624858661813}


100%|██████████| 252/252 [00:00<00:00, 724.27it/s]
100%|██████████| 252/252 [00:00<00:00, 2202009.60it/s]
100%|██████████| 42/42 [00:00<00:00, 225.03it/s]
100%|██████████| 42/42 [00:00<00:00, 1285845.02it/s]
100%|██████████| 252/252 [00:00<00:00, 466.01it/s]                       
100%|██████████| 252/252 [00:00<00:00, 1950119.20it/s]
100%|██████████| 42/42 [00:00<00:00, 223.03it/s]
100%|██████████| 42/42 [00:00<00:00, 978670.93it/s]
100%|██████████| 252/252 [00:00<00:00, 500.00it/s]                        
100%|██████████| 252/252 [00:00<00:00, 1509949.44it/s]
100%|██████████| 42/42 [00:00<00:00, 239.12it/s]
100%|██████████| 42/42 [00:00<00:00, 957395.48it/s]
100%|██████████| 252/252 [00:00<00:00, 517.29it/s]                        
100%|██████████| 252/252 [00:00<00:00, 1880719.94it/s]
100%|██████████| 42/42 [00:00<00:00, 219.81it/s]
100%|██████████| 42/42 [00:00<00:00, 842874.49it/s]
100%|██████████| 252/252 [00:00<00:00, 607.76it/s]                       
100%|██████████| 252/252 [


--- Trial 5 finished ---
Value: -0.7052254776520345 and parameters: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 1, 'num_heads': 1, 'hidden_dim': 152, 'num_layers': 4, 'learning_rate': 0.00099624858661813}
Best is trial 2 with value: -0.9712407100769145


--- Starting Trial: 6 ---
Hyperparameters for this trial: {'train_length_multiplier': 14, 'val_period_length': 21, 'lookahead': 1, 'num_heads': 5, 'hidden_dim': 108, 'num_layers': 2, 'learning_rate': 0.001792251065956423}


100%|██████████| 294/294 [00:00<00:00, 505.35it/s]
100%|██████████| 294/294 [00:00<00:00, 1874050.72it/s]
100%|██████████| 21/21 [00:00<00:00, 7247.03it/s]
100%|██████████| 21/21 [00:00<00:00, 503316.48it/s]
100%|██████████| 294/294 [00:00<00:00, 553.45it/s]                        
100%|██████████| 294/294 [00:00<00:00, 1941929.73it/s]
100%|██████████| 21/21 [00:00<00:00, 7208.48it/s]
100%|██████████| 21/21 [00:00<00:00, 350917.86it/s]
100%|██████████| 294/294 [00:00<00:00, 550.58it/s]                       
100%|██████████| 294/294 [00:00<00:00, 2018208.47it/s]
100%|██████████| 21/21 [00:00<00:00, 7521.81it/s]
100%|██████████| 21/21 [00:00<00:00, 451694.28it/s]
100%|██████████| 294/294 [00:00<00:00, 674.23it/s]                        
100%|██████████| 294/294 [00:00<00:00, 1771731.86it/s]
100%|██████████| 21/21 [00:00<00:00, 5737.76it/s]
100%|██████████| 21/21 [00:00<00:00, 727937.06it/s]
100%|██████████| 294/294 [00:00<00:00, 397.71it/s]     
100%|██████████| 294/294 [00:00<00:00, 16


--- Trial 6 finished ---
Value: -0.8942850392270816 and parameters: {'train_length_multiplier': 14, 'val_period_length': 21, 'lookahead': 1, 'num_heads': 5, 'hidden_dim': 108, 'num_layers': 2, 'learning_rate': 0.001792251065956423}
Best is trial 2 with value: -0.9712407100769145


--- Starting Trial: 7 ---
Hyperparameters for this trial: {'train_length_multiplier': 13, 'val_period_length': 63, 'lookahead': 21, 'num_heads': 4, 'hidden_dim': 138, 'num_layers': 4, 'learning_rate': 0.0004036225082442523}


100%|██████████| 273/273 [00:00<00:00, 693.37it/s]
100%|██████████| 273/273 [00:00<00:00, 1676493.40it/s]
100%|██████████| 63/63 [00:00<00:00, 329.01it/s]
100%|██████████| 63/63 [00:00<00:00, 1518627.31it/s]
100%|██████████| 273/273 [00:00<00:00, 480.15it/s]     
100%|██████████| 273/273 [00:00<00:00, 1964056.59it/s]
100%|██████████| 63/63 [00:00<00:00, 196.40it/s]
100%|██████████| 63/63 [00:00<00:00, 1468006.40it/s]
100%|██████████| 273/273 [00:00<00:00, 562.57it/s]                       
100%|██████████| 273/273 [00:00<00:00, 2136277.97it/s]
100%|██████████| 63/63 [00:00<00:00, 295.08it/s]
100%|██████████| 63/63 [00:00<00:00, 1355082.83it/s]
100%|██████████| 273/273 [00:00<00:00, 678.01it/s]     
100%|██████████| 273/273 [00:00<00:00, 2097152.00it/s]
100%|██████████| 63/63 [00:00<00:00, 246.07it/s]
100%|██████████| 63/63 [00:00<00:00, 1124430.43it/s]
100%|██████████| 273/273 [00:00<00:00, 759.83it/s]                       
100%|██████████| 273/273 [00:00<00:00, 2245186.26it/s]
100%|█


--- Trial 7 finished ---
Value: -0.9802443947929206 and parameters: {'train_length_multiplier': 13, 'val_period_length': 63, 'lookahead': 21, 'num_heads': 4, 'hidden_dim': 138, 'num_layers': 4, 'learning_rate': 0.0004036225082442523}
Best is trial 7 with value: -0.9802443947929206


--- Starting Trial: 8 ---
Hyperparameters for this trial: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 200, 'num_layers': 1, 'learning_rate': 9.660658336668312e-05}


100%|██████████| 210/210 [00:00<00:00, 503.58it/s]
100%|██████████| 210/210 [00:00<00:00, 1492887.86it/s]
100%|██████████| 42/42 [00:00<00:00, 211.77it/s]
100%|██████████| 42/42 [00:00<00:00, 811800.77it/s]
100%|██████████| 210/210 [00:00<00:00, 556.42it/s]     
100%|██████████| 210/210 [00:00<00:00, 1550710.99it/s]
100%|██████████| 42/42 [00:00<00:00, 229.21it/s]
100%|██████████| 42/42 [00:00<00:00, 1129235.69it/s]
100%|██████████| 210/210 [00:00<00:00, 518.65it/s]     
100%|██████████| 210/210 [00:00<00:00, 2102157.14it/s]
100%|██████████| 42/42 [00:00<00:00, 232.60it/s]
100%|██████████| 42/42 [00:00<00:00, 1190275.46it/s]
100%|██████████| 210/210 [00:00<00:00, 437.06it/s]                         
100%|██████████| 210/210 [00:00<00:00, 1637181.86it/s]
100%|██████████| 42/42 [00:00<00:00, 262.02it/s]
100%|██████████| 42/42 [00:00<00:00, 855149.36it/s]
100%|██████████| 210/210 [00:00<00:00, 662.76it/s]     
100%|██████████| 210/210 [00:00<00:00, 2087212.89it/s]
100%|██████████| 42/42 [


--- Trial 8 finished ---
Value: -0.9827196478552414 and parameters: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 200, 'num_layers': 1, 'learning_rate': 9.660658336668312e-05}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 9 ---
Hyperparameters for this trial: {'train_length_multiplier': 14, 'val_period_length': 63, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 178, 'num_layers': 3, 'learning_rate': 0.002990679283249535}


100%|██████████| 294/294 [00:00<00:00, 554.72it/s]
100%|██████████| 294/294 [00:00<00:00, 1914790.96it/s]
100%|██████████| 63/63 [00:00<00:00, 309.70it/s]
100%|██████████| 63/63 [00:00<00:00, 1110256.94it/s]
100%|██████████| 294/294 [00:00<00:00, 572.20it/s]                       
100%|██████████| 294/294 [00:00<00:00, 1537562.81it/s]
100%|██████████| 63/63 [00:00<00:00, 276.72it/s]
100%|██████████| 63/63 [00:00<00:00, 1308124.51it/s]
100%|██████████| 294/294 [00:00<00:00, 530.62it/s]                         
100%|██████████| 294/294 [00:00<00:00, 1957341.87it/s]
100%|██████████| 63/63 [00:00<00:00, 298.52it/s]
100%|██████████| 63/63 [00:00<00:00, 1223338.67it/s]
100%|██████████| 294/294 [00:00<00:00, 566.93it/s]                       
100%|██████████| 294/294 [00:00<00:00, 1944992.71it/s]
100%|██████████| 63/63 [00:00<00:00, 163.36it/s]
100%|██████████| 63/63 [00:00<00:00, 1484500.85it/s]
[I 2023-10-10 20:07:26,358] Trial 8 finished with value: -0.758372578141234 and parameters: {'tra


--- Trial 9 finished ---
Value: -0.758372578141234 and parameters: {'train_length_multiplier': 14, 'val_period_length': 63, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 178, 'num_layers': 3, 'learning_rate': 0.002990679283249535}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 10 ---
Hyperparameters for this trial: {'train_length_multiplier': 14, 'val_period_length': 21, 'lookahead': 5, 'num_heads': 5, 'hidden_dim': 230, 'num_layers': 5, 'learning_rate': 4.080452605741284e-05}


100%|██████████| 294/294 [00:00<00:00, 446.64it/s]
100%|██████████| 294/294 [00:00<00:00, 1789732.04it/s]
100%|██████████| 21/21 [00:00<00:00, 7328.43it/s]
100%|██████████| 21/21 [00:00<00:00, 642922.51it/s]
100%|██████████| 294/294 [00:00<00:00, 510.42it/s]                       
100%|██████████| 294/294 [00:00<00:00, 1635444.80it/s]
100%|██████████| 21/21 [00:00<00:00, 7266.16it/s]
100%|██████████| 21/21 [00:00<00:00, 716100.68it/s]
100%|██████████| 294/294 [00:00<00:00, 496.53it/s]     
100%|██████████| 294/294 [00:00<00:00, 2174824.30it/s]
100%|██████████| 21/21 [00:00<00:00, 7062.25it/s]
100%|██████████| 21/21 [00:00<00:00, 629145.60it/s]
100%|██████████| 294/294 [00:00<00:00, 564.65it/s]     
100%|██████████| 294/294 [00:00<00:00, 2100724.66it/s]
100%|██████████| 21/21 [00:00<00:00, 7515.39it/s]
100%|██████████| 21/21 [00:00<00:00, 688128.00it/s]
100%|██████████| 294/294 [00:00<00:00, 519.65it/s]                         
100%|██████████| 294/294 [00:00<00:00, 1969848.84it/s]
100%


--- Trial 10 finished ---
Value: -0.9791242284956218 and parameters: {'train_length_multiplier': 14, 'val_period_length': 21, 'lookahead': 5, 'num_heads': 5, 'hidden_dim': 230, 'num_layers': 5, 'learning_rate': 4.080452605741284e-05}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 11 ---
Hyperparameters for this trial: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 1, 'hidden_dim': 248, 'num_layers': 1, 'learning_rate': 1.1582577045212088e-05}


100%|██████████| 210/210 [00:00<00:00, 467.51it/s]
100%|██████████| 210/210 [00:00<00:00, 1578501.51it/s]
100%|██████████| 42/42 [00:00<00:00, 236.72it/s]
100%|██████████| 42/42 [00:00<00:00, 952220.37it/s]
100%|██████████| 210/210 [00:00<00:00, 614.60it/s]     
100%|██████████| 210/210 [00:00<00:00, 1765137.96it/s]
100%|██████████| 42/42 [00:00<00:00, 228.19it/s]
100%|██████████| 42/42 [00:00<00:00, 912750.09it/s]
100%|██████████| 210/210 [00:00<00:00, 538.73it/s]     
100%|██████████| 210/210 [00:00<00:00, 1631118.22it/s]
100%|██████████| 42/42 [00:00<00:00, 218.92it/s]
100%|██████████| 42/42 [00:00<00:00, 1129235.69it/s]
100%|██████████| 210/210 [00:00<00:00, 709.88it/s]     
100%|██████████| 210/210 [00:00<00:00, 2043628.40it/s]
100%|██████████| 42/42 [00:00<00:00, 228.81it/s]
100%|██████████| 42/42 [00:00<00:00, 880803.84it/s]
100%|██████████| 210/210 [00:00<00:00, 607.36it/s]                         
100%|██████████| 210/210 [00:00<00:00, 2038897.78it/s]
100%|██████████| 42/42 [0


--- Trial 11 finished ---
Value: -0.9733146744148821 and parameters: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 1, 'hidden_dim': 248, 'num_layers': 1, 'learning_rate': 1.1582577045212088e-05}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 12 ---
Hyperparameters for this trial: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 196, 'num_layers': 5, 'learning_rate': 0.00012022262507505991}


100%|██████████| 210/210 [00:00<00:00, 529.65it/s]
100%|██████████| 210/210 [00:00<00:00, 1874050.72it/s]
100%|██████████| 42/42 [00:00<00:00, 216.62it/s]
100%|██████████| 42/42 [00:00<00:00, 1048576.00it/s]
100%|██████████| 210/210 [00:00<00:00, 533.75it/s]     
100%|██████████| 210/210 [00:00<00:00, 1970478.39it/s]
100%|██████████| 42/42 [00:00<00:00, 208.62it/s]
100%|██████████| 42/42 [00:00<00:00, 978670.93it/s]
100%|██████████| 210/210 [00:00<00:00, 752.45it/s]                       
100%|██████████| 210/210 [00:00<00:00, 1581335.44it/s]
100%|██████████| 42/42 [00:00<00:00, 223.02it/s]
100%|██████████| 42/42 [00:00<00:00, 1054854.90it/s]
100%|██████████| 210/210 [00:00<00:00, 548.92it/s]                         
100%|██████████| 210/210 [00:00<00:00, 1914790.96it/s]
100%|██████████| 42/42 [00:00<00:00, 260.98it/s]
100%|██████████| 42/42 [00:00<00:00, 1006632.96it/s]
100%|██████████| 210/210 [00:00<00:00, 490.70it/s]     
100%|██████████| 210/210 [00:00<00:00, 2029501.94it/s]
100%|


--- Trial 12 finished ---
Value: -0.9816870682846127 and parameters: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 196, 'num_layers': 5, 'learning_rate': 0.00012022262507505991}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 13 ---
Hyperparameters for this trial: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 196, 'num_layers': 5, 'learning_rate': 0.00012971424321307542}


100%|██████████| 210/210 [00:00<00:00, 660.39it/s]
100%|██████████| 210/210 [00:00<00:00, 1631118.22it/s]
100%|██████████| 42/42 [00:00<00:00, 227.05it/s]
100%|██████████| 42/42 [00:00<00:00, 880803.84it/s]
100%|██████████| 210/210 [00:00<00:00, 590.79it/s]     
100%|██████████| 210/210 [00:00<00:00, 1744166.02it/s]
100%|██████████| 42/42 [00:00<00:00, 223.99it/s]
100%|██████████| 42/42 [00:00<00:00, 1143901.09it/s]
100%|██████████| 210/210 [00:00<00:00, 639.74it/s]     
100%|██████████| 210/210 [00:00<00:00, 1674532.02it/s]
100%|██████████| 42/42 [00:00<00:00, 254.90it/s]
100%|██████████| 42/42 [00:00<00:00, 1087412.15it/s]
100%|██████████| 210/210 [00:00<00:00, 451.81it/s]     
100%|██████████| 210/210 [00:00<00:00, 1503078.23it/s]
100%|██████████| 42/42 [00:00<00:00, 240.90it/s]
100%|██████████| 42/42 [00:00<00:00, 995258.58it/s]
100%|██████████| 210/210 [00:00<00:00, 729.41it/s]     
100%|██████████| 210/210 [00:00<00:00, 1655646.32it/s]
100%|██████████| 42/42 [00:00<00:00, 220.85i


--- Trial 13 finished ---
Value: -0.9819451970021351 and parameters: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 196, 'num_layers': 5, 'learning_rate': 0.00012971424321307542}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 14 ---
Hyperparameters for this trial: {'train_length_multiplier': 11, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 210, 'num_layers': 4, 'learning_rate': 0.00011651596652491465}


100%|██████████| 231/231 [00:00<00:00, 565.48it/s]
100%|██████████| 231/231 [00:00<00:00, 1997699.43it/s]
100%|██████████| 42/42 [00:00<00:00, 222.90it/s]
100%|██████████| 42/42 [00:00<00:00, 952220.37it/s]
100%|██████████| 231/231 [00:00<00:00, 697.87it/s]                       
100%|██████████| 231/231 [00:00<00:00, 1907252.41it/s]
100%|██████████| 42/42 [00:00<00:00, 219.70it/s]
100%|██████████| 42/42 [00:00<00:00, 957395.48it/s]
100%|██████████| 231/231 [00:00<00:00, 592.78it/s]     
100%|██████████| 231/231 [00:00<00:00, 1945550.65it/s]
100%|██████████| 42/42 [00:00<00:00, 227.25it/s]
100%|██████████| 42/42 [00:00<00:00, 1122043.11it/s]
100%|██████████| 231/231 [00:00<00:00, 543.40it/s]     
100%|██████████| 231/231 [00:00<00:00, 2065851.22it/s]
100%|██████████| 42/42 [00:00<00:00, 262.27it/s]
100%|██████████| 42/42 [00:00<00:00, 1054854.90it/s]
100%|██████████| 231/231 [00:00<00:00, 598.00it/s]     
100%|██████████| 231/231 [00:00<00:00, 1575421.50it/s]
100%|██████████| 42/42 [00


--- Trial 14 finished ---
Value: -0.9692764319505631 and parameters: {'train_length_multiplier': 11, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 210, 'num_layers': 4, 'learning_rate': 0.00011651596652491465}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 15 ---
Hyperparameters for this trial: {'train_length_multiplier': 11, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 174, 'num_layers': 5, 'learning_rate': 1.3768852654454538e-05}


100%|██████████| 231/231 [00:00<00:00, 497.47it/s]
100%|██████████| 231/231 [00:00<00:00, 2124746.11it/s]
100%|██████████| 42/42 [00:00<00:00, 225.46it/s]
100%|██████████| 42/42 [00:00<00:00, 1249367.15it/s]
100%|██████████| 231/231 [00:00<00:00, 575.62it/s]     
100%|██████████| 231/231 [00:00<00:00, 1699796.88it/s]
100%|██████████| 42/42 [00:00<00:00, 221.87it/s]
100%|██████████| 42/42 [00:00<00:00, 1122043.11it/s]
100%|██████████| 231/231 [00:00<00:00, 575.55it/s]                         
100%|██████████| 231/231 [00:00<00:00, 1676270.28it/s]
100%|██████████| 42/42 [00:00<00:00, 223.07it/s]
100%|██████████| 42/42 [00:00<00:00, 947100.90it/s]
100%|██████████| 231/231 [00:00<00:00, 545.02it/s]     
100%|██████████| 231/231 [00:00<00:00, 2101701.14it/s]
100%|██████████| 42/42 [00:00<00:00, 232.80it/s]
100%|██████████| 42/42 [00:00<00:00, 1190275.46it/s]
100%|██████████| 231/231 [00:00<00:00, 620.41it/s]     
100%|██████████| 231/231 [00:00<00:00, 2115467.74it/s]
100%|██████████| 42/42 


--- Trial 15 finished ---
Value: -0.96038748384224 and parameters: {'train_length_multiplier': 11, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 174, 'num_layers': 5, 'learning_rate': 1.3768852654454538e-05}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 16 ---
Hyperparameters for this trial: {'train_length_multiplier': 11, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 202, 'num_layers': 1, 'learning_rate': 0.00013035233958766097}


100%|██████████| 231/231 [00:00<00:00, 557.27it/s]
100%|██████████| 231/231 [00:00<00:00, 2279727.59it/s]
100%|██████████| 42/42 [00:00<00:00, 230.77it/s]
100%|██████████| 42/42 [00:00<00:00, 1231893.48it/s]
100%|██████████| 231/231 [00:00<00:00, 544.10it/s]     
100%|██████████| 231/231 [00:00<00:00, 1593559.58it/s]
100%|██████████| 42/42 [00:00<00:00, 230.57it/s]
100%|██████████| 42/42 [00:00<00:00, 962627.15it/s]
100%|██████████| 231/231 [00:00<00:00, 759.27it/s]     
100%|██████████| 231/231 [00:00<00:00, 1949465.24it/s]
100%|██████████| 42/42 [00:00<00:00, 227.19it/s]
100%|██████████| 42/42 [00:00<00:00, 872083.01it/s]
100%|██████████| 231/231 [00:00<00:00, 637.38it/s]     
100%|██████████| 231/231 [00:00<00:00, 1609442.23it/s]
100%|██████████| 42/42 [00:00<00:00, 232.01it/s]
100%|██████████| 42/42 [00:00<00:00, 1048576.00it/s]
100%|██████████| 231/231 [00:00<00:00, 573.85it/s]     
100%|██████████| 231/231 [00:00<00:00, 1969276.88it/s]
100%|██████████| 42/42 [00:00<00:00, 234.46i


--- Trial 16 finished ---
Value: -0.9806579964260068 and parameters: {'train_length_multiplier': 11, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 2, 'hidden_dim': 202, 'num_layers': 1, 'learning_rate': 0.00013035233958766097}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 17 ---
Hyperparameters for this trial: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 1, 'hidden_dim': 256, 'num_layers': 3, 'learning_rate': 3.806522346414092e-05}


100%|██████████| 210/210 [00:00<00:00, 438.26it/s]
100%|██████████| 210/210 [00:00<00:00, 1661894.04it/s]
100%|██████████| 42/42 [00:00<00:00, 228.39it/s]
100%|██████████| 42/42 [00:00<00:00, 995258.58it/s]
100%|██████████| 210/210 [00:00<00:00, 508.74it/s]     
100%|██████████| 210/210 [00:00<00:00, 1649445.39it/s]
100%|██████████| 42/42 [00:00<00:00, 244.87it/s]
100%|██████████| 42/42 [00:00<00:00, 1074151.02it/s]
100%|██████████| 210/210 [00:00<00:00, 547.19it/s]                        
100%|██████████| 210/210 [00:00<00:00, 2241231.15it/s]
100%|██████████| 42/42 [00:00<00:00, 235.45it/s]
100%|██████████| 42/42 [00:00<00:00, 932067.56it/s]
100%|██████████| 210/210 [00:00<00:00, 578.15it/s]     
100%|██████████| 210/210 [00:00<00:00, 1529173.33it/s]
100%|██████████| 42/42 [00:00<00:00, 244.07it/s]
100%|██████████| 42/42 [00:00<00:00, 1012418.21it/s]
100%|██████████| 210/210 [00:00<00:00, 633.47it/s]     
100%|██████████| 210/210 [00:00<00:00, 2024836.41it/s]
100%|██████████| 42/42 [0


--- Trial 17 finished ---
Value: -0.9766937384121194 and parameters: {'train_length_multiplier': 10, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 1, 'hidden_dim': 256, 'num_layers': 3, 'learning_rate': 3.806522346414092e-05}
Best is trial 8 with value: -0.9827196478552414


--- Starting Trial: 18 ---
Hyperparameters for this trial: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 64, 'num_layers': 2, 'learning_rate': 0.000373425693937419}


100%|██████████| 252/252 [00:00<00:00, 616.79it/s]
100%|██████████| 252/252 [00:00<00:00, 2174824.30it/s]
100%|██████████| 42/42 [00:00<00:00, 183.70it/s]
100%|██████████| 42/42 [00:00<00:00, 894217.10it/s]
100%|██████████| 252/252 [00:00<00:00, 464.53it/s]     
100%|██████████| 252/252 [00:00<00:00, 2105507.19it/s]
100%|██████████| 42/42 [00:00<00:00, 221.58it/s]
100%|██████████| 42/42 [00:00<00:00, 1000913.45it/s]
100%|██████████| 252/252 [00:00<00:00, 663.26it/s]     
100%|██████████| 252/252 [00:00<00:00, 2109709.80it/s]
100%|██████████| 42/42 [00:00<00:00, 228.32it/s]
100%|██████████| 42/42 [00:00<00:00, 1030179.93it/s]
100%|██████████| 252/252 [00:00<00:00, 558.40it/s]     
100%|██████████| 252/252 [00:00<00:00, 2060359.86it/s]
100%|██████████| 42/42 [00:00<00:00, 205.04it/s]
100%|██████████| 42/42 [00:00<00:00, 1048576.00it/s]
100%|██████████| 252/252 [00:00<00:00, 583.36it/s]                       
100%|██████████| 252/252 [00:00<00:00, 2056351.38it/s]
100%|██████████| 42/42 [0


--- Trial 18 finished ---
Value: -0.9830387355178691 and parameters: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 64, 'num_layers': 2, 'learning_rate': 0.000373425693937419}
Best is trial 18 with value: -0.9830387355178691


--- Starting Trial: 19 ---
Hyperparameters for this trial: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 4, 'hidden_dim': 68, 'num_layers': 2, 'learning_rate': 0.0004279967617071587}


100%|██████████| 252/252 [00:00<00:00, 569.66it/s]
100%|██████████| 252/252 [00:00<00:00, 1623601.55it/s]
100%|██████████| 42/42 [00:00<00:00, 210.12it/s]
100%|██████████| 42/42 [00:00<00:00, 1067641.02it/s]
100%|██████████| 252/252 [00:00<00:00, 480.26it/s]                         
100%|██████████| 252/252 [00:00<00:00, 1819216.19it/s]
100%|██████████| 42/42 [00:00<00:00, 231.06it/s]
100%|██████████| 42/42 [00:00<00:00, 885229.99it/s]
100%|██████████| 252/252 [00:00<00:00, 586.10it/s]                         
100%|██████████| 252/252 [00:00<00:00, 1904440.74it/s]
100%|██████████| 42/42 [00:00<00:00, 230.56it/s]
100%|██████████| 42/42 [00:00<00:00, 984138.37it/s]
100%|██████████| 252/252 [00:00<00:00, 506.54it/s]     
100%|██████████| 252/252 [00:00<00:00, 1669770.31it/s]
100%|██████████| 42/42 [00:00<00:00, 198.80it/s]
100%|██████████| 42/42 [00:00<00:00, 989667.24it/s]
100%|██████████| 252/252 [00:00<00:00, 672.82it/s]     
100%|██████████| 252/252 [00:00<00:00, 1363825.30it/s]
100%|


--- Trial 19 finished ---
Value: -0.983351756286768 and parameters: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 4, 'hidden_dim': 68, 'num_layers': 2, 'learning_rate': 0.0004279967617071587}
Best is trial 19 with value: -0.983351756286768


--- Starting Trial: 20 ---
Hyperparameters for this trial: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 5, 'num_heads': 4, 'hidden_dim': 70, 'num_layers': 2, 'learning_rate': 0.0004970957343277877}


100%|██████████| 252/252 [00:00<00:00, 604.86it/s]
100%|██████████| 252/252 [00:00<00:00, 1638704.82it/s]
100%|██████████| 42/42 [00:00<00:00, 237.03it/s]
100%|██████████| 42/42 [00:00<00:00, 1054854.90it/s]
100%|██████████| 252/252 [00:00<00:00, 698.22it/s]                       
100%|██████████| 252/252 [00:00<00:00, 2072479.62it/s]
100%|██████████| 42/42 [00:00<00:00, 230.04it/s]
100%|██████████| 42/42 [00:00<00:00, 912750.09it/s]
100%|██████████| 252/252 [00:00<00:00, 723.73it/s]     
100%|██████████| 252/252 [00:00<00:00, 2118165.55it/s]
100%|██████████| 42/42 [00:00<00:00, 236.43it/s]
100%|██████████| 42/42 [00:00<00:00, 885229.99it/s]
100%|██████████| 252/252 [00:00<00:00, 527.76it/s]     
100%|██████████| 252/252 [00:00<00:00, 2060359.86it/s]
100%|██████████| 42/42 [00:00<00:00, 237.62it/s]
100%|██████████| 42/42 [00:00<00:00, 903388.55it/s]
100%|██████████| 252/252 [00:00<00:00, 727.47it/s]                       
100%|██████████| 252/252 [00:00<00:00, 2028722.86it/s]
100%|████


--- Trial 20 finished ---
Value: -0.9799811275939825 and parameters: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 5, 'num_heads': 4, 'hidden_dim': 70, 'num_layers': 2, 'learning_rate': 0.0004970957343277877}
Best is trial 19 with value: -0.983351756286768


--- Starting Trial: 21 ---
Hyperparameters for this trial: {'train_length_multiplier': 15, 'val_period_length': 21, 'lookahead': 21, 'num_heads': 4, 'hidden_dim': 64, 'num_layers': 2, 'learning_rate': 0.0004025197421153041}


100%|██████████| 315/315 [00:00<00:00, 525.37it/s]
100%|██████████| 315/315 [00:00<00:00, 2017108.03it/s]
100%|██████████| 21/21 [00:00<00:00, 7612.17it/s]
100%|██████████| 21/21 [00:00<00:00, 688128.00it/s]
100%|██████████| 315/315 [00:00<00:00, 517.06it/s]     
100%|██████████| 315/315 [00:00<00:00, 2144814.55it/s]
100%|██████████| 21/21 [00:00<00:00, 7432.94it/s]
100%|██████████| 21/21 [00:00<00:00, 629145.60it/s]
100%|██████████| 315/315 [00:00<00:00, 559.53it/s]     
100%|██████████| 315/315 [00:00<00:00, 2497553.42it/s]
100%|██████████| 21/21 [00:00<00:00, 8003.67it/s]
100%|██████████| 21/21 [00:00<00:00, 537075.51it/s]
100%|██████████| 315/315 [00:00<00:00, 685.97it/s]                       
100%|██████████| 315/315 [00:00<00:00, 2054752.35it/s]
100%|██████████| 21/21 [00:00<00:00, 7394.88it/s]
100%|██████████| 21/21 [00:00<00:00, 657316.30it/s]
100%|██████████| 315/315 [00:00<00:00, 662.66it/s]                         
100%|██████████| 315/315 [00:00<00:00, 2051561.74it/s]
100%


--- Trial 21 finished ---
Value: -0.9797462574324666 and parameters: {'train_length_multiplier': 15, 'val_period_length': 21, 'lookahead': 21, 'num_heads': 4, 'hidden_dim': 64, 'num_layers': 2, 'learning_rate': 0.0004025197421153041}
Best is trial 19 with value: -0.983351756286768


--- Starting Trial: 22 ---
Hyperparameters for this trial: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 74, 'num_layers': 1, 'learning_rate': 0.00034378137722093204}


100%|██████████| 252/252 [00:00<00:00, 529.25it/s]
100%|██████████| 252/252 [00:00<00:00, 2056351.38it/s]
100%|██████████| 42/42 [00:00<00:00, 218.30it/s]
100%|██████████| 42/42 [00:00<00:00, 917504.00it/s]
100%|██████████| 252/252 [00:00<00:00, 685.59it/s]     
100%|██████████| 252/252 [00:00<00:00, 2097152.00it/s]
100%|██████████| 42/42 [00:00<00:00, 256.76it/s]
100%|██████████| 42/42 [00:00<00:00, 942036.19it/s]
100%|██████████| 252/252 [00:00<00:00, 473.21it/s]                        
100%|██████████| 252/252 [00:00<00:00, 2273042.17it/s]
100%|██████████| 42/42 [00:00<00:00, 263.25it/s]
100%|██████████| 42/42 [00:00<00:00, 947100.90it/s]
100%|██████████| 252/252 [00:00<00:00, 492.58it/s]     
100%|██████████| 252/252 [00:00<00:00, 2024836.41it/s]
100%|██████████| 42/42 [00:00<00:00, 182.19it/s]
100%|██████████| 42/42 [00:00<00:00, 947100.90it/s]
100%|██████████| 252/252 [00:00<00:00, 569.88it/s]     
100%|██████████| 252/252 [00:00<00:00, 1608774.14it/s]
100%|██████████| 42/42 [00:


--- Trial 22 finished ---
Value: -0.9856337684812899 and parameters: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 74, 'num_layers': 1, 'learning_rate': 0.00034378137722093204}
Best is trial 22 with value: -0.9856337684812899


--- Starting Trial: 23 ---
Hyperparameters for this trial: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 78, 'num_layers': 1, 'learning_rate': 0.0007979001637517091}


100%|██████████| 252/252 [00:00<00:00, 727.69it/s]
100%|██████████| 252/252 [00:00<00:00, 2380550.92it/s]
100%|██████████| 42/42 [00:00<00:00, 220.24it/s]
100%|██████████| 42/42 [00:00<00:00, 1122043.11it/s]
100%|██████████| 252/252 [00:00<00:00, 512.00it/s]     
100%|██████████| 252/252 [00:00<00:00, 2183811.17it/s]
100%|██████████| 42/42 [00:00<00:00, 158.69it/s]
100%|██████████| 42/42 [00:00<00:00, 989667.24it/s]
100%|██████████| 252/252 [00:00<00:00, 579.18it/s]                       
100%|██████████| 252/252 [00:00<00:00, 1918266.08it/s]
100%|██████████| 42/42 [00:00<00:00, 231.28it/s]
100%|██████████| 42/42 [00:00<00:00, 942036.19it/s]
100%|██████████| 252/252 [00:00<00:00, 724.72it/s]                       
100%|██████████| 252/252 [00:00<00:00, 1718641.64it/s]
100%|██████████| 42/42 [00:00<00:00, 180.93it/s]
100%|██████████| 42/42 [00:00<00:00, 1122043.11it/s]
100%|██████████| 252/252 [00:00<00:00, 529.91it/s]     
100%|██████████| 252/252 [00:00<00:00, 2174824.30it/s]
100%|███


--- Trial 23 finished ---
Value: -0.9846422637071294 and parameters: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 78, 'num_layers': 1, 'learning_rate': 0.0007979001637517091}
Best is trial 22 with value: -0.9856337684812899


--- Starting Trial: 24 ---
Hyperparameters for this trial: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 82, 'num_layers': 1, 'learning_rate': 0.0011291860360294976}


100%|██████████| 252/252 [00:00<00:00, 455.97it/s]
100%|██████████| 252/252 [00:00<00:00, 1648930.75it/s]
100%|██████████| 42/42 [00:00<00:00, 191.98it/s]
100%|██████████| 42/42 [00:00<00:00, 978670.93it/s]
100%|██████████| 252/252 [00:00<00:00, 525.09it/s]                         
100%|██████████| 252/252 [00:00<00:00, 2064384.00it/s]
100%|██████████| 42/42 [00:00<00:00, 211.65it/s]
100%|██████████| 42/42 [00:00<00:00, 834885.16it/s]
100%|██████████| 252/252 [00:00<00:00, 677.79it/s]                        
100%|██████████| 252/252 [00:00<00:00, 2109709.80it/s]
100%|██████████| 42/42 [00:00<00:00, 211.58it/s]
100%|██████████| 42/42 [00:00<00:00, 1190275.46it/s]
100%|██████████| 252/252 [00:00<00:00, 640.51it/s]     
100%|██████████| 252/252 [00:00<00:00, 2101321.29it/s]
100%|██████████| 42/42 [00:00<00:00, 212.48it/s]
100%|██████████| 42/42 [00:00<00:00, 827045.86it/s]
100%|██████████| 252/252 [00:00<00:00, 614.34it/s]                         
100%|██████████| 252/252 [00:00<00:00, 16


--- Trial 24 finished ---
Value: -0.9839688462530451 and parameters: {'train_length_multiplier': 12, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 82, 'num_layers': 1, 'learning_rate': 0.0011291860360294976}
Best is trial 22 with value: -0.9856337684812899


--- Starting Trial: 25 ---
Hyperparameters for this trial: {'train_length_multiplier': 11, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 122, 'num_layers': 1, 'learning_rate': 0.0011984778258248034}


100%|██████████| 231/231 [00:00<00:00, 528.45it/s]
100%|██████████| 231/231 [00:00<00:00, 1957341.87it/s]
100%|██████████| 42/42 [00:00<00:00, 219.15it/s]
100%|██████████| 42/42 [00:00<00:00, 1000913.45it/s]
100%|██████████| 231/231 [00:00<00:00, 641.69it/s]                         
100%|██████████| 231/231 [00:00<00:00, 1997699.43it/s]
100%|██████████| 42/42 [00:00<00:00, 235.43it/s]
100%|██████████| 42/42 [00:00<00:00, 1054854.90it/s]
100%|██████████| 231/231 [00:00<00:00, 645.90it/s]                       
100%|██████████| 231/231 [00:00<00:00, 1650569.38it/s]
100%|██████████| 42/42 [00:00<00:00, 230.34it/s]
100%|██████████| 42/42 [00:00<00:00, 967916.31it/s]
100%|██████████| 231/231 [00:00<00:00, 532.72it/s]                       
100%|██████████| 231/231 [00:00<00:00, 2044059.54it/s]
100%|██████████| 42/42 [00:00<00:00, 244.62it/s]
100%|██████████| 42/42 [00:00<00:00, 863533.18it/s]
100%|██████████| 231/231 [00:00<00:00, 688.61it/s]     
100%|██████████| 231/231 [00:00<00:00, 2124


--- Trial 25 finished ---
Value: -0.8411013899435794 and parameters: {'train_length_multiplier': 11, 'val_period_length': 42, 'lookahead': 21, 'num_heads': 3, 'hidden_dim': 122, 'num_layers': 1, 'learning_rate': 0.0011984778258248034}
Best is trial 22 with value: -0.9856337684812899



In [13]:
STOP

NameError: name 'STOP' is not defined

In [None]:
# After all trials have finished, retrieve the best trial's parameters
best_params = study.best_trial.params

# # Create the best model using the Transformer
# best_model = Transformer(
#     input_dim=FEATURE_DIM,
#     d_model=best_params["hidden_dim"],
#     output_dim=OUTPUT_DIM,
#     num_heads=best_params["num_heads"],
#     num_layers=best_params["num_layers"]
# ).to(device)

# Below is the SimpleNN code, commented out:
best_model = SimpleNN(input_dim=FEATURE_DIM, output_dim=OUTPUT_DIM).to(device)

# Train the best model on the entire dataset
criterion = nn.MSELoss()
lr = best_params['learning_rate']
optimizer = optim.Adam(best_model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=100, gamma=0.1)

# Assuming get_era2data() can handle the entire dataset
all_batches = get_era2data(dataset)  

# You might need to adjust/train_model to handle no validation set or adjust accordingly.
_, _, _ = train_model(
    best_model, criterion, optimizer, scheduler, NUM_EPOCHS, PATIENCE, 
    all_batches, None, is_lr_scheduler=True  # Assuming train_model can handle None for validation_batches
)

# Saving the model
model_name = best_model.__class__.__name__
lookahead = best_params.get("lookahead", "NA")
filename = f"{top}_{model_name}_{target_string}_{lookahead:02d}d_rank_quantiled.pkl"
file_path = os.path.join(model_dir, filename)

save_data = {
    'model_type': 'Transformer',
    'model_state_dict': best_model.state_dict(),
    'trial_params': best_params
}
torch.save(save_data, file_path)

In [None]:
# Loading the saved data
loaded_data = torch.load(file_path)

# Create the correct model based on the saved type
if loaded_data['model_type'] == 'Transformer':
    model = Transformer(
        input_dim=FEATURE_DIM,
        d_model=loaded_data['trial_params']["hidden_dim"],
        output_dim=OUTPUT_DIM,
        num_heads=loaded_data['trial_params']["num_heads"],
        num_layers=loaded_data['trial_params']["num_layers"]
    ).to(device)
else:
    model = SimpleNN(input_dim=FEATURE_DIM, output_dim=OUTPUT_DIM).to(device)

# Load the saved parameters into the model
model.load_state_dict(loaded_data['model_state_dict'])

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.optim.lr_scheduler import StepLR
# from model import RankPredictorNN  # Assuming this is where your SimpleNN class is defined
# import os
# from model import Transformer
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# # Constants and hyperparameters
# NUM_EPOCHS = 150
# PATIENCE = 5
# FEATURE_DIM = len(features)  # Assuming 'features' is defined elsewhere in your code
# OUTPUT_DIM = 1
# NUM_TRAIL = 25
# device = "cuda" if torch.cuda.is_available() else "cpu"

# # Best parameters from Optuna study
# best_params = {
#     'train_length_multiplier': 14,
#     'val_period_length': 63,
#     'lookahead': 5,
#     'num_heads': 5,  # Updated to match with the Transformer definition below
#     'hidden_dim': 208,
#     'num_layers': 10,
#     'learning_rate': 0.002
# }

# # Choose model
# best_model = Transformer(
#     input_dim=FEATURE_DIM,
#     d_model=best_params['hidden_dim'],
#     output_dim=OUTPUT_DIM,
#     num_heads=best_params['num_heads'],
#     num_layers=best_params['num_layers'],
# ).to(device)


# # # Initialize the best model using SimpleNN
# # best_model = RankPredictorNN(input_dim=FEATURE_DIM, \
# #     output_dim=OUTPUT_DIM).to(device)

# # Initialize loss function, optimizer, and learning rate scheduler
# criterion = nn.MSELoss()
# optimizer = optim.Adam(best_model.parameters(), lr=best_params['learning_rate'])
# scheduler = StepLR(optimizer, step_size=100, gamma=0.1)


# # label = f'TARGET_ret_fwd_{params["lookahead"]:02d}d_rank_quantiled'
# # Assuming get_era2data() can handle the entire dataset and returns a DataLoader
# all_batches = get_era2data(dataset)  # Replace this with your actual data loading function

# # Training Loop
# # Training Loop
# for epoch in range(NUM_EPOCHS):
#     best_model.train()
    
#     total_loss = 0.0
#     total_mse = 0.0
#     total_corr = 0.0
    
#     # Define the specific label using lookahead
#     for timestamp, (inputs, labels, masks_inputs, target_names) in all_batches.items():
        
#         # Move tensors to the desired device
#         inputs = inputs.to(device)
#         labels = labels.to(device)
#         masks_inputs = masks_inputs.to(device)

#         # Get index for specific label dynamically
#         specific_label_name = f'TARGET_ret_fwd_{best_params["lookahead"]:02d}d_rank_quantiled'
#         specific_label_index = target_names.index(specific_label_name)

#         # Use that index to fetch the specific column
#         labels = labels[:, :, specific_label_index].unsqueeze(2)
#         # print(f"Target Labels Range from training loop: {labels.min().item()}, {labels.max().item()}")

#         # Zero the parameter gradients
#         optimizer.zero_grad()

#         # Forward pass
#         outputs = best_model(inputs / 4, masks_inputs)
#         # print(outputs)
        
#         # Asserting that shapes of labels and outputs match
#         assert labels.shape == outputs.shape, \
#             f"Shape mismatch: labels {labels.shape}, outputs {outputs.shape}"

#         # Compute loss using the custom loss function
#         loss, mse, corr = calculate_loss(outputs, criterion, \
#                     labels, masks_inputs)

#         # print(f"Current batch loss: {loss.item()}, Current batch MSE: {mse.item()}, Current batch Correlation: {corr.item()}")
#         total_loss += loss.item()
#         total_mse += mse.item()
#         total_corr += corr.item()

#         # print(f"Accumulated Total Loss after this batch: {total_loss}, Accumulated Total MSE after this batch: {total_mse}, Accumulated Total Correlation after this batch: {total_corr}")

            
#         # Backward pass and optimization
#         loss.backward()
#         optimizer.step()
#         # break
        
#     # Step the learning rate scheduler
#     scheduler.step()

#     # # At the end of the training loop:
#     # print(f"Total loss: {total_loss}, Total MSE: {total_mse}, Total Correlation: {total_corr}, Number of batches: {len(all_batches)}")

#     avg_loss = total_loss / len(all_batches)
#     avg_mse = total_mse / len(all_batches)
#     avg_corr = total_corr / len(all_batches)
        
#     print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {avg_loss:.4f}, MSE: {avg_mse:.4f}, Correlation: {avg_corr:.4f}")

#     # break


# print("Training complete.")