# Training

In [None]:
from pathlib import Path
from dataclasses import dataclass
import sys

# Add the root project directory to the Python path
ROOT = Path.cwd().parent  # This will get the project root since the notebook is in 'notebooks/'
sys.path.append(str(ROOT))
from configs.path_config import EXTRACTED_DATA_DIR, WEIGHTS_DIR, INPUT_FEATURES, OUTPUT_FEATURES
from src.processing import dataset
from models import lstm_model
from src import utils

#### Create the train data loader

In [None]:
folder_path = EXTRACTED_DATA_DIR / 'group_alvbrodel_shifted'
data = dataset.StrainDataset(folder_path, INPUT_FEATURES, OUTPUT_FEATURES, sequence_length=128, start_idx=0, test_size=0.3)
train_loader = data.train_dataloader

#### Train the model

In [None]:
# Model parameters
# input_dim = data.feature_count
input_dim = data.input_feature_count  # Number of input features
output_dim = data.output_feature_count  # Number of output features
print(f"Input dimension: {input_dim}")
print(f"Output dimension: {output_dim}")
hidden_dim=32
num_layers=1
num_epochs=5
learning_rate=0.01
dropout = 0.4
# weight_decay = 0


# Location and name for saving the model
model_folder = WEIGHTS_DIR
# model_name = f'lstm_model_{input_dim}_{hidden_dim}_{num_layers}_{num_epochs}_{learning_rate}_{dropout}_{weight_decay}.pth'
model_name = f'lstm_model_{input_dim}_{hidden_dim}_{num_layers}_{num_epochs}_{learning_rate}_{dropout}.pth'

# Create the model and train it
model = lstm_model.LSTMModel(input_dim, input_dim, hidden_dim, num_layers, dropout)
# losses, prediction = lstm_model.training_loop(model, train_loader, num_epochs, weight_decay, learning_rate, model_folder, model_name)
losses, prediction = lstm_model.training_loop(
    model=model,
    train_loader=train_loader,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    model_folder=model_folder,
    model_name=model_name,
    input_features=INPUT_FEATURES,
    output_features=OUTPUT_FEATURES
)


#### Plot the epoch losses

In [None]:
utils.plot_epochs_loss(num_epochs, losses)