In [None]:
import sys
sys.path.append("../../")
print(sys.path)
import os
import json
import time
import datetime

import openpyxl
from openpyxl import Workbook
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

import torch
import torchvision
from torch import nn
from torchvision.transforms import Compose
from torch.utils.data import RandomSampler, DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter


from tqdm import tqdm
from tqdm.notebook import tqdm

from iunets import iUNet
from dataset import FWIDataset
# from networks import iunet_network
from networks import forward_network, inverse_network, iunet_network

import utils.transforms as T
from utils.pytorch_ssim import *
import utils.utilities as utils
from utils.scheduler import WarmupMultiStepLR
from matplotlib.colors import ListedColormap
import matplotlib.gridspec as gridspec


rainbow_cmap = ListedColormap(np.load('/projects/ml4science/OpenFWI/Latent_Bijectivity/utils/rainbow256.npy'))

forward_model_list = [forward_network.FNO2d, forward_network.WaveformNet, forward_network.WaveformNet_V2, iunet_network.IUnetForwardModel, iunet_network.UNetForwardModel]
inverse_model_list = [inverse_network.InversionNet, iunet_network.IUnetInverseModel, iunet_network.UNetInverseModel]
joint_model_list = [iunet_network.IUnetModel, iunet_network.JointModel, iunet_network.Decouple_IUnetModel]

In [None]:
step = 0
file_size = 500
vis_suffix = False
device = torch.device("cuda")

k = 1
workers = 4
lambda_g1v = 1
lambda_g2v = 1
batch_size = 50
mask_factor = 0.0
sample_temporal = 1
distributed = False

num_images = 20
skip = 0
cfg_path = "../../configs/"
latent_dim = 70



base_path = "/projects/ml4science/OpenFWI/Results/SupervisedExperiment/"
mode = "main_paper" #"appendix" or "main_paper"
viz_save_path = f'AAAI_viz_aug15/{mode}'
font_sizes = {"color_bar": 16, "sub_plt_title":20, "dataset_name": 20}

#Runs in supervised fashion, models[i] on evaluate[i]
evaluate_datasets = ["flatvel-a", "flatvel-b", "curvevel-a", "curvevel-b",
                     "flatfault-a", "flatfault-b",  "curvefault-a", "curvefault-b", "style-a", "style-b"]

model_train_datasets = ['FlatVel-A', 'FlatVel-B', 'CurveVel-A', 'CurveVel-B',
                        'FlatFault-A', 'FlatFault-B', 'CurveFault-A', 'CurveFault-B', 'Style-A', 'Style-B']

if mode == "appendix":
    architecture_types =  ["FNO",  "WaveformNet", "AutoLinearForward", "UNetForwardModel", "UNetForwardModel",  
                      "IUNET", "IUNET"]
                     
    architecture_names =  ["FNO", "WaveformNet", "AutoLinear_Forward_ckpt", "UNetForwardModel_17M", "UNetForwardModel_33M", 
                     "Invertible_XNet", "Invertible_XNet_cycle_warmup"]
else:

    architecture_types =  ["FNO", "AutoLinearForward", "UNetForwardModel", "IUNET"]

    architecture_names =  ["FNO", "AutoLinear_Forward_ckpt", "UNetForwardModel_33M", "Invertible_XNet"]
    


plot_names = {"FNO": "FNO", "UNetForwardModel_17M": "Latent U-Net (Small)", 
              "UNetForwardModel_33M": "Latent U-Net (Large)",
              "IUnetForwardModel": "IUnetForwardModel" ,
              "Invertible_XNet": "Invertible X-Net", "Invertible_XNet_cycle_warmup": "Invertible X-Net (Cycle)",
             "Velocity_GAN": "VelocityGAN", "WaveformNet": "WaveformNet", "ground_truth": "Ground Truth",
             "AutoLinear_Forward_ckpt": "Auto-Linear"}

plot_dataset_names = {"flatvel-a": "FVA", "flatvel-b": "FVB", "curvevel-a": "CVA", "curvevel-b": "CVB",
                     "flatfault-a": "FFA", "flatfault-b": "FFB",  "curvefault-a": "CFA", "curvefault-b": "CFB", 
                      "style-a": "STA", "style-b": "STB"}

architecture_params = {"UNetForwardModel_17M":{"unet_depth": 2, "unet_repeat_blocks": 1}, 
                       "UNetForwardModel_33M":{"unet_depth": 2, "unet_repeat_blocks": 2},
                       "default":{"unet_depth": 2, "unet_repeat_blocks": 2}
                      }
model_paths = {}
for model_name in model_train_datasets:
    model_paths[model_name] = {}
    for i, architecture_name in enumerate(architecture_names):
        path_ = os.path.join(model_name, architecture_name, "fcn_l1loss_ffb")
        model_paths[model_name][architecture_name]= path_

In [None]:

linestyle_tuple = [
     ('loosely dotted',        (0, (1, 10))),
     ('dotted',                (0, (1, 1))),
     ('densely dotted',        (0, (1, 1))),
     ('long dash with offset', (5, (10, 3))),
     ('loosely dashed',        (0, (5, 10))),
     ('dashed',                (0, (5, 5))),
     ('densely dashed',        (0, (5, 1))),

     ('loosely dashdotted',    (0, (3, 10, 1, 10))),
     ('dashdotted',            (0, (3, 5, 1, 5))),
     ('densely dashdotted',    (0, (3, 1, 1, 1))),

     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]

linestyle_dict = {name: style for name, style in linestyle_tuple}

plot_map = {
                "ground_truth":{
                   "linestyle":"solid", 
                   "color":"black",
                   "zorder":0,
                  },
                 "FNO":{
                   "linestyle":linestyle_dict["dashed"], 
                   "color":"blue",
                   "zorder":3,
                  },
                 "UNetForwardModel_33M":{
                   "linestyle":linestyle_dict["densely dashed"], 
                   "color":"red",
                   "zorder":2,
                  },
                  "Invertible_XNet":{
                   "linestyle":"dotted", 
                   "color":"darkgreen",
                   "zorder":1,
                  },
           }

In [None]:
true_items_dict = {'flatvel-a': np.array([1488, 2253,  819, 1046, 3920,  633,  964, 5898, 5933, 4383,  315,
         256, 5837, 4492,  389, 2186, 5272, 4972, 4383, 3516]),
 'flatvel-b': np.array([ 974, 3454, 5279,  894, 2401, 5405,  114, 5189, 3564, 3560, 4970,
        2284, 1872, 3680,  331,   94, 5782, 4302, 5120,  569]),
 'curvevel-a': np.array([1441, 2121, 2343, 2331, 3659, 4457, 2403, 1546, 3575, 1234, 2294,
         255, 5613, 3010, 3647, 4285, 3442,  475, 4226,  942]),
 'curvevel-b': np.array([3417,  332, 5871, 3862,  200, 3372, 5210, 3169, 5479,  751, 3873,
        5162, 5260, 2470, 4395,  711, 1204, 3886, 2100, 3956]),
 'flatfault-a': np.array([1492, 3869, 5428, 2941, 4577, 5167, 5313, 1357, 4488, 2116, 4468,
        4232, 1402,  349,  801,   93, 5210, 1984,  640, 5807]),
 'flatfault-b': np.array([2200, 2507, 2531, 1484,  163,   87, 1756, 1864, 4972, 3109, 2258,
        4510, 5983, 5145, 4854, 1234, 2423,  696, 4337, 2290]),
 'curvefault-a': np.array([5458, 3833, 4271, 1892, 5448, 1351, 2580, 2004, 2699, 3761, 3630,
        1238,  300, 5753, 4806, 3049, 1157, 2890, 4465,  144]),
 'curvefault-b': np.array([4998,  580, 3749, 2775, 5798, 2342,   24, 3622, 2406,  443, 3561,
        5511, 4919, 2228, 4075, 1330, 4050, 1068,  263, 1840]),
 'style-a': np.array([ 121, 3384, 1096, 5874,  312, 4093, 6969, 6105, 4649, 6392, 3347,
        2649, 4184, 4634,   96, 5287, 2963, 4879, 3525,  684]),
 'style-b': np.array([1771,  508, 6559,  386, 1349, 4491, 3294,  131, 6602, 6553, 3832,
        2119,  880, 4273, 5477, 5846, 2790,  773, 3311,  416])}


In [None]:
model_paths

In [None]:
criterions = {
    'MAE': lambda x, y: torch.mean(torch.abs(x - y)),
    'MSE': lambda x, y: torch.mean((x - y) ** 2)
}

def get_dataset_path(dataset):
    arr = dataset.split("-")
    base_path = f"../../train_test_splits/"
    
    train_path = os.path.join(base_path, f"{arr[0]}_{arr[1]}_train.txt")
    val_path = os.path.join(base_path, f"{arr[0]}_{arr[1]}_val.txt")
    
    return train_path, val_path

def get_transforms(dataset, return_ctx=False):
    f = open('../../dataset_config.json')
    ctx = json.load(f)[dataset]

    transform_data = T.Normalize(ctx['data_mean'], ctx['data_std'])
    transform_label = T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])
    if return_ctx:
        return  transform_data, transform_label, ctx
    return  transform_data, transform_label

def get_transforms_auto_linear(dataset, return_ctx=False):
    f = open('../../dataset_config.json')
    ctx = json.load(f)[dataset]
    
    transform_data = Compose([
                        T.LogTransform(k=k),
                        T.MinMaxNormalize(T.log_transform(ctx['data_min'], k=k), 
                        T.log_transform(ctx['data_max'], k=k))
                    ])

    transform_label = T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])

    inverse_transform_data = Compose([
        T.MinMaxNormalize(T.log_transform(ctx['data_min'], k=k), 
                            T.log_transform(ctx['data_max'], k=k)).inverse_transform,
        T.LogTransform(k=k).inverse_transform
    ])

    min_log = T.log_transform(ctx['data_min'], k)
    max_log = T.log_transform(ctx['data_max'],k)
    if return_ctx:
        return  transform_data, transform_label, ctx
    return  transform_data, transform_label, min_log, max_log

def inverse_transform_auto_linear(dataset, amp_pred):
    ### inverse transform autolinear predictions to original space
    transform_data, transform_label, min_log, max_log = get_transforms_auto_linear(dataset, return_ctx=False)
    amp_pred_np = T.minmax_denormalize(permute_fn(amp_pred).detach().cpu().numpy(), min_log,max_log)
    amp_pred_np = T.log_inverse_transform(amp_pred_np, k=k)
    return amp_pred_np

def get_dataloader(dataset, transform_mode="normal"):
        
    if transform_mode =="normal":    
        transform_data, transform_label, ctx = get_transforms(dataset, return_ctx=True)
    
    if transform_mode =="AutoLinear":
        transform_data, transform_label, ctx = get_transforms_auto_linear(dataset, return_ctx=True)


    train_anno, val_anno = get_dataset_path(dataset)
        
    print(f'Loading {dataset} validation data')
    dataset_valid = FWIDataset(
        val_anno,
        preload=True,
        sample_ratio=sample_temporal,
        file_size=ctx['file_size'],
        transform_data=transform_data,
        transform_label=transform_label
    )
        
    valid_sampler = RandomSampler(dataset_valid)

    dataloader_valid = DataLoader(
                                dataset_valid, batch_size=batch_size,
                                sampler=valid_sampler, num_workers=workers,
                                pin_memory=True, collate_fn=default_collate, shuffle=False)
    
    print('Data loading over')
        
    return dataset_valid, dataloader_valid, transform_data, transform_label 

In [None]:
def set_forward_params(forward_model_params, model_path):
        forward_model_params.setdefault('IUnetForwardModel', {})
        forward_model_params['IUnetForwardModel']['cfg_path'] = cfg_path
        forward_model_params['IUnetForwardModel']['latent_dim'] = latent_dim
        
        forward_model_params.setdefault('UNetForwardModel', {})
        forward_model_params['UNetForwardModel']['cfg_path'] = cfg_path
        forward_model_params['UNetForwardModel']['latent_dim'] = latent_dim
        if "UNetForwardModel_17M" in model_path:
            print("here")
            forward_model_params['UNetForwardModel']['unet_depth'] = architecture_params["UNetForwardModel_17M"]["unet_depth"]
            forward_model_params['UNetForwardModel']['unet_repeat_blocks'] = architecture_params["UNetForwardModel_17M"]["unet_repeat_blocks"]
        else:
            forward_model_params['UNetForwardModel']['unet_depth'] = architecture_params["default"]["unet_depth"]
            forward_model_params['UNetForwardModel']['unet_repeat_blocks'] = architecture_params["default"]["unet_repeat_blocks"]
        forward_model_params['UNetForwardModel']['skip'] = skip # skip true
        return forward_model_params

def get_model(model_path, model_type):
    try:
        print(model_path)
        forward_model_params = forward_network.forward_params
        forward_model_params = set_forward_params(forward_model_params, model_path)
        model = forward_network.model_dict[model_type](**forward_model_params[model_type]).to(device)
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model'])
    except:
        print("Failed to load new model. Falling back to Legacy Code.")
        forward_model_params = forward_network.forward_params_legacy
        if "UNetForwardModel_17M" in model_path:
            print("here")
            forward_model_params['unet_depth'] = architecture_params["UNetForwardModel_17M"]["unet_depth"]
            forward_model_params['unet_repeat_blocks'] = architecture_params["UNetForwardModel_17M"]["unet_repeat_blocks"]
        else:
            forward_model_params['unet_depth'] = architecture_params["default"]["unet_depth"]
            forward_model_params['unet_repeat_blocks'] = architecture_params["default"]["unet_repeat_blocks"]
        model_type = model_type+"_Legacy"
        model = forward_network.model_dict[model_type](**forward_model_params).to(device)
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model'])

    model = model.to(device)
    model.eval()
    
    return model


def get_model_iunet(model_path, model_type):

    amp_input_channel = 5
    amp_encoder_channel = [8, 16, 32, 64, 128]
    amp_decoder_channel = [128, 64, 32, 16, 5]
    amp_model = iunet_network.AmpAutoEncoder(amp_input_channel, amp_encoder_channel, amp_decoder_channel).to(device)

    # creating velocity cnn
    vel_input_channel = 1
    vel_encoder_channel = [8, 16, 32, 64, 128]
    vel_decoder_channel = [128, 64, 32, 16, 1]
    vel_model = iunet_network.VelAutoEncoder(vel_input_channel, vel_encoder_channel, vel_decoder_channel).to(device)

    if model_type == "IUNET":
        iunet_model = iUNet(in_channels=128, dim=2, architecture=(4,4,4,4))
        model = iunet_network.IUnetModel(amp_model, vel_model, iunet_model)
        print("IUnet model initialized.")
    elif model_type == "Decouple_IUnet":
        amp_iunet_model = iUNet(in_channels=128, dim=2, architecture=(4,4,4,4))
        vel_iunet_model = iUNet(in_channels=128, dim=2, architecture=(4,4,4,4))
        model = iunet_network.Decouple_IUnetModel(amp_model, vel_model, amp_iunet_model, vel_iunet_model)
        print("Decoupled IUnetModel model initialized.")
    else:
        print(f"Invalid Model: {model_type}")

    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model'])
    model = model.to(device)
    model.eval()
    
    return model

def get_model_auto_linear(model_path, model_type):
    # Load the TorchScript model
    model = torch.jit.load(model_path)

    # Set the model to evaluation mode
    model.eval()
    return model
    

In [None]:
# def permute_fn(tensor):
#     return torch.einsum("ijkl->iklj", tensor)
def permute_fn(tensor):
    return tensor

def plot_instance(amp_true, vel_true, transform_data, 
                  amp_true_auto_linear, vel_true_auto_linear, transform_data_auto_linear,
                  model_paths, model_train_dataset="FlatVel-A", evaluate_dataset="flatvel-b", plot=True):
    values_dict = {}
    
    amp_true_np = transform_data.inverse_transform(permute_fn(amp_true).detach().cpu().numpy())
    values_dict["ground_truth"] = amp_true_np
    for i, architecture_name in enumerate(architecture_names):
        print("Evaluating Architecture: ", architecture_name)

        model_path = os.path.join(base_path, 
                                  model_paths[model_train_dataset][architecture_name], 
                                  "latest_checkpoint.pth")
        print("Evaluating Model: ", model_path)
        #load model
        if architecture_types[i] == "IUNET":
            model = get_model_iunet(model_path, architecture_types[i])
            model = model.to(device)
            amp_pred = model.forward(vel_true)
        elif architecture_types[i] == "AutoLinearForward":
            model = get_model_auto_linear(model_path, architecture_types[i])
            model = model.to(device)
            output = model(vel_true_auto_linear, amp_true_auto_linear)
            amp_pred_auto_linear = output[4]
        else:
            model = get_model(model_path, architecture_types[i])  
            model = model.to(device)
            amp_pred = model(vel_true)
            
         ### diff normalization so diff inverse transofrm, not doing for inversion net as we retrain it with right normalization
        if architecture_types[i] == "AutoLinearForward":
            amp_pred_np = inverse_transform_auto_linear(evaluate_dataset, amp_pred_auto_linear)
        else:
            amp_pred_np = transform_data.inverse_transform(permute_fn(amp_pred).detach().cpu().numpy())
        values_dict[architecture_name] = amp_pred_np
    
    ## viz path for vizualizing all models in one place
    vis_path = os.path.join(base_path, model_train_dataset, viz_save_path)
    if not os.path.exists(vis_path):
        os.makedirs(vis_path) 
    ################
    generate_plot_viz(values_dict, vis_path, evaluate_dataset, num_images, plot=plot)
    generate_plot_trace(values_dict, os.path.join(vis_path, "trace_plots_forward", mode), evaluate_dataset, num_images, plot=plot)
    return values_dict

def generate_plot_viz(values_dict, vis_path, evaluate_dataset, num_images, plot=True):
    save_name = f"{evaluate_dataset}_{mode}_{num_images}"
    num_cols = len(values_dict.keys())
    
#     fig, axes = plt.subplots(num_images, num_cols, figsize=(3.5*num_cols, int(3*num_images)), dpi=150)
    fig = plt.figure(figsize=(3.3*num_cols, int(3.3*num_images)), dpi=150)
    gs = gridspec.GridSpec(num_images, num_cols + 1, width_ratios=[1]*num_cols + [0.1])
    for i in range(num_images):           
        amp_index=np.random.choice(5, 1)[0]
        for j, (key, values) in enumerate(values_dict.items()):
#             ax = axes[i, j]
            ax = fig.add_subplot(gs[i, j])
            img = ax.imshow(values_dict[key][i, 0], aspect='auto', vmin=-1, vmax=1, cmap="seismic")
#             divider = make_axes_locatable(ax)
#             cax = divider.append_axes("right", size="10%", pad=0.05)
#             cbar = ax.figure.colorbar(img,cax)
#             cbar.ax.tick_params(labelsize=font_sizes["color_bar"])  
#             plt.colorbar(img, cax=cax)
            plot_name = plot_names[key] if key in plot_names else key
            ax.set_title(f"{plot_names[key]}", fontsize=font_sizes["sub_plt_title"])
#             ax.set_title(f"{key}: Image {i}", fontsize=12)
            if j==0:
                ax.set_ylabel(plot_dataset_names[evaluate_dataset], fontsize=font_sizes["dataset_name"])
            ax.set_xticks([])
            ax.set_yticks([]) 
        divider = make_axes_locatable(ax)
        cax = fig.add_subplot(gs[i, num_cols])
        cbar = fig.colorbar(img, cax=cax)
        cbar.ax.tick_params(labelsize=font_sizes["color_bar"]) 
    plt.tight_layout()
    plt.savefig(os.path.join(vis_path, f"{save_name}_forward.pdf"))
    if plot:
        plt.show()
    plt.close()
    

def plot_trace_plots(values_dict, indices, direction='horizontal', image_id=0, vis_path="", save_name="", plot=True):
    """
    Plot trace plots for the given indices.

    Parameters:
    values_dict (dictionary): The ground truth matrix with shape (h, w).
    indices (list of int): The indices at which to plot the trace plots.
    direction (str): The direction of the trace plot ('horizontal' or 'vertical').
    """
    
    
    # Determine the number of plots
    num_trace_plots = len(indices)
    num_comparison_plots = len(values_dict)
    
    # Plot the original u_gt, u_pred and the difference
#     fig, axes = plt.subplots(1, num_comparison_plots, figsize=(3*num_comparison_plots, 2.7))
    fig = plt.figure(figsize=(3*num_comparison_plots, 3))
    gs = gridspec.GridSpec(1, num_comparison_plots + 1, width_ratios=[1]*num_comparison_plots + [0.1])
    for i, (key, values) in enumerate(values_dict.items()):
#         ax = axes[i]
        ax = fig.add_subplot(gs[i])
        im = ax.imshow(values[image_id, 0], aspect='auto', vmin=-1, vmax=1, cmap="seismic")
        ax.set_title(f'{plot_names[key]}')
        for idx in indices:
            ax.axhline(y=idx, color='k', linestyle='--') if direction == 'horizontal' else ax.axvline(x=idx, color='k', linestyle='--')
        xlabels = "Sensors Locations (m)"
        ax.set_xlabel(xlabels)
        ylabels = "Time (ms)"
        if i==0:
            ax.set_ylabel(ylabels)
#     plt.colorbar(im, ax=ax)
    divider = make_axes_locatable(ax)
    cax = fig.add_subplot(gs[num_comparison_plots])
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=font_sizes["color_bar"]) 
    plt.tight_layout()
    plt.savefig(os.path.join(vis_path, f"{save_name}_Image{image_id}_{direction}_compare.pdf"))
    if plot:
        plt.show()
    plt.close()
    
    # Set up the figure with subplots in a N x 1 configuration
    fig, axes = plt.subplots(1, num_trace_plots, figsize=(3 * num_trace_plots, 2.7))
    
    # If there's only one index, axes will not be an array, so we wrap it in a list for consistency
    if num_trace_plots == 1:
        axes = [axes]
    
    # Loop over the indices and plot the trace plots
    for i, idx in enumerate(indices):
        for j, (key, values) in enumerate(values_dict.items()):
            slice_ = values[image_id, 0, idx, :] if direction == 'horizontal' else values[image_id, 0, :, idx]
            linestyle = plot_map[key]["linestyle"]
            color = plot_map[key]["color"]
            zorder = plot_map[key]["zorder"]
            axes[i].plot(slice_, linestyle=linestyle, color=color, zorder=zorder, label=f'{plot_names[key]}')
                
        # Set labels and title
        axes[i].set_ylabel(f"Trace at index {idx} {direction}")
        xlabels = "Sensors Locations" if direction=="horizontal" else "Time (ms)"
        axes[i].set_xlabel(xlabels)
#         axes[i].legend(loc='upper right', fontsize=8)
    
    handles, labels = axes[0].get_legend_handles_labels()

    fig.legend(handles, labels, bbox_to_anchor=(0.5, 0.05),loc='upper center', ncol=4)
    plt.tight_layout()
    plt.savefig(os.path.join(vis_path, f"{save_name}_Image{image_id}_{direction}_TracePlot.pdf"), 
                             bbox_inches="tight")
    if plot:
        plt.show()
    plt.close()
    
def generate_plot_trace(values_dict, vis_path, evaluate_dataset, num_images, plot=True):
    if not os.path.exists(vis_path):
        os.makedirs(vis_path) 
    values_dict = {key: value for key, value in values_dict.items() if key in plot_map}
    
    save_name = evaluate_dataset
    horizontal_indices = [249, 499, 749]
    horizontal_indices = [150, 350, 550]
    vertical_indices = [17, 34, 51]
    for i in range(num_images):
        plot_trace_plots(values_dict, horizontal_indices, direction='horizontal', image_id=i, vis_path=vis_path, save_name=save_name, plot=plot)
        plot_trace_plots(values_dict, vertical_indices, direction='vertical', image_id=i, vis_path=vis_path, save_name=save_name, plot=plot)

        
def do_everything(model_train_datasets, evaluate_datasets, model_paths):
    items_dict = {}
    for idx, evaluate_dataset in enumerate(evaluate_datasets):
        print("Target Dataset: ", evaluate_dataset)
        dataset_val, _, transform_data, transform_label = get_dataloader(evaluate_dataset)
        if num_images == 20:
            items = true_items_dict[evaluate_dataset]
        else:
            items = np.random.choice(len(dataset_val), num_images)
            items_dict[evaluate_dataset] = items

        _, amp_true, vel_true = dataset_val[items]
        amp_true, vel_true = torch.tensor(amp_true).to(device), torch.tensor(vel_true).to(device)
        
         #### diff normalization for auto linear
        dataset_val_auto_linear, _, transform_data_auto_linear, transform_label_auto_linear = get_dataloader(evaluate_dataset, transform_mode = "AutoLinear")  
        _, amp_true_auto_linear, vel_true_auto_linear = dataset_val_auto_linear[items]
        amp_true_auto_linear, vel_true_auto_linear = torch.tensor(amp_true_auto_linear).to(device), torch.tensor(vel_true_auto_linear).to(device)
        
       
        for model_train_dataset in [model_train_datasets[idx]]:
            print("Source Dataset: ", model_train_dataset)
            value_dict = plot_instance(amp_true, vel_true, transform_data, 
                                       amp_true_auto_linear, vel_true_auto_linear, transform_data_auto_linear,
                                       model_paths, model_train_dataset, evaluate_dataset, plot=False)
    return value_dict, items_dict



In [None]:
value_dict=do_everything(model_train_datasets, evaluate_datasets, model_paths)