In [None]:
import os
import torch
from torch import nn
from torch.nn.functional import mse_loss
from torch.optim.adam import Adam
from tqdm import tqdm
import pandas as pd
import numpy as np


class EarlyStop:
    STOP = 0
    CONTINUE = 1

    def __init__(self, patience=5):
        self.max_patience = patience
        self.current_patience = 0
        self.min_loss = 2000000000

    def count(self, loss):
        if self.min_loss > loss:
            self.min_loss = loss
            self.current_patience = 0
            return EarlyStop.CONTINUE
        else:
            self.current_patience += 1
            if self.current_patience > self.max_patience:
                return EarlyStop.STOP
            else:
                return EarlyStop.CONTINUE

    def reset(self):
        self.current_patience = 0
        self.min_loss = 2000000000


class LSTM(nn.Module):
    def __init__(self, input_size=4, hidden_layer_size=128, output_size=4):
        super().__init__()
        self.hidden_layer_size = hidden_layer_size

        self.lstm = nn.LSTM(input_size, hidden_layer_size)

        self.linear_0 = nn.Linear(hidden_layer_size, hidden_layer_size // 2)
        self.relu = nn.LeakyReLU()
        self.linear_1 = nn.Linear(hidden_layer_size // 2, hidden_layer_size // 4)
        self.relu2 = nn.LeakyReLU()
        self.linear_2 = nn.Linear(hidden_layer_size // 4, output_size)
        self.sigmoid = nn.Sigmoid()
        self.hidden_cell = (torch.zeros(1, 1, self.hidden_layer_size),
                            torch.zeros(1, 1, self.hidden_layer_size))

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear_0(lstm_out.view(len(input_seq), -1))
        predictions = self.relu(predictions)
        predictions = self.linear_1(predictions)
        predictions = self.relu2(predictions)
        predictions = self.linear_2(predictions)
        predictions = self.sigmoid(predictions)
        return predictions[-1]


def make_dataset(stock_train_file, sequence_length, step=16):
    stock_book = pd.read_parquet(stock_train_file)

    time_ids = stock_book.groupby('time_id')
    sequence_data = []
    labels = []
    for item in tqdm(time_ids):
        time_id = item[0]
        time_id_data = item[1]
        num_trade = time_id_data.shape[0]
        for i in range(0, num_trade - sequence_length, step):
            sequence = time_id_data[['bid_price1', 'ask_price1', 'bid_size1', 'ask_size1']][
                       i:i + sequence_length].astype(float).to_numpy()
            label = time_id_data[['bid_price1', 'ask_price1', 'bid_size1', 'ask_size1']][
                    i + sequence_length:i + sequence_length + 1].astype(float).to_numpy()
            sequence_data.append(sequence)
            labels.append(label)
    return np.array(sequence_data), np.array(labels)


def get_stock_file(root_data, stock):
    dir = os.path.join(root_data, "stock_id=" + str(stock))
    file_path = os.listdir(dir)[0]
    return os.path.join(dir, file_path)


def validate(model, val_data, device):
    max_price = max(val_data[0][:, :, 0:2].max(), val_data[1][:, :, 0:2].max())
    min_price = min(val_data[0][:, :, 0:2].min(), val_data[1][:, :, 0:2].min())
    max_size = max(val_data[0][:, :, 2:4].max(), val_data[1][:, :, 2:4].max())
    min_size = min(val_data[0][:, :, 2:4].min(), val_data[1][:, :, 2:4].min())

    val_data[0][:, :, 0:2] = (val_data[0][:, :, 0:2] - min_price) / (max_price - min_price)
    val_data[1][:, :, 0:2] = (val_data[1][:, :, 0:2] - min_price) / (max_price - min_price)
    val_data[0][:, :, 2:4] = (val_data[0][:, :, 2:4] - min_size) / (max_size - min_size)
    val_data[1][:, :, 2:4] = (val_data[1][:, :, 2:4] - min_size) / (max_size - min_size)

    # val step
    model.eval()
    data_len = val_data[0].shape[0]
    val_loss = 0
    for i in tqdm(range(data_len)):
        seq = val_data[0][i]
        labels = val_data[1][i]
        seq = torch.from_numpy(seq).type(torch.FloatTensor).to(device)
        labels = torch.from_numpy(labels).type(torch.FloatTensor).to(device)
        model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size).to(device),
                             torch.zeros(1, 1, model.hidden_layer_size).to(device))

        y_pred = model(seq)
        single_loss = loss_function(y_pred, labels.view(4))
        val_loss += single_loss.item()
    return val_loss


def get_stocks(train):
    return np.unique(train['stock_id'])


if __name__ == "__main__":
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print(f"Working on {device}")
    root_data = "../input/optiver-realized-volatility-prediction/book_train.parquet"
    train = pd.read_csv('../input/optiver-realized-volatility-prediction/train.csv')

    epochs = 150  # for each stock
    sequence_length = 100
    train_ratio = 0.75
    lstm_hidden_size = 128
    input_size = 4
    output_size = 4
    learning_rate = 0.0001
    max_patience = 7
    data_step = 200
    early_stop = EarlyStop(max_patience)

    stocks = get_stocks(train)
    sequence_data = np.ones((0, sequence_length, input_size))
    labels = np.ones((0, 1, input_size))
    for stock in stocks:
        stock_book_file = get_stock_file(root_data, stock)
        a_sequence_data, a_label = make_dataset(stock_book_file, sequence_length, step=data_step)
        sequence_data = np.vstack((sequence_data, a_sequence_data))
        labels = np.vstack((labels, a_label))
        
    rand_indices = np.arange(0, sequence_data.__len__())
    np.random.shuffle(rand_indices)

    sequence_data = sequence_data[rand_indices]
    labels = labels[rand_indices]

    train_data = sequence_data[:int(train_ratio * sequence_data.__len__())], labels[:int(
        train_ratio * sequence_data.__len__())]
    val_data = sequence_data[int(train_ratio * sequence_data.__len__()):], labels[int(
        train_ratio * sequence_data.__len__()):]

    model = LSTM(input_size, lstm_hidden_size, output_size).to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    loss_function = mse_loss

    # scale the train data only
    max_price = max(train_data[0][:, :, 0:2].max(), train_data[1][:, :, 0:2].max())
    min_price = min(train_data[0][:, :, 0:2].min(), train_data[1][:, :, 0:2].min())
    max_size = max(train_data[0][:, :, 2:4].max(), train_data[1][:, :, 2:4].max())
    min_size = min(train_data[0][:, :, 2:4].min(), train_data[1][:, :, 2:4].min())

    train_data[0][:, :, 0:2] = (train_data[0][:, :, 0:2] - min_price) / (max_price - min_price)
    train_data[1][:, :, 0:2] = (train_data[1][:, :, 0:2] - min_price) / (max_price - min_price)
    train_data[0][:, :, 2:4] = (train_data[0][:, :, 2:4] - min_size) / (max_size - min_size)
    train_data[1][:, :, 2:4] = (train_data[1][:, :, 2:4] - min_size) / (max_size - min_size)

    best_val_loss = 1999999999
    early_stop.reset()
    for e in range(epochs):
        # train step
        model.train()
        train_loss = 0
        data_len = train_data[0].shape[0]
        for i in tqdm(range(data_len)):
            seq = train_data[0][i]
            labels = train_data[1][i]
            seq = torch.from_numpy(seq).type(torch.FloatTensor).to(device)
            labels = torch.from_numpy(labels).type(torch.FloatTensor).to(device)
            optimizer.zero_grad()
            model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size).to(device),
                                 torch.zeros(1, 1, model.hidden_layer_size).to(device))

            y_pred = model(seq)

            single_loss = loss_function(y_pred, labels.view(4))
            train_loss += single_loss.item()
            single_loss.backward()
            optimizer.step()
        print(f'epoch: {e:3} train loss: {train_loss:10.8f}')
        # val step
        val_loss = validate(model, val_data, device)
        action = early_stop.count(val_loss)
        if action == EarlyStop.STOP:
            break  # stop training
        print(f'epoch: {e:3} val loss: {val_loss:10.8f}')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"model_best_all.pth")
    debug = 1
