# Training

In [58]:
# Automatic reloading
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [59]:
from src.data_preparation import *
from src.models import *
from src.train_eval import *

from torch.utils.data import DataLoader

In [60]:
train_dataset, val_dataset, test_dataset, norm_stats = prepare_datasets(
    sequence_size=1000,
    stride=25,
    train_ratio=0.8,
    val_ratio=0.1,
    test_ratio=0.1,
    random_state=42,
    load_if_exists=True
)

Loading saved components...


In [61]:
train_dataloader = DataLoader(train_dataset, batch_size = 64, shuffle= True)
validation_dataloader = DataLoader(val_dataset, batch_size = 64, shuffle= True)
test_dataloader = DataLoader(test_dataset, batch_size = 64, shuffle= True)

In [62]:
model = HarGRU()
# model.load_state_dict(torch.load("models/HarGRU_2025-03-11_19-11-04/HarGRU_epoch100.pth"))
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

HarGRU model loaded on cpu.


In [64]:
training_loss_history, validation_loss_history, accuracy_history, f1_history, precision_history, recall_history = train_HAR70_model(
    model, 
    optimizer, 
    train_dataloader, 
    validation_dataloader, 
    num_epochs = 2
)

Epoch [1/2] | Time: 65.23s
(Training) Loss: 0.0053
(Validation) Loss: 0.0044, Accuracy: 0.5404, F1: 0.5330, Precision: 0.5260, Recall: 0.5404
Epoch [2/2] | Time: 67.50s
(Training) Loss: 0.0040
(Validation) Loss: 0.0039, Accuracy: 0.5451, F1: 0.5371, Precision: 0.5296, Recall: 0.5451


In [65]:
save_training_plots_and_metric_history(training_loss_history, validation_loss_history, accuracy_history, f1_history, precision_history, recall_history, str(model).split("(")[0])

✅ Plots saved to: results\HarGRU_2025-03-11_23-07-35
✅ Metric histories saved to: results\HarGRU_2025-03-11_23-07-35\metric_histories.pth


In [53]:
# Test metrics
loss, accuracy, f1, precision, recall, conf_matrix = evaluate_HAR70_model(model, test_dataloader)
print(f"(Test) Loss: {loss:.4f}, Accuracy: {accuracy:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

(Test) Loss: 0.0030, Accuracy: 0.5568, F1: 0.5503, Precision: 0.6124, Recall: 0.5568
