In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import r2_score, mean_absolute_percentage_error
from torch.utils.data import TensorDataset, DataLoader, random_split
import wandb
from transformers import TimeSeriesTransformerConfig
from transformers import TimeSeriesTransformerForPrediction
import matplotlib.pyplot as plt
import json
import os
import random
# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# PATH = r"/scratch/maxtheisen/Training_Pdrop"
# MODEL_FOLDER = "~/MeOHReactor/models/"
# # Load the configuration from the JSON file
# with open(os.path.join("/home/maxtheisen/MeOHReactor/", 'config_transformer.json'), 'r') as config_file:
#     model_config = json.load(config_file)

torch.manual_seed(10)
# Load Data
print("Load data")
# data_files = ["T", "P", "CO", "CO2", "H2", "CH4", "CH3OH", "H2O", "N2"]
# states = [pd.read_csv(f"{PATH}/{file}.csv", header=None) for file in data_files]
path = "data.csv"
states = pd.read_csv(path,sep=',', header=0,index_col=False)
states.replace([np.inf, -np.inf], np.nan, inplace=True)  # Replace inf with NaN
states.ffill(inplace=True)  # Forward fill NaNs (fill with last valid value)
print("Data loaded")

# Preprocessing
NUMBER_OF_FUNCTIONS = 3
NUMBER_OF_POINTS = states.shape[0]//NUMBER_OF_FUNCTIONS
PRED_LENGTH = 5
CONTEXT_LENGTH = 10
z = np.linspace(0, 1, NUMBER_OF_POINTS) #.reshape(-1, 1, 1) / 8.0
# z = torch.from_numpy(z).to(DEVICE).permute(1, 0, 2)
# states = [state.to_numpy() for state in states]
# df_raw = np.stack(states)
df_raw = states
mean = df_raw.mean(axis=(1,2), keepdims=True)
std = df_raw.std(axis=(1,2), keepdims=True)

#df = ((df_raw - mean) / std).astype(np.float32)
df = df_raw
min = df.min(axis=(1,2), keepdims=True)
max = df.max(axis=(1,2), keepdims=True)
df = ((df - min) / (max - min)).astype(np.float32)
#df = df_raw.astype(np.float32)
df = torch.from_numpy(df).permute(1, 2, 0)[:, :NUMBER_OF_POINTS, ...]

dataset = TensorDataset(df[:, 0,:].unsqueeze(1), df[:, :,:])

#train_size = int(model_config['training']['train_size'] * len(dataset))
train_size = NUMBER_OF_POINTS // 0.7
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32) #model_config['training']['batch_size'])
val_loader = DataLoader(val_dataset, batch_size=32) #model_config['training']['batch_size'])


def train()
    # Model Configuration
    config = TimeSeriesTransformerConfig(
        prediction_length= PRED_LENGTH,  #NUMBER_OF_POINTS,
        context_length= CONTEXT_LENGTH,#model_config['model']['context_length'],
        embedding_dimension= 64, #model_config['model']["hyperparameters"]['embedding_dim'],
        #scaling=model_config['model']['with_scaling'],
        lags_sequence=[0],
        num_time_features=1,
        input_size= NUMBER_OF_FUNCTIONS,#len(data_files),
        num_parallel_samples=1,
        #loss= ,#model_config['training']["loss"]
    )
    model = TimeSeriesTransformerForPrediction(config=config).to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr= 6e-4, weight_decay=10e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=0.002)
    loss_fn = nn.MSELoss()
    epochs = 100

    # wandb.init(project="MeOH-TubolarPacketBedReactor", name="Time_series_model_deterministic", config=model_config, mode="online")

    # # If continuing training, load model weights
    # if model_config['training']['continue_training']:
    #     model_weights = torch.load("models/model_Time_series_model_deterministic_final.pth")
    #     model.load_state_dict(model_weights)
    #     model.to(DEVICE)

    ## MODEL TRAINING 
    best_mape = 100
    model_name = 'model' #model_{}_final.pth'.format(model_config["model"]["name"])

    # if model_config['training']['continue_training']:
    #     model_weights = torch.load(MODEL_FOLDER + model_name)
    #     model.load_state_dict(model_weights)
    #     model.to(DEVICE)



    def plot_predictions(z, y_val, yhat):
        titles = " "
        sample = random.randint(0, y_val.shape[0] - 1)  # Ensure the index is within bounds
        images = []

        for a in range(NUMBER_OF_FUNCTIONS):
            fig, ax = plt.subplots()
            ax.plot(z[0, :, 0], y_val[sample, :, a], label="Ground truth")
            ax.plot(z[0, :, 0], yhat[sample, :, a], label="Prediction")
            ax.set_title(titles[a])
            ax.legend()
            
            # Convert the figure to a wandb Image and append to the images list
            images.append(wandb.Image(fig))

            # Close the figure to free up memory
            plt.close(fig)

        return images

    for epoch in range(epochs):
        # Training
        model.train()
        scheduler.step()
        for x_batch, y_batch in train_loader:
            # Forward pass
            past_time_features =  z[:, 0:1].repeat(x_batch.size(0), CONTEXT_LENGTH, 1).to(DEVICE).float()#torch.zeros_like(torch.linspace(-1, 0, CONTEXT_LENGTH).reshape(1, -1, 1).repeat(x_batch.size(0), 1, 1)).to(device)
            future_time_features = z.repeat(x_batch.size(0), 1, 1).to(DEVICE).float() #torch.zeros_like(y_batch[..., 0]).unsqueeze(-1).to(device)
            past_values = x_batch.repeat(1, CONTEXT_LENGTH, 1).to(DEVICE)
            past_observed_mask = torch.zeros_like(past_values).to(DEVICE)
            past_observed_mask[:, -1:, :] = 1

            output = model(past_values = past_values, 
                        past_time_features = past_time_features,#z.repeat(y_batch.size(0), 1, 1)[:, 0, ...].unsqueeze(1), 
                        past_observed_mask = past_observed_mask,
                        future_values = y_batch.to(DEVICE),
                        future_time_features = future_time_features)
            loss = output.loss
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({"Training loss": loss.item()})   # iteration loss = loss on the batch

        # Validation
        
        model.eval()
        with torch.no_grad():
            vector_lv = []
            vector_r2 = []
            vector_mape = []
            val_loss_2 = []
            for x_val, y_val in val_loader:
                # Inference on validation
                past_time_features =  z[:, 0:1].repeat(x_val.size(0), CONTEXT_LENGTH, 1).to(DEVICE).float()#torch.zeros_like(torch.linspace(-1, 0, CONTEXT_LENGTH).reshape(1, -1, 1).repeat(x_batch.size(0), 1, 1)).to(device)
                future_time_features = z.repeat(x_val.size(0), 1, 1).to(DEVICE).float() #torch.zeros_like(y_batch[..., 0]).unsqueeze(-1).to(device)
                past_values = x_val.repeat(1, CONTEXT_LENGTH, 1).to(DEVICE)
                past_observed_mask = torch.zeros_like(past_values).to(DEVICE)
                past_observed_mask[:, -1:, :] = 1
                output = model.generate(past_values=past_values, 
                                        past_time_features= past_time_features, 
                                        past_observed_mask = past_observed_mask,
                                        future_time_features = future_time_features)
                
                #output = model.generate(past_values=x_val, past_time_features=z.repeat(y_batch.size(0), 1, 1)[:, 0, ...].unsqueeze(1), past_observed_mask = torch.ones_like(x_batch), 
                #                      future_time_features = z.repeat(y_batch.size(0), 1, 1)[:, 1:, ...])
                yhat = output.sequences.mean(dim=1) 

                output = model(past_values = past_values, 
                        past_time_features = past_time_features,#z.repeat(y_batch.size(0), 1, 1)[:, 0, ...].unsqueeze(1), 
                        past_observed_mask = past_observed_mask,
                        future_values = y_val.to(DEVICE),
                        future_time_features = future_time_features)
                loss = output.loss
                val_loss_2 += [loss.cpu().detach().numpy()]


                val_loss = loss_fn(yhat, y_val.to(DEVICE))
                vector_lv.append(val_loss.item())
                # R2 score on validation
                yhat = yhat.cpu().detach().numpy()
                y_val = y_val.cpu().detach().numpy()
                r2 = r2_score(y_val.reshape(-1, len(data_files))[:,:-1], yhat.reshape(-1, len(data_files))[:,:-1])
                mape = mean_absolute_percentage_error(y_val.reshape(-1, len(data_files))[:,:-1], yhat.reshape(-1, len(data_files))[:,:-1])
                vector_r2.append(r2)
                vector_mape.append(mape)


            images_to_log = plot_predictions(z, y_val=y_val, yhat=yhat)
                #        break
            # Log validation loss and R2 score
            #if epoch % 10 == 0:
            #    plot_predictions(z, y_val, yhat)
            wandb.log({"Validation loss": np.mean(vector_lv), # epoch loss = mean loss on the validation set at the end of the epoch
                        "R2 score": np.mean(vector_r2),        # epoch R2 score = mean R2 score on the validation set at the end of the epoch
                        "Test MAPE score": np.mean(vector_mape),       # epoch MAPE score = mean MAPE score on the validation set at the end of the epoch
                        "Epoch": epoch,
                        "Prediction_example": images_to_log})        # epoch MAPE score = mean MAPE score on the validation set at the end of the epoch
            if np.mean(vector_mape) < best_mape: # save best checkpoint
                torch.save(model.state_dict(), MODEL_FOLDER + model_name)
                best_mape = np.mean(vector_mape)
        #print("Val loss 2", np.mean(val_loss_2))
        print("Epoch: {} | Training loss: {} | Validation loss: {} | R2: {} | MAPE: {}".format(epoch+1, loss.item(), val_loss.item(), np.mean(vector_r2), np.mean(vector_mape)))
    #plot_predictions(z, y_val, yhat)

    print("#"*20, "Best MAPE: {}".format(best_mape), "#"*20)
    # model_weights = torch.load(MODEL_FOLDER + model_name)
    # best_model = TimeSeriesTransformerForPrediction(config=config)
    # best_model.load_state_dict(model_weights)
    # best_model.to(DEVICE)

# mape_per_length = []
# for x_val, y_val in val_loader:
#     # Inference on validation
#     past_time_features = torch.zeros_like(torch.linspace(-1, 0, CONTEXT_LENGTH).reshape(1, -1, 1).repeat(x_val.size(0), 1, 1)).to(DEVICE)
#     future_time_features = torch.zeros_like(y_val[..., 0]).unsqueeze(-1).to(DEVICE)
#     past_values = x_val.repeat(1, CONTEXT_LENGTH, 1).to(DEVICE)
#     past_observed_mask = torch.zeros_like(past_values).to(DEVICE)
#     past_observed_mask[:, -1:, :] = 1
#     output = model.generate(past_values=past_observed_mask, 
#                             past_time_features= past_time_features, 
#                             past_observed_mask = past_observed_mask,
#                             future_time_features = future_time_features)
#     yhat = output.sequences.mean(dim=1) 
#     yhat = yhat.transpose(2,1)
#     y_val = y_val.transpose(2,1)
#     yhat = yhat.detach().numpy()
#     y_val = y_val.detach().numpy()
#     mape = mean_absolute_percentage_error(y_val.reshape(-1, y_val.shape[2]), yhat.reshape(-1, yhat.shape[2]), multioutput="raw_values")
#     mape_per_length.append(mape)