# Import necessary libraries
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import itertools
import random

# Define your LSTM model class
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, learning_rate, window_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.window_size = window_size
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
        self.learning_rate = learning_rate

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        out, _ = self.lstm(x, (h0, c0))

        out = self.fc(out[:, -1, :])
        return out


# Define a function to split the data
def split_data_with_window(x_in, y_in, split_window_size):
    x_out1_list, y_out1_list, x_out2_list, y_out2_list = [], [], [], []

    for i in range(0, len(x_in) - split_window_size, split_window_size + 1):
        x_out1_out2 = x_in.iloc[i:i + split_window_size + 1]
        y_out1_out2 = y_in.iloc[i:i + split_window_size + 1]

        x_out1 = x_out1_out2.iloc[:-1]
        y_out1 = y_out1_out2.iloc[:-1]

        x_out2 = x_out1_out2.iloc[-1:]
        y_out2 = y_out1_out2.iloc[-1:]

        x_out1_list.append(x_out1)
        y_out1_list.append(y_out1)
        x_out2_list.append(x_out2)
        y_out2_list.append(y_out2)

    x_out1 = pd.concat(x_out1_list)
    y_out1 = pd.concat(y_out1_list)
    x_out2 = pd.concat(x_out2_list)
    y_out2 = pd.concat(y_out2_list)

    return x_out1, y_out1, x_out2, y_out2


# Read the CSV file
data = pd.read_csv("../data/data/aapl_raw_data.csv")
data = data.drop("date", axis=1)
data = data.fillna(0)  # Filling null values with zero
data = data.astype('float32')

# Keep data until 31.07.2023
data = data.iloc[:10747]

# Set random seeds for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available() else None
np.random.seed(seed)
random.seed(seed)

# Ensuring deterministic behavior in cuDNN
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define the hyperparameters to search over
input_sizes = [7]
hidden_sizes = [4]
num_layers_list = [1]
learning_rates = [0.0003]
window_sizes = [5]

num_epochs = 1500
patience = 10  # Number of epochs to wait for improvement

# Combine hyperparameters into a list of tuples
hyperparameter_combinations = list(itertools.product(input_sizes, hidden_sizes, num_layers_list, learning_rates, window_sizes))

# Loop through each column to use it as the target variable
for target_column in data.columns:
    print(f"Training model with target variable: {target_column}")
    print()

    # Set the target column as y_data and the rest as x_data
    y_data = data[target_column]
    x_data = data.drop(columns=[target_column])

    # Split Data to train and temp
    split_window_size = 3
    x_train, y_train, x_temp, y_temp = split_data_with_window(x_data, y_data, split_window_size)

    # Split temp into val and test
    split_window_size = 1
    x_val, y_val, x_test, y_test = split_data_with_window(x_temp, y_temp, split_window_size)

    # Normalize the data
    scaler = MinMaxScaler()
    x_train_normalized = scaler.fit_transform(x_train)
    x_val_normalized = scaler.transform(x_val)
    x_test_normalized = scaler.transform(x_test)

    # Convert to PyTorch tensors
    x_train_tensor = torch.tensor(x_train_normalized, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32).view(-1, 1)

    x_val_tensor = torch.tensor(x_val_normalized, dtype=torch.float32)
    y_val_tensor = torch.tensor(y_val.values, dtype=torch.float32).view(-1, 1)

    x_test_tensor = torch.tensor(x_test_normalized, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32).view(-1, 1)

    # Walk-forward validation training with sliding window for each hyperparameter combination
    for hyperparams in hyperparameter_combinations:
        input_size, hidden_size, num_layers, learning_rate, window_size = hyperparams

        # Print hyperparameters
        print(f"Hyperparameters: input_size={input_size}, hidden_size={hidden_size}, num_layers={num_layers}, learning_rate={learning_rate}, window_size={window_size}")

        # Initialize the model
        model = LSTMModel(input_size, hidden_size, num_layers, learning_rate, window_size)

        # Define the loss function and optimizer
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

        best_val_loss = float('inf')
        counter = 0

        # Train the model
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0

            for i in range(len(x_train_tensor)):
                window_end = min(i + window_size, len(x_train_tensor))
                inputs = x_train_tensor[i:window_end].unsqueeze(0)
                labels = y_train_tensor[window_end - 1]

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            model.eval()
            val_loss = 0.0

            with torch.no_grad():
                for i in range(len(x_val_tensor)):
                    window_end = min(i + window_size, len(x_val_tensor))
                    inputs = x_val_tensor[i:window_end].unsqueeze(0)
                    labels = y_val_tensor[window_end - 1]

                    outputs = model(inputs)
                    val_loss += criterion(outputs, labels)

            # Early stopping based on validation loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                counter = 0
            else:
                counter += 1
                if counter >= patience:
                    print(f'Early stopping at epoch {epoch}')
                    break

            print(f'Epoch [{epoch + 1}/{num_epochs}], Training Loss: {running_loss / len(x_train_tensor)}, Validation Loss: {val_loss / len(x_val_tensor)}')

        # Calculate test loss after training is complete
        test_loss = 0.0
        with torch.no_grad():
            for i in range(len(x_test_tensor)):
                window_end = min(i + window_size, len(x_test_tensor))
                inputs = x_test_tensor[i:window_end].unsqueeze(0)
                labels = y_test_tensor[window_end - 1]

                outputs = model(inputs)
                test_loss += criterion(outputs, labels)

        print(f'Final Test Loss: {test_loss / len(x_test_tensor)}')

        for _ in range(4):
            print()

print()