In [None]:
%load_ext autoreload
%autoreload 2

import argparse
from collections import OrderedDict
import datetime
import gc
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib
import matplotlib.pylab as plt
from numbers import Number
import numpy as np
import pandas as pd
pd.options.display.max_rows = 1500
pd.options.display.max_columns = 200
pd.options.display.width = 1000
pd.set_option('max_colwidth', 400)
import pdb
import pickle
import pprint as pp
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from deepsnap.batch import Batch as deepsnap_Batch

import sys, os
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))
from lamp.argparser import arg_parse
from lamp.datasets.load_dataset import load_data
from lamp.gnns import get_data_dropout
from lamp.models import load_model
from lamp.pytorch_net.util import groupby_add_keys, filter_df, get_unique_keys_df, Attr_Dict, Printer, get_num_params, get_machine_name, pload, pdump, to_np_array, get_pdict, reshape_weight_to_matrix, ddeepcopy as deepcopy, plot_vectors, record_data, filter_filename, Early_Stopping, str2bool, get_filename_short, print_banner, plot_matrices, get_num_params, init_args, filter_kwargs, to_string, COLOR_LIST
from lamp.utils import p, update_legacy_default_hyperparam, EXP_PATH, deepsnap_to_pyg, LpLoss, to_cpu, to_tuple_shape, parse_multi_step, loss_op, get_cholesky_inverse, get_device, get_data_comb

device = torch.device("cuda:5")

## Functions:

In [None]:
# Analysis:
def get_results_1d(
    all_hash,
    mode="best",
    exclude_idx=(None,),
    dropout_mode="None",
    n_rollout_steps=-1,
    dirname=None,
    suffix="",
):
    """
    Perform analysis on the 1D Burgers' benchmark.

    Args:
        all_hash: a list of hashes which indicates the experiments to load for analysis
        mode: choose from "best" (load the best model with lowest validation loss) or an integer, 
            e.g. -1 (last saved model), -2 (second last saved model)
        dirname: if not None, will use the dirnaem provided. E.g. tailin-1d_2022-7-27
        suffix: suffix for saving the analysis result.
    """
    
    isplot = True
    df_dict_list = []
    dirname_start = dirname
    for hash_str in all_hash:
        df_dict = {}
        df_dict["hash"] = hash_str
        # Load model:
        is_found = False
        for dirname_core in [
             dirname_start,
            ]:
            filename = filter_filename(EXP_PATH + dirname_core, include=hash_str)
            if len(filename) == 1:
                is_found = True
                break
        if not is_found:
            print(f"hash {hash_str} does not exist in {dirname}! Please pass in the correct dirname.")
            continue
        dirname = EXP_PATH + dirname_core
        if not dirname.endswith("/"):
            dirname += "/"

        try:
            data_record = pload(dirname + filename[0])
        except Exception as e:
            # p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=100)
            print(f"error {e} in hash_str {hash_str}")
            continue
        p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=160)
        if isplot:
            plot_learning_curve(data_record)
        args = init_args(update_legacy_default_hyperparam(data_record["args"]))
        args.filename = filename
        if mode == "best":
            model = load_model(data_record["best_model_dict"], device=device)
            print("Load the model with best validation loss.")
        else:
            assert isinstance(mode, int)
            print(f'Load the model at epoch {data_record["epoch"][mode]}')
            model = load_model(data_record["model_dict"][mode], device=device)
        model.eval()
        # pp.pprint(args.__dict__)
        kwargs = {}
        if data_record["best_model_dict"]["type"].startswith("GNNPolicy"):
            kwargs["is_deepsnap"] = True

        # Load test dataset:
        args_test = deepcopy(args)
        multi_step = (250 - 50) // args_test.temporal_bundle_steps
        args_test.multi_step = f"1^{multi_step}"
        args_test.is_test_only = True
        args_test.n_train = "-1"
        n_test_traj = 128
        (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args_test)
        nx = int(args.dataset.split("-")[2])
        time_stamps_effective = len(dataset_test) // n_test_traj
        for exclude_idx_ele in exclude_idx:
            loss_list = []
            pred_list = []
            y_list = []
            for i in range(n_test_traj):
                idx = i * time_stamps_effective + args_test.temporal_bundle_steps
                data = deepcopy(dataset_test[idx])
                if dropout_mode == "None":
                    if exclude_idx_ele is not None:
                        data = get_data_dropout(data, dropout_mode="node:0", exclude_idx=exclude_idx_ele)
                else:
                    data = get_data_dropout(data, dropout_mode=dropout_mode)
                data = data.to(device)
                preds, info = model(
                    data,
                    pred_steps=np.arange(1,n_rollout_steps+1) if n_rollout_steps != -1 else np.arange(1, max(parse_multi_step(args_test.multi_step).keys())+1),
                    latent_pred_steps=None,
                    is_recons=False,
                    use_grads=False,
                    use_pos=args.use_pos,
                    is_y_diff=False,
                    is_rollout=False,
                    **kwargs
                )
                y = data.node_label["n0"]
                if n_rollout_steps != -1:
                    y = y[:,:25*n_rollout_steps]
                pred = preds["n0"].reshape(y.shape)
                pred_list.append(pred.detach())
                y_list.append(y.detach())
                loss_ele = nn.MSELoss(reduction="sum")(pred, y) / nx
                loss_list.append(loss_ele.item())

            loss_mean = np.mean(loss_list)
            pred_list = torch.stack(pred_list).squeeze(-1)
            y_list = torch.stack(y_list).squeeze(-1)
            df_dict[f"loss_cumu_{exclude_idx_ele}"] = loss_mean 
            print("\nTest for {} for exclude_idx={} is: {:.9f} at epoch {}, for {}/{} epochs".format(hash_str, exclude_idx_ele, loss_mean, data_record['best_epoch'], len(data_record["train_loss"]), args.epochs))

            mse_full = nn.MSELoss(reduction="none")(pred_list, y_list)
            mse_time = to_np_array(mse_full.mean((0,1)))
            p.print("Learning curve:", is_datetime=False, banner_size=100)
            plt.figure(figsize=(12,5))
            plt.subplot(1,2,1)
            plt.plot(mse_time)
            plt.xlabel("rollout step")
            plt.ylabel("MSE")
            plt.title("MSE vs. rollout step (linear scale)")
            plt.subplot(1,2,2)
            plt.semilogy(mse_time)
            plt.xlabel("rollout step")
            plt.ylabel("MSE")
            plt.title("MSE vs. rollout step (log scale)")
            plt.show()
            plt.figure(figsize=(6,5))
            plt.plot(mse_time.cumsum())
            plt.title("cumulative MSE vs. rollout step")
            plt.xlabel("rollout step")
            plt.ylabel("cumulative MSE")
            plt.show()

            # Visualization:
            for idx in range(6,8):
                p.print(f"Example {idx*128}:", banner_size=100, is_datetime=False)
                data = deepcopy(dataset_test[idx*128]).to(device)
                if exclude_idx_ele is not None:
                    data = get_data_dropout(data, dropout_mode="node:0", exclude_idx=exclude_idx_ele)
                preds, info = model(
                    data,
                    pred_steps=np.arange(1,max(parse_multi_step(args_test.multi_step).keys())+1),
                    latent_pred_steps=None,
                    is_recons=False,
                    use_grads=False,
                    use_pos=args.use_pos,
                    is_y_diff=False,
                    is_rollout=False,
                    **kwargs
                )
                y = data.node_label["n0"]
                pred = preds["n0"].reshape(y.shape)
                visualize(pred, y)
                visualize_paper(pred, y)

            p.print(f"Individual prediction at rollout step {y.shape[1]}:", banner_size=100, is_datetime=False)
            time_step = -1
            for idx in range(0, 20, 5):
                plt.figure(figsize=(6,4))
                plt.plot(to_np_array(pred_list[idx,:,time_step]), label="pred")
                plt.plot(to_np_array(y_list[idx,:,time_step]), "--", label="y")
                plt.legend()
                plt.show()
        df_dict["best_epoch"] = data_record['best_epoch']
        df_dict["epoch"] = len(data_record["train_loss"])
        df_dict.update(args.__dict__)
        df_dict_list.append(df_dict)
    df = pd.DataFrame(df_dict_list)
    pdump(df, f"df_1d{suffix}.p")
    return df

# Plotting:
def plot_learning_curve(data_record):
    plt.figure(figsize=(16,6))
    plt.subplot(1,2,1)
    plt.plot(data_record["epoch"], data_record["train_loss"], label="train")
    plt.plot(data_record["test_epoch"] if "test_epoch" in data_record else data_record["epoch"], data_record["val_loss"], label="val")
    plt.plot(data_record["test_epoch"] if "test_epoch" in data_record else data_record["epoch"], data_record["test_loss"], label="test")
    plt.title("Learning curve, linear scale")
    plt.legend()
    plt.subplot(1,2,2)
    plt.semilogy(data_record["epoch"], data_record["train_loss"], label="train")
    plt.semilogy(data_record["test_epoch"] if "test_epoch" in data_record else data_record["epoch"], data_record["val_loss"], label="val")
    plt.semilogy(data_record["test_epoch"] if "test_epoch" in data_record else data_record["epoch"], data_record["test_loss"], label="test")
    plt.title("Learning curve, log scale")
    plt.legend()
    plt.show()


def plot_colorbar(matrix, vmax=None, vmin=None, cmap="seismic", label=None):
    if vmax==None:
        vmax = matrix.max()
        vmin = matrix.min()
    im = plt.imshow(matrix,vmax=vmax,vmin=vmin,cmap=cmap)
    plt.title(label)
    im_ratio = matrix.shape[0]/matrix.shape[1]
    plt.colorbar(im,fraction=0.046*im_ratio,pad=0.04)


def visualize(pred, gt, animate=False):
    if torch.is_tensor(gt):
        gt = to_np_array(gt)
        pred = to_np_array(pred)
    mse_over_t = ((gt-pred)**2).mean(axis=0).mean(axis=-1)
     
    if not animate:
        vmax = gt.max()
        vmin = gt.min()
        plt.figure(figsize=[15,5])
        plt.subplot(1,4,1)
        plot_colorbar(gt[:,:,0].T,label="gt")
        plt.subplot(1,4,2)
        plot_colorbar(pred[:,:,0].T,label="pred")
        plt.subplot(1,4,3)
        plot_colorbar((pred-gt)[:,:,0].T,vmax=np.abs(pred-gt).max(),vmin=(-1*np.abs(pred-gt).max()),label="diff")
        plt.subplot(1,4,4)
        plt.plot(mse_over_t);plt.title("mse over t");plt.yscale('log');
        plt.tight_layout()
        plt.show()

def visualize_paper(pred, gt, is_save=False):
    idx = 6
    nx = pred.shape[0]

    fontsize = 14
    idx_list = np.arange(0, 200, 15)
    color_list = np.linspace(0.01, 0.9, len(idx_list))
    x_axis = np.linspace(0,16,nx)
    cmap = matplotlib.cm.get_cmap('jet')

    plt.figure(figsize=(16,5))
    plt.subplot(1,2,1)
    for i, idx in enumerate(idx_list):
        pred_i = to_np_array(pred[...,idx,:].squeeze())
        rgb = cmap(color_list[i])[:3]
        plt.plot(x_axis, pred_i, color=rgb, label=f"t={np.round(i*0.3, 1)}s")
    plt.ylabel("u(t,x)", fontsize=fontsize)
    plt.xlabel("x", fontsize=fontsize)
    plt.tick_params(labelsize=fontsize)
    # plt.legend(fontsize=10, bbox_to_anchor=[1,1])
    plt.xticks([0,8,16], [0,8,16])
    plt.ylim([-2.5,2.5])
    plt.title("Prediction")
    if is_save:
        plt.savefig(f"1D_E2-{nx}.pdf", bbox_inches='tight')

    plt.subplot(1,2,2)
    for i, idx in enumerate(idx_list):
        y_i = to_np_array(gt[...,idx,:])
        rgb = cmap(color_list[i])[:3]
        plt.plot(x_axis, y_i, color=rgb, label=f"t={np.round(i*0.3, 1)}s")
    plt.ylabel("u(t,x)", fontsize=fontsize)
    plt.xlabel("x", fontsize=fontsize)
    plt.tick_params(labelsize=fontsize)
    plt.legend(fontsize=10, bbox_to_anchor=[1,1])
    plt.xticks([0,8,16], [0,8,16])
    plt.ylim([-2.5,2.5])
    plt.title("Ground-truth")
    if is_save:
        plt.savefig(f"1D_gt-{nx}.pdf", bbox_inches='tight')
    plt.show()

## Analysis:

In [None]:
"""
all_hash is a list of hashes, each of which corresponds to one experiment.
    For example, if one experiment is saved under ./results/evo-1d_2023-01-01/mppde1d-E2-50_train_-1_algo_contrast_ebm_False_ebmt_cd_enc_cnn-s_evo_cnn_act_elu_hid_128_lo_rmse_recef_1.0_conef_1.0_nconv_4_nlat_1_clat_3_lf_True_reg_None_id_0_Hash_qvQry9QJ_ampere3.p
    Then, the "qvQry9QJ" (located at the end of the filename) is the {hash} of this file.
    The "evo-1d_2023-01-01" is the "{--exp_id}_{--date_time}" of the training command.
    all_hash can contain multiple hashes, and analyze them sequentially.
"""
all_hash = [
    "mhkVkAaz",
]
df9 = get_results_1d(
    all_hash,
    dirname="evo-1d_2023-01-01",
    n_rollout_steps=7,
    dropout_mode="uniform:2",
    suffix="_0")