In [None]:
import torch

if torch.cuda.is_available():
    print("CUDA Version: ", torch.version.cuda)
    print("Pytorch Version: ", torch.__version__)
else:
    print("CUDA is not available.")

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch_scatter
import torch.nn as nn
import torch.optim as optim
from torch.nn import Linear, Sequential, LayerNorm, ReLU
from tqdm import trange

import os
import random
import numpy as np
import logging
import pandas as pd
import copy

import matplotlib.pyplot as plt


def normalize(to_normalize, mean_vec, std_vec):
    return (to_normalize - mean_vec) / std_vec

def unnormalize(to_unnormalize, mean_vec, std_vec):
    return to_unnormalize * std_vec + mean_vec

def get_stats(data_list):
    '''
    Compute mean/std for node features (x), edge features (edge_attr),
    and labels (y) over the provided list of Data objects.
    '''
    mean_vec_x = torch.zeros(data_list[0].x.shape[1:])
    std_vec_x  = torch.zeros(data_list[0].x.shape[1:])

    mean_vec_edge = torch.zeros(data_list[0].edge_attr.shape[1:])
    std_vec_edge  = torch.zeros(data_list[0].edge_attr.shape[1:])

    mean_vec_y = torch.zeros(data_list[0].y.shape[1:])
    std_vec_y  = torch.zeros(data_list[0].y.shape[1:])

    max_accumulations = 10**6
    eps = torch.tensor(1e-8)

    num_accs_x = num_accs_edge = num_accs_y = 0

    for dp in data_list:
        mean_vec_x  += torch.sum(dp.x, dim=0)
        std_vec_x   += torch.sum(dp.x**2, dim=0)
        num_accs_x  += dp.x.shape[0]

        mean_vec_edge += torch.sum(dp.edge_attr, dim=0)
        std_vec_edge  += torch.sum(dp.edge_attr**2, dim=0)
        num_accs_edge += dp.edge_attr.shape[0]

        mean_vec_y  += torch.sum(dp.y, dim=0)
        std_vec_y   += torch.sum(dp.y**2, dim=0)
        num_accs_y  += dp.y.shape[0]

        if (num_accs_x > max_accumulations) or (num_accs_edge > max_accumulations):
            break

    mean_vec_x = mean_vec_x / num_accs_x
    std_vec_x  = torch.maximum(torch.sqrt(std_vec_x / num_accs_x - mean_vec_x**2), eps)

    mean_vec_edge = mean_vec_edge / num_accs_edge
    std_vec_edge  = torch.maximum(torch.sqrt(std_vec_edge / num_accs_edge - mean_vec_edge**2), eps)

    mean_vec_y = mean_vec_y / num_accs_y
    std_vec_y  = torch.maximum(torch.sqrt(std_vec_y / num_accs_y - mean_vec_y**2), eps)

    return [mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y]



def build_optimizer(args, params):
    weight_decay = args.weight_decay
    filter_fn = filter(lambda p: p.requires_grad, params)

    if args.opt == 'adam':
        optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer: {args.opt}")

    if args.opt_scheduler == 'none':
        return None, optimizer
    elif args.opt_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)
    elif args.opt_scheduler == 'cos':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)
    else:
        scheduler = None

    return scheduler, optimizer



def train(train_dataset, test_dataset, device, stats_list, loo, args):
    df = pd.DataFrame(columns=['epoch', 'train_loss', 'test_loss'])

    model_name = (
        'MODEL_LOO_SIM_' + str(loo) +
        '_NL_' + str(args.num_layers) +
        '_BS_' + str(args.batch_size) +
        '_HD_' + str(args.hidden_dim) +
        '_EP_' + str(args.epochs) +
        '_WD_' + str(args.weight_decay) +
        '_LR_' + str(args.lr) +
        '_SHUFF_' + str(args.shuffle) +
        '_TR_' + str(args.train_size) +
        '_TE_' + str(args.test_size)
    )

    loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

    [mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y] = stats_list
    (mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y) = (
        mean_vec_x.to(device),
        std_vec_x.to(device),
        mean_vec_edge.to(device),
        std_vec_edge.to(device),
        mean_vec_y.to(device),
        std_vec_y.to(device)
    )

    num_node_features = train_dataset[0].x.shape[1]
    num_edge_features = train_dataset[0].edge_attr.shape[1]
    num_classes = 3  # vx, vy, vz

    model = MeshGraphNet(num_node_features, num_edge_features, args.hidden_dim, num_classes, args).to(device)
    scheduler, opt = build_optimizer(args, model.parameters())

    losses, test_losses = [], []
    best_test_loss = np.inf
    best_model = None

    for epoch in trange(args.epochs, desc="Training", unit="Epochs"):
        total_loss = 0.0
        model.train()
        num_loops = 0

        for batch in loader:
            batch = batch.to(device)
            opt.zero_grad()
            pred = model(batch, mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge)
            loss = model.loss(pred, batch, mean_vec_y, std_vec_y)
            loss.backward()
            opt.step()

            total_loss += loss.item()
            num_loops += 1

        total_loss /= num_loops
        losses.append(total_loss)

        if epoch % 5 == 0:
            test_loss, _ = test(test_loader, device, model, mean_vec_x, std_vec_x,
                                mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y,
                                args.save_velo_val)
            test_losses.append(test_loss.item())

            if not os.path.isdir(args.checkpoint_dir):
                os.mkdir(args.checkpoint_dir)

            if test_loss < best_test_loss:
                best_test_loss = test_loss
                best_model = copy.deepcopy(model)
        else:
            test_losses.append(test_losses[-1])

        df_newline = pd.DataFrame({'epoch': [epoch],
                                   'train_loss': [losses[-1]],
                                   'test_loss': [test_losses[-1]]})
        df = pd.concat([df, df_newline], ignore_index=True)

        if epoch % 5 == 0:
            print("  train loss", round(total_loss, 5), "| test loss", round(test_loss.item(), 5))
            if args.save_best_model and best_model is not None:
                PATH = os.path.join(args.checkpoint_dir, model_name + '.pt')
                torch.save(best_model.state_dict(), PATH)

    CSV_PATH = os.path.join(args.checkpoint_dir, model_name + '.csv')
    df.to_csv(CSV_PATH, index=False)

    return test_losses, losses, best_model, best_test_loss, test_loader


def test(loader, device, test_model,
         mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge, mean_vec_y, std_vec_y,
         is_validation, delta_t=0.01, save_model_preds=False, model_type=None):

    loss = 0.0
    velo_rmse = 0.0
    num_loops = 0
#         toolpiece = torch.tensor(2)
#         workpiece = torch.tensor(1)
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = test_model(data, mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge)
            loss += test_model.loss(pred, data, mean_vec_y, std_vec_y)

            if is_validation:
                loss_mask = torch.argmax(data.x[:, 2:], dim=1) == torch.tensor(1, device=data.x.device)
                eval_velo = data.x[:, 0:3] + unnormalize(pred[:], mean_vec_y, std_vec_y) * delta_t
                gs_velo   = data.x[:, 0:3] + data.y[:] * delta_t
                error = torch.sum((eval_velo - gs_velo) ** 2, axis=1)
                velo_rmse += torch.sqrt(torch.mean(error[loss_mask]))

        num_loops += 1

    return loss / num_loops, velo_rmse / num_loops


def save_plots(args, loss_image_title, losses, test_losses):
    file_prepend = 'SIM_' + loss_image_title.split(' ')[-1] + '_'
    no_title_prepend = file_prepend + 'NO_TITLE_'

    plot_file_name = (file_prepend + '_NL_' + str(args.num_layers) + '_BS_' +
        str(args.batch_size) + '_HD_' + str(args.hidden_dim) + '_EP_' + str(args.epochs) + '_WD_' +
        str(args.weight_decay) + '_LR_' + str(args.lr) + '_SHUFF_' + str(args.shuffle) + '_TR_' +
        str(args.train_size) + '_TE_' + str(args.test_size))

    no_title_plot_file_name = (no_title_prepend + '_NL_' + str(args.num_layers) + '_BS_' +
        str(args.batch_size) + '_HD_' + str(args.hidden_dim) + '_EP_' + str(args.epochs) + '_WD_' +
        str(args.weight_decay) + '_LR_' + str(args.lr) + '_SHUFF_' + str(args.shuffle) + '_TR_' +
        str(args.train_size) + '_TE_' + str(args.test_size))

    if not os.path.isdir(args.postprocess_dir):
        os.mkdir(args.postprocess_dir)

    PLOT_FILE_PATH = os.path.join(args.postprocess_dir, plot_file_name + '.eps')
    NO_TITLE_PLOT_FILE_PATH = os.path.join(args.postprocess_dir, no_title_plot_file_name + '.eps')

    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = "Times New Roman"

    f1 = plt.figure(figsize=(6,6))
    plt.title(loss_image_title)
    plt.plot(losses, label="Training Loss", color='#156082')
    plt.plot(test_losses, label="Test Loss", color='#FFC000')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.ylim(top=5)
    plt.legend()
    plt.show()
    f1.savefig(PLOT_FILE_PATH, bbox_inches='tight', format='eps')

    f2 = plt.figure(figsize=(6,6))
    plt.title('')
    plt.plot(losses, label="Training Loss", color='#156082')
    plt.plot(test_losses, label="Test Loss", color='#FFC000')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.ylim(top=5)
    plt.legend()
    plt.show()
    f2.savefig(NO_TITLE_PLOT_FILE_PATH, bbox_inches='tight', format='eps')




class MeshGraphNet(torch.nn.Module):
    def __init__(self, input_dim_node, input_dim_edge, hidden_dim, output_dim, args, emb=False):
        super(MeshGraphNet, self).__init__()

        self.num_layers = args.num_layers

        self.node_encoder = Sequential(Linear(input_dim_node, hidden_dim),
                                       ReLU(),
                                       Linear(hidden_dim, hidden_dim),
                                       LayerNorm(hidden_dim))

        self.edge_encoder = Sequential(Linear(input_dim_edge, hidden_dim),
                                       ReLU(),
                                       Linear(hidden_dim, hidden_dim),
                                       LayerNorm(hidden_dim))

        self.processor = nn.ModuleList()
        assert (self.num_layers >= 1), 'Number of message passing layers is not >=1'

        processor_layer = self.build_processor_model()
        for _ in range(self.num_layers):
            self.processor.append(processor_layer(hidden_dim, hidden_dim))

        self.decoder = Sequential(Linear(hidden_dim, hidden_dim),
                                  ReLU(),
                                  Linear(hidden_dim, output_dim))

    def build_processor_model(self):
        return ProcessorLayer

    def forward(self, data, mean_vec_x, std_vec_x, mean_vec_edge, std_vec_edge):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        x = normalize(x, mean_vec_x, std_vec_x)
        edge_attr = normalize(edge_attr, mean_vec_edge, std_vec_edge)

        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)

        for i in range(self.num_layers):
            x, edge_attr = self.processor[i](x, edge_index, edge_attr)

        return self.decoder(x)

    def loss(self, pred, inputs, mean_vec_y, std_vec_y):
#         toolpiece = torch.tensor(2)
#         workpiece = torch.tensor(1)
        loss_mask = torch.argmax(inputs.x[:, 2:], dim=1) == torch.tensor(1, device=inputs.x.device)
        labels = normalize(inputs.y, mean_vec_y, std_vec_y)
        error = torch.sum((labels - pred) ** 2, axis=1)
        loss = torch.sqrt(torch.mean(error[loss_mask]))
        return loss


class ProcessorLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(ProcessorLayer, self).__init__(**kwargs)

        self.edge_mlp = Sequential(Linear(3 * in_channels, out_channels),
                                   ReLU(),
                                   Linear(out_channels, out_channels),
                                   LayerNorm(out_channels))

        self.node_mlp = Sequential(Linear(2 * in_channels, out_channels),
                                   ReLU(),
                                   Linear(out_channels, out_channels),
                                   LayerNorm(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.edge_mlp[0].reset_parameters()
        self.edge_mlp[2].reset_parameters()
        self.node_mlp[0].reset_parameters()
        self.node_mlp[2].reset_parameters()

    def forward(self, x, edge_index, edge_attr, size=None):
        out, updated_edges = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
        updated_nodes = torch.cat([x, out], dim=1)
        updated_nodes = x + self.node_mlp(updated_nodes)
        return updated_nodes, updated_edges

    def message(self, x_i, x_j, edge_attr):
        updated_edges = torch.cat([x_i, x_j, edge_attr], dim=1)
        updated_edges = self.edge_mlp(updated_edges) + edge_attr
        return updated_edges

    def aggregate(self, updated_edges, edge_index, dim_size=None):
        node_dim = 0
        out = torch_scatter.scatter(updated_edges, edge_index[0, :], dim=node_dim, reduce='sum')
        return out, updated_edges


class objectview(object):
    def __init__(self, d):
        self.__dict__ = d



def torch_load_trusted(path):
    """
    Loads a .pt produced by our 03 script that may contain NumPy arrays or other pickled objects.
    Uses weights_only=False to allow unpickling. Only use on files you trust (your own).
    """
    try:
        return torch.load(path, weights_only=False)
    except TypeError:
        return torch.load(path)


logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s: %(message)s')

for args in [
    {'model_type': 'meshgraphnet',
     'num_layers': 10,
     'batch_size': 16,
     'hidden_dim': 10,
     'epochs': 750,
     'opt': 'adam',
     'opt_scheduler': 'none',
     'opt_restart': 0,
     'weight_decay': 5e-4,
     'lr': 0.001,
     'train_size': 45,
     'test_size': 10,
     'device': 'cuda',
     'shuffle': False,
     'save_velo_val': False,
     'save_best_model': True,
     'checkpoint_dir': './best_models/',
     'postprocess_dir': './3d_loss_plots/'},
]:
    args = objectview(args)

torch.manual_seed(5)
random.seed(5)
np.random.seed(5)

# Pick which tensor set to use (must match what you generated in 03)
tool_to_wp_connections = 3  # <- make sure this matches the folder you generated (world_3)
root_dir = r"C:\Users\fbagher1\Desktop\run sam\input"
episode_file = os.path.join(root_dir, "graph_csv", "episode_summary.csv")
episode_desc = pd.read_csv(episode_file)


tensor_dir = os.path.join(root_dir, f"tensor_files", f"world_{tool_to_wp_connections}")

logging.info('Episode Description \n--------------------------------------\n %s\n', episode_desc)


simulation_counter = 0
for i in range(len(episode_desc)):
    sim_time = episode_desc.iloc[i, 0]
    if sim_time < 0:
        simulation_counter += 1

for i in range(simulation_counter):
    if i < 10:
        simulation_mesh_file_indicator = f"00{i}"
    elif i < 100:
        simulation_mesh_file_indicator = f"0{i}"
    else:
        simulation_mesh_file_indicator = str(i)

    globals()[f"data_list_{simulation_mesh_file_indicator}"] = []

    logging.info('Creating data list for simulation %s', simulation_mesh_file_indicator)


    DIRECTION = "wp_to_tool"  # or "wp_to_tool"

    # Filenames
    simulation_onehot_tensor_filename        = f"{simulation_mesh_file_indicator}_{DIRECTION}_node_onehot.pt"
    simulation_node_velocity_tensor_filename = f"{simulation_mesh_file_indicator}_{DIRECTION}_node_velo.pt"
    simulation_y_tensor_filename             = f"{simulation_mesh_file_indicator}_{DIRECTION}_node_y.pt"


    simulation_edge_tensor_filename      = f"{simulation_mesh_file_indicator}_mesh_edge.pt"
    simulation_edge_attr_tensor_filename = f"{simulation_mesh_file_indicator}_simp_mesh_edge_attr.pt"


    simulation_onehot_tensor_file        = os.path.join(tensor_dir, simulation_onehot_tensor_filename)
    simulation_node_velocity_tensor_file = os.path.join(tensor_dir, simulation_node_velocity_tensor_filename)
    simulation_y_tensor_file             = os.path.join(tensor_dir, simulation_y_tensor_filename)
    simulation_edge_tensor_file          = os.path.join(tensor_dir, simulation_edge_tensor_filename)
    simulation_edge_attr_tensor_file     = os.path.join(tensor_dir, simulation_edge_attr_tensor_filename)


    required_paths = [
        simulation_onehot_tensor_file,
        simulation_node_velocity_tensor_file,
        simulation_y_tensor_file,
        simulation_edge_tensor_file,
        simulation_edge_attr_tensor_file,
    ]
    for pth in required_paths:
        if not os.path.exists(pth):
            raise FileNotFoundError(
                f"Missing file: {pth}\n"
                f"Check DIRECTION='{DIRECTION}', tool_to_wp_connections={tool_to_wp_connections}, and tensor_dir='{tensor_dir}'."
            )
    print(f"[OK] Using DIRECTION='{DIRECTION}' from: {tensor_dir}")


    simulation_onehot_obj        = torch_load_trusted(simulation_onehot_tensor_file)
    simulation_node_velocity_obj = torch_load_trusted(simulation_node_velocity_tensor_file)
    simulation_y_obj             = torch_load_trusted(simulation_y_tensor_file)


    def to_tensor(obj):
        if isinstance(obj, np.ndarray):
            return torch.from_numpy(obj)
        elif isinstance(obj, torch.Tensor):
            return obj
        else:
            return torch.from_numpy(np.array(obj))

    simulation_onehot_tensor        = to_tensor(simulation_onehot_obj)
    simulation_node_velocity_tensor = to_tensor(simulation_node_velocity_obj)
    simulation_y_tensor             = to_tensor(simulation_y_obj)

    # These should already be torch tensors
    simulation_edge_index_tensor  = torch_load_trusted(simulation_edge_tensor_file)
    simulation_edge_attr_tensor   = torch_load_trusted(simulation_edge_attr_tensor_file)


    simulation_timesteps = simulation_node_velocity_tensor.shape[0]

    data_list = []
    for j in range(simulation_timesteps):
        simulation_node_timestep_velocity_tensor = simulation_node_velocity_tensor[j]
        edge_attr_tensor = simulation_edge_attr_tensor[j]
        y_tensor = simulation_y_tensor[j]


        timestep_x = torch.cat((simulation_node_timestep_velocity_tensor,
                                simulation_onehot_tensor), dim=-1).type(torch.float)

        data_list.append(
            Data(
                x=timestep_x,
                edge_index=simulation_edge_index_tensor,
                edge_attr=edge_attr_tensor,
                y=y_tensor
            )
        )

    globals()[f"data_list_{simulation_mesh_file_indicator}"] = data_list

logging.info('Data list creation complete.')

for k, v in list(globals().items()):
    if k.startswith("data_list_"):
        rowcount = len(v)
        if rowcount < 10:
            rowcount = "  " + str(rowcount)
        elif rowcount < 100:
            rowcount = " " + str(rowcount)
        logging.info('  %s: %s entries with dimensions %s', k, rowcount, v[1])

for i in range(simulation_counter):
    overall_data_list = []
    train_data_list = []
    test_data_list = []

    if i < 10:
        simulation_mesh_file_indicator = f"00{i}"
    elif i < 100:
        simulation_mesh_file_indicator = f"0{i}"
    else:
        simulation_mesh_file_indicator = str(i)

    logging.info('Starting LOO Training for %s', simulation_mesh_file_indicator)
    loo = simulation_mesh_file_indicator

    for k, v in list(globals().items()):
        if k.startswith("data_list_"):
            if k.startswith(f"data_list_{simulation_mesh_file_indicator}"):
                logging.info('  LOO test data list:  %s', k)
                sim_number = k.split("_")[-1].lstrip('0')
                if sim_number == '':
                    sim_number = 0
                logging.info('  LOO simulation number: %s', sim_number)
                loss_image_title = "GNN LOOCV Loss: Simulation " + str(sim_number)
                test_data_list = v
            else:
                if len(train_data_list) == 0:
                    train_data_list = v
                else:
                    train_data_list = train_data_list + v

            if len(overall_data_list) == 0:
                overall_data_list = v
            else:
                overall_data_list = overall_data_list + v

    logging.info('  Overall data list length %s', len(overall_data_list))
    logging.info('  Train data list length %s', len(train_data_list))
    logging.info('  Test data list length %s', len(test_data_list))

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.device = device
    logging.info('  Proceeding with %s device.', device)

    stats_list = get_stats(overall_data_list)

    test_losses, losses, best_model, best_test_loss, test_loader = train(
        train_data_list, test_data_list, device, stats_list, loo, args
    )

    save_plots(args, loss_image_title, losses, test_losses)

exit(0)
