In [9]:
import os
import sys

import joblib
import numpy as np
import pandas as pd
import torch

from model import BiLSTM

sys.path.append(os.path.abspath(os.path.join(os.path.dirname("."), '..')))
from etl_data_preparing import DataPrepareETL

In [None]:
model = BiLSTM()

In [None]:
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_function = torch.nn.MSELoss()

In [16]:
train_loader, (X_valid, y_valid), (X_test, y_test) = DataPrepareETL(
    data="../preprocessing/data/BTCUSDT.csv",
    scaler="./static/BTCUSDT_MinMaxScaler.pkl"
).prepare_for_train(valid_size=.10, test_size=.10)

In [None]:
static_dir = "static/"
best_model_filename = os.path.join(static_dir, "best_BiLSTM.pkl")
train(epochs=25, best_model_filename=best_model_filename, test_count=1000)

# Utils

In [17]:
def mae(x, y):
    with torch.no_grad():
        out = model(x)
        errors = abs(out - y)
    return errors

def train_one_epoch():
    train_losses = list()
    val_losses = list()
    model.train()
    for i, (X_batch, y_batch) in enumerate(train_loader):
        model.zero_grad()
        out = model(X_batch)
        loss = loss_function(out, y_batch)
        loss.backward()
        optimizer.step()

        if (i+1) % 25 == 0:
            train_losses.append(loss.item())
            with torch.no_grad():
                val_out = model(X_valid)
                val_loss = loss_function(val_out, y_valid)
            val_losses.append(val_loss.item())

        if (i+1) % 100 == 0:
            print(f"Train Itteration: {i+1}, Train Loss: {loss.item()}, Valid Loss: {val_loss.item()}")

    return train_losses, val_losses


def train(epochs: int, best_model_filename: str, test_count: int = 1000):
    best_error = np.inf
    train_losses = list()
    val_losses = list()
    for epoch in range(epochs):
        tl, vl = train_one_epoch()
        train_losses.extend(tl)
        val_losses.extend(vl)
        epoch_errors = list()
        for i in range(0, test_count, 100):
            epoch_errors.append(mae(X_valid[i:i+100], y_valid[i:i+100]))
        epoch_errors = torch.cat(epoch_errors)
        if epoch_errors.mean().item() < best_error:
            torch.save(model, best_model_filename)
            best_error = epoch_errors.mean().item()
        print("Epoch:", epoch+1, "MAE:", epoch_errors.mean().item(), "STD:", epoch_errors.std().item())
    return train_losses, val_losses
