In [None]:
import sys
sys.path.append("../../")
print(sys.path)
import os
import json
import time
import datetime

import openpyxl
from openpyxl import Workbook
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

import torch
import torchvision
from torch import nn
from torchvision.transforms import Compose
from torch.utils.data import RandomSampler, DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter


from tqdm import tqdm
from tqdm.notebook import tqdm

from iunets import iUNet
from dataset import FWIDataset
# from networks import iunet_network
from networks import forward_network, inverse_network, iunet_network

import utils.transforms as T
from utils.pytorch_ssim import *
import utils.utilities as utils
from utils.scheduler import WarmupMultiStepLR
from matplotlib.colors import ListedColormap
import matplotlib.gridspec as gridspec

rainbow_cmap = ListedColormap(np.load('/projects/ml4science/OpenFWI/Latent_Bijectivity/utils/rainbow256.npy'))

forward_model_list = [forward_network.FNO2d, forward_network.WaveformNet, forward_network.WaveformNet_V2, iunet_network.IUnetForwardModel, iunet_network.UNetForwardModel]
inverse_model_list = [inverse_network.InversionNet, iunet_network.IUnetInverseModel, iunet_network.UNetInverseModel]
joint_model_list = [iunet_network.IUnetModel, iunet_network.JointModel, iunet_network.Decouple_IUnetModel]

In [None]:
step = 0
file_size = 500
vis_suffix = False
device = torch.device("cuda")

k = 1
workers = 4
lambda_g1v = 1
lambda_g2v = 1
batch_size = 50
mask_factor = 0.0
sample_temporal = 1
distributed = False

num_images = 2

cfg_path = "../../configs/"
latent_dim = 70
skip = 0


base_path = "/projects/ml4science/OpenFWI/Results/SupervisedExperiment/"

mode = "appendix" #"appendix" or "main_paper"
viz_save_path = f'AAAI_viz_aug15/{mode}'

font_sizes = {"color_bar": 16, "sub_plt_title":20, "dataset_name": 20}


evaluate_datasets = ["marmousi", "marmousi_smooth"]
model_train_datasets = ['FlatVel-A', 'FlatVel-B', 'CurveVel-A', 'CurveVel-B',
                        'FlatFault-A', 'FlatFault-B', 'CurveFault-A', 'CurveFault-B', 'Style-A', 'Style-B']

if mode == "appendix":
    architecture_types =  ["InversionNet",  "InversionNet", "AutoLinearInverse", "UNetInverseModel", "UNetInverseModel",  
                      "IUNET", "IUNET"]
                     
    architecture_names =  ["InversionNet", "Velocity_GAN", "AutoLinear_Inversion_ckpt", "UNetInverseModel_17M", "UNetInverseModel_33M", 
                     "Invertible_XNet", "Invertible_XNet_cycle_warmup"]
else:

    architecture_types =  ["InversionNet",  "AutoLinearInverse", "UNetInverseModel", "IUNET"]

    architecture_names =  ["InversionNet", "AutoLinear_Inversion_ckpt", "UNetInverseModel_33M", 
                           "Invertible_XNet"]


plot_names = {"InversionNet": "InversionNet", "UNetInverseModel_17M": "Latent U-Net (Small)", 
              "UNetInverseModel_33M": "Latent U-Net (Large)",
              "IUnetInverseModel": "IUnetInverseModel" ,
              "Invertible_XNet": "Invertible X-Net", "Invertible_XNet_cycle_warmup": "Invertible X-Net (Cycle)",
             "Velocity_GAN": "VelocityGAN", "ground_truth": "Ground Truth",
             "AutoLinear_Inversion_ckpt": "Auto-Linear"}

plot_dataset_names = {"flatvel-a": "FVA", "flatvel-b": "FVB", "curvevel-a": "CVA", "curvevel-b": "CVB",
                     "flatfault-a": "FFA", "flatfault-b": "FFB",  "curvefault-a": "CFA", "curvefault-b": "CFB", 
                      "style-a": "STA", "style-b": "STB", 
                      "marmousi_0": "Marmousi", "marmousi_smooth_0": "Marmousi smooth",
                      "marmousi_1": "Overthrust", "marmousi_smooth_1": "Overthrust smooth"}
                       
architecture_params = {"UNetInverseModel_17M":{"unet_depth": 2, "unet_repeat_blocks": 1}, 
                       "UNetInverseModel_33M":{"unet_depth": 2, "unet_repeat_blocks": 2},
                       "default":{"unet_depth": 2, "unet_repeat_blocks": 2}
                      }
model_paths = {}
for model_name in model_train_datasets:
    model_paths[model_name] = {}
    for i, architecture_name in enumerate(architecture_names):
        path_ = os.path.join(model_name, architecture_name, "fcn_l1loss_ffb")
        model_paths[model_name][architecture_name]= path_

In [None]:
linestyle_tuple = [
     ('loosely dotted',        (0, (1, 10))),
     ('dotted',                (0, (1, 1))),
     ('densely dotted',        (0, (1, 1))),
     ('long dash with offset', (5, (10, 3))),
     ('loosely dashed',        (0, (5, 10))),
     ('dashed',                (0, (5, 5))),
     ('densely dashed',        (0, (5, 1))),

     ('loosely dashdotted',    (0, (3, 10, 1, 10))),
     ('dashdotted',            (0, (3, 5, 1, 5))),
     ('densely dashdotted',    (0, (3, 1, 1, 1))),

     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]

linestyle_dict = {name: style for name, style in linestyle_tuple}

plot_map = {
                "ground_truth":{
                   "linestyle":"solid", 
                   "color":"black",
                   "zorder":0,
                  },
                 "InversionNet":{
                   "linestyle":linestyle_dict["dashed"], 
                   "color":"blue",
                   "zorder":3,
                  },
                 "UNetInverseModel_33M":{
                   "linestyle":linestyle_dict["densely dashed"], 
                   "color":"red",
                   "zorder":2,
                  },
                  "Invertible_XNet_cycle_warmup":{
                   "linestyle":"dotted", 
                   "color":"darkgreen",
                   "zorder":1,
                  },
           }



In [None]:
model_paths

In [None]:
criterions = {
    'MAE': lambda x, y: torch.mean(torch.abs(x - y)),
    'MSE': lambda x, y: torch.mean((x - y) ** 2)
}



def get_dataset_path(dataset):
    base_path = f"../../train_test_splits/"
    
    train_path = os.path.join(base_path, f"{dataset}_train.txt")
    val_path = os.path.join(base_path, f"{dataset}_val.txt")
    
    return train_path, val_path


def get_transforms(dataset, return_ctx=False):
    f = open('../../dataset_config.json')
    ctx = json.load(f)[dataset]

    transform_data = T.Normalize(ctx['data_mean'], ctx['data_std'])
    transform_label = T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])
    if return_ctx:
        return  transform_data, transform_label, ctx
    return  transform_data, transform_label



def get_transforms_vel_gan(dataset, return_ctx=False):
    k=1
    f = open('../../dataset_config.json')
    ctx = json.load(f)[dataset]
    
    # Normalize data and label to [-1, 1]
    transform_data = Compose([
                            T.LogTransform(k=k),
                            T.MinMaxNormalize(T.log_transform(ctx['data_min'], k=k), 
                            T.log_transform(ctx['data_max'], k=k))
                        ])
    transform_label = T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])
    
    if return_ctx:
        return  transform_data, transform_label, ctx
    return  transform_data, transform_label

def get_transforms_auto_linear(dataset, return_ctx=False):
    f = open('../../dataset_config.json')
    ctx = json.load(f)[dataset]
    
    transform_data = Compose([
                        T.LogTransform(k=k),
                        T.MinMaxNormalize(T.log_transform(ctx['data_min'], k=k), 
                        T.log_transform(ctx['data_max'], k=k))
                    ])

    transform_label = T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])

    inverse_transform_data = Compose([
        T.MinMaxNormalize(T.log_transform(ctx['data_min'], k=k), 
                            T.log_transform(ctx['data_max'], k=k)).inverse_transform,
        T.LogTransform(k=k).inverse_transform
    ])

    min_log = T.log_transform(ctx['data_min'], k)
    max_log = T.log_transform(ctx['data_max'],k)
    if return_ctx:
        return  transform_data, transform_label, ctx
    return  transform_data, transform_label, min_log, max_log


def get_dataloader(dataset, transform_mode="normal", train_dataset=None):
    
    if train_dataset is None:
        train_dataset = dataset
        
    if transform_mode =="normal":    
        transform_data, transform_label, ctx = get_transforms(train_dataset, return_ctx=True)
    if transform_mode =="VelocityGAN":
        transform_data, transform_label, ctx = get_transforms_vel_gan(train_dataset, return_ctx=True)
    if transform_mode =="AutoLinear":
        transform_data, transform_label, ctx = get_transforms_auto_linear(train_dataset, return_ctx=True)

    train_anno, val_anno = get_dataset_path(dataset)
        
    print(f'Loading {dataset} validation data')
    dataset_valid = FWIDataset(
        val_anno,
        preload=True,
        sample_ratio=sample_temporal,
        file_size=ctx['file_size'],
        transform_data=transform_data,
        transform_label=transform_label
    )
        
    valid_sampler = RandomSampler(dataset_valid)

    dataloader_valid = DataLoader(
                                dataset_valid, batch_size=batch_size,
                                sampler=valid_sampler, num_workers=workers,
                                pin_memory=True, collate_fn=default_collate, shuffle=False)
    
    print('Data loading over')
        
    return dataset_valid, dataloader_valid, transform_data, transform_label 

In [None]:
def set_inverse_params(inverse_model_params, model_path=None):
        inverse_model_params.setdefault('IUnetInverseModel', {})
        inverse_model_params['IUnetInverseModel']['cfg_path'] = cfg_path
        inverse_model_params['IUnetInverseModel']['latent_dim'] = latent_dim
        
        inverse_model_params.setdefault('UNetInverseModel', {})
        inverse_model_params['UNetInverseModel']['cfg_path'] = cfg_path
        inverse_model_params['UNetInverseModel']['latent_dim'] = latent_dim
        if "UNetInverseModel_17M" in model_path:
            print("here")
            inverse_model_params['UNetInverseModel']['unet_depth'] = architecture_params["UNetInverseModel_17M"]["unet_depth"]
            inverse_model_params['UNetInverseModel']['unet_repeat_blocks'] = architecture_params["UNetInverseModel_17M"]["unet_repeat_blocks"]
        else:
            inverse_model_params['UNetInverseModel']['unet_depth'] = architecture_params["default"]["unet_depth"]
            inverse_model_params['UNetInverseModel']['unet_repeat_blocks'] = architecture_params["default"]["unet_repeat_blocks"]
            
        inverse_model_params['UNetInverseModel']['skip'] = skip # skip true
        return inverse_model_params
    
    
def get_model(model_path, model_type):
    try:
        print(model_path, model_type)
        inverse_model_params = inverse_network.inverse_params
        inverse_model_params = set_inverse_params(inverse_model_params, model_path)
        model = inverse_network.model_dict[model_type](**inverse_model_params[model_type]).to(device)
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model'])
    except:
        print("Failed to load new model. Falling back to Legacy Code.")
        inverse_model_params = inverse_network.inverse_params_legacy
        if "UNetInverseModel_17M" in model_path:
            inverse_model_params['unet_depth'] = architecture_params["UNetInverseModel_17M"]["unet_depth"]
            inverse_model_params['unet_repeat_blocks'] = architecture_params["UNetInverseModel_17M"]["unet_repeat_blocks"]
        else:
            inverse_model_params['unet_depth'] = architecture_params["default"]["unet_depth"]
            inverse_model_params['unet_repeat_blocks'] = architecture_params["default"]["unet_repeat_blocks"]
            
        model_type = model_type+"_Legacy"
        model = inverse_network.model_dict[model_type](**inverse_model_params).to(device)
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model'])

    model = model.to(device)
    model.eval()
    
    return model

def get_model_iunet_(amp_model, vel_model, latent_channels, model_type):
    if model_type == "IUNET":
        iunet_model = iUNet(in_channels=latent_channels, dim=2, architecture=(4,4,4,4))
        model = iunet_network.IUnetModel(amp_model, vel_model, iunet_model).to(device)
        print("IUnet model initialized.")
    elif model_type == "Decouple_IUnet":
        amp_iunet_model = iUNet(in_channels=latent_channels, dim=2, architecture=(4,4,4,4))
        vel_iunet_model = iUNet(in_channels=latent_channels, dim=2, architecture=(4,4,4,4))
        model = iunet_network.Decouple_IUnetModel(amp_model, vel_model, amp_iunet_model, vel_iunet_model).to(device)
        print("Decoupled IUnetModel model initialized.")
    else:
        print(f"Invalid Model: {model_type}")
    return model


def get_model_iunet(model_path, model_type):
    try:   
        print(model_path, model_type)
        amp_cfg_name = get_config_name(latent_dim, model_type="amplitude")
        amp_model = autoencoder.AutoEncoder(cfg_path, amp_cfg_name).to(device)

        # creating velocity cnn
        vel_cfg_name = get_config_name(latent_dim, model_type="velocity")
        vel_model = autoencoder.AutoEncoder(cfg_path, vel_cfg_name).to(device)

        latent_channels = get_latent_dim(cfg_path, amp_cfg_name)
        model = get_model_iunet_(amp_model, vel_model, latent_channels, model_type)
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model'])
        
    except:
        print("Failed to load new model. Falling back to Legacy Code.")
        amp_input_channel = 5
        amp_encoder_channel = [8, 16, 32, 64, 128]
        amp_decoder_channel = [128, 64, 32, 16, 5]
        amp_model = iunet_network.AmpAutoEncoder(amp_input_channel, amp_encoder_channel, amp_decoder_channel).to(device)

        # creating velocity cnn
        vel_input_channel = 1
        vel_encoder_channel = [8, 16, 32, 64, 128]
        vel_decoder_channel = [128, 64, 32, 16, 1]
        vel_model = iunet_network.VelAutoEncoder(vel_input_channel, vel_encoder_channel, vel_decoder_channel).to(device)

        latent_channels = 128
        model = get_model_iunet_(amp_model, vel_model, latent_channels, model_type)
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model'])

    model = model.to(device)
    model.eval()
    
    return model


def get_model_auto_linear(model_path, model_type):
    # Load the TorchScript model
    model = torch.jit.load(model_path)

    # Set the model to evaluation mode
    model.eval()
    return model

In [None]:
def permute_fn(tensor):
    return tensor


def plot_instance(amp_true, vel_true, transform_label, 
                  amp_true_vel_gan, vel_true_vel_gan, transform_label_vel_gan, 
                  amp_true_auto_linear, vel_true_auto_linear, transform_label_auto_linear,
                  model_paths, model_train_dataset="FlatVel-A", evaluate_dataset="flatvel-b"):
    values_dict = {}
    
    vel_true_np = transform_label.inverse_transform(permute_fn(vel_true).detach().cpu().numpy())
    values_dict["ground_truth"] = vel_true_np
    
    for idx, architecture_name in enumerate(architecture_names):
        print("Evaluating Architecture: ", architecture_name)

        model_path = os.path.join(base_path,model_paths[model_train_dataset][architecture_name],"latest_checkpoint.pth")
        print("Evaluating Model: ", model_path)
        #load model
        if architecture_types[idx] == "IUNET":
            model = get_model_iunet(model_path, architecture_types[idx])
            model = model.to(device)
            vel_pred = model.inverse(amp_true)  
        elif architecture_types[idx] == "AutoLinearInverse":
            model = get_model_auto_linear(model_path, architecture_types[idx])
            model = model.to(device)
            output = model(amp_true_auto_linear, vel_true_auto_linear)
            vel_pred_auto_linear = output[4]
        else:
            model = get_model(model_path, architecture_types[idx])  
            model = model.to(device)
            if architecture_name == "Velocity_GAN":
                vel_pred_vel_gan = model(amp_true_vel_gan)
            else:
                vel_pred = model(amp_true)
        
         ### diff normalization so diff inverse transofrm, not doing for inversion net as we retrain it with right normalization
        if architecture_name == "Velocity_GAN":
            vel_pred_np = transform_label_vel_gan.inverse_transform(permute_fn(vel_pred_vel_gan).detach().cpu().numpy())
        elif architecture_types[idx] == "AutoLinearInverse":
            vel_pred_np = transform_label_auto_linear.inverse_transform(permute_fn(vel_pred_auto_linear).detach().cpu().numpy())
        else:
            vel_pred_np = transform_label.inverse_transform(permute_fn(vel_pred).detach().cpu().numpy())
        values_dict[architecture_name] = vel_pred_np
        

    ## viz path for vizualizing all models in one place
    vis_path = os.path.join(base_path, model_train_dataset, viz_save_path)
    if not os.path.exists(vis_path):
        os.makedirs(vis_path) 
    ################
    generate_plot_viz(values_dict, vis_path, evaluate_dataset, num_images)
    return values_dict


from copy import deepcopy
def add_diff_plots(values_dict):
    values_dict_new = deepcopy(values_dict)
    ground_truth = values_dict["ground_truth"]
    for key in values_dict.keys():
        if key!="ground_truth":
            diff = values_dict["ground_truth"] - values_dict[key]
            values_dict_new[f"diff_{key}"] = diff
    values_dict_new["XNET_cyle-InversionNet"] = values_dict_new["Invertible_XNet_cycle_warmup"] - values_dict_new["InversionNet"]
    values_dict_new["XNET_cyle-XNET"] = values_dict_new["Invertible_XNet_cycle_warmup"] - values_dict_new["Invertible_XNet"]
    return values_dict_new


def generate_plot_viz(values_dict, vis_path, evaluate_dataset, num_images, plot=True):
    save_name = f"{evaluate_dataset}_{mode}"
    num_cols = len(values_dict.keys())
    
#     fig, axes = plt.subplots(num_images, num_cols, figsize=(3.5*num_cols, int(3*num_images)), dpi=150)
    fig = plt.figure(figsize=(3.3*num_cols, int(3.3*num_images)), dpi=150)
    gs = gridspec.GridSpec(num_images, num_cols + 1, width_ratios=[1]*num_cols + [0.1])
    
    for i in range(num_images):
        vel_min, vel_max = np.inf, -np.inf
        for key, values in values_dict.items():
            vel_min = min(vel_min, values[i].min())
            vel_max = max(vel_max, values[i].max())
            
        for j, (key, values) in enumerate(values_dict.items()):
#             ax = axes[i, j]
            ax = fig.add_subplot(gs[i, j])
            img = ax.imshow(values_dict[key][i, 0], aspect='auto', vmin=vel_min, vmax=vel_max, cmap=rainbow_cmap)
#             divider = make_axes_locatable(ax)
#             cax = divider.append_axes("right", size="10%", pad=0.05)
#             plt.colorbar(img, cax=cax)
            plot_name = plot_names[key] if key in plot_names else key
            ax.set_title(f"{plot_name}", fontsize=font_sizes["sub_plt_title"])
#             ax.set_title(f"{key}: Image {i}", fontsize=12)
            if j==0:
                ax.set_ylabel(plot_dataset_names[f"{evaluate_dataset}_{i}"], fontsize=font_sizes["dataset_name"])
            ax.set_xticks([])
            ax.set_yticks([]) 
            
        divider = make_axes_locatable(ax)
        cax = fig.add_subplot(gs[i, num_cols])
        cbar = fig.colorbar(img, cax=cax)
        cbar.ax.tick_params(labelsize=font_sizes["color_bar"]) 
        
    plt.tight_layout()
    plt.savefig(os.path.join(vis_path, f"{save_name}.pdf"))
    if plot:
        plt.show()
    plt.close()

    
def generate_plot_trace(values_dict, vis_path, evaluate_dataset, num_images, plot=True):
    if not os.path.exists(vis_path):
        os.makedirs(vis_path) 
    values_dict = {key: value for key, value in values_dict.items() if key in plot_map}
    save_name = evaluate_dataset
    horizontal_indices = [17, 34, 51]
    vertical_indices = [17, 34, 51]
    for i in range(num_images):
        plot_trace_plots(values_dict, horizontal_indices, direction='horizontal', image_id=i, vis_path=vis_path, save_name=save_name)
        plot_trace_plots(values_dict, vertical_indices, direction='vertical', image_id=i, vis_path=vis_path, save_name=save_name)

    
def plot_trace_plots(values_dict, indices, direction='horizontal', image_id=0, vis_path="", save_name=""):
    """
    Plot trace plots for the given indices.

    Parameters:
    values_dict (dictionary): The ground truth matrix with shape (h, w).
    indices (list of int): The indices at which to plot the trace plots.
    direction (str): The direction of the trace plot ('horizontal' or 'vertical').
    """
    
    # Determine the number of plots
    num_trace_plots = len(indices)
    num_comparison_plots = len(values_dict)
    
    # Plot the original u_gt, u_pred and the difference
#     fig, axes = plt.subplots(1, num_comparison_plots, figsize=(3*num_comparison_plots, 2.7))
    fig = plt.figure(figsize=(3*num_comparison_plots, 3))
    gs = gridspec.GridSpec(1, num_comparison_plots + 1, width_ratios=[1]*num_comparison_plots + [0.1])
    
    for i, (key, values) in enumerate(values_dict.items()):
#         ax = axes[i]
        ax = fig.add_subplot(gs[i])
        im = ax.imshow(values[image_id, 0], aspect='auto', cmap=rainbow_cmap)
#         plt.colorbar(im, ax=ax)
        ax.set_title(f'{plot_names[key]}')
        for idx in indices:
            ax.axhline(y=idx, color='white', linestyle='--') if direction == 'horizontal' else ax.axvline(x=idx, color='white', linestyle='--')
        xlabels = "Sensors Locations (m)"
        ax.set_xlabel(xlabels)
        ylabels = "Depth (m)"
        if i==0:
            ax.set_ylabel(ylabels)
    divider = make_axes_locatable(ax)
    cax = fig.add_subplot(gs[num_comparison_plots])
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=font_sizes["color_bar"])    
    plt.tight_layout()
    plt.savefig(os.path.join(vis_path, f"{save_name}_Image{image_id}_{direction}_compare.pdf"))
    plt.show()
    plt.close()
    
    # Set up the figure with subplots in a N x 1 configuration
    fig, axes = plt.subplots(1, num_trace_plots, figsize=(3 * num_trace_plots, 2.7))
    
    # If there's only one index, axes will not be an array, so we wrap it in a list for consistency
    if num_trace_plots == 1:
        axes = [axes]
    
    # Loop over the indices and plot the trace plots
    for i, idx in enumerate(indices):
        for j, (key, values) in enumerate(values_dict.items()):
            slice_ = values[image_id, 0, idx, :] if direction == 'horizontal' else values[image_id, 0, :, idx]
#             linestyle = '--' if key=="ground_truth" else '-'
#             axes[i].plot(slice_, linestyle=linestyle, label=f'{key}')
            linestyle = plot_map[key]["linestyle"]
            color = plot_map[key]["color"]
            zorder = plot_map[key]["zorder"]
            axes[i].plot(slice_, linestyle=linestyle, color=color, zorder=zorder, label=f'{plot_names[key]}')
                
        # Set labels and title
        axes[i].set_ylabel(f"Trace at index {idx} {direction}")
        xlabels = "Sensors Locations" if direction=="horizontal" else "Depth"
        axes[i].set_xlabel(xlabels)
#         axes[i].legend(loc='upper right', fontsize=8)
    
    handles, labels = axes[0].get_legend_handles_labels()

    fig.legend(handles, labels, bbox_to_anchor=(0.5, 0.05),loc='upper center', ncol=4)
    plt.tight_layout()
    plt.savefig(os.path.join(vis_path, f"{save_name}_Image{image_id}_{direction}_TracePlot.pdf"), 
                             bbox_inches="tight")
    plt.show()
    plt.close()
    
    
def do_everything(model_train_datasets, evaluate_datasets, model_paths):
    items_dict = {}
    for idx, evaluate_dataset in enumerate(evaluate_datasets):
        print("Target Dataset: ", evaluate_dataset)
        
        dataset_val, _, transform_data, transform_label = get_dataloader(evaluate_dataset)
        items = np.random.choice(len(dataset_val), num_images)
        items_dict[evaluate_dataset] = items

        _, amp_true, vel_true = dataset_val[items]
        amp_true, vel_true = torch.tensor(amp_true).to(device), torch.tensor(vel_true).to(device)
        
        #### diff normalization for vel gan
        dataset_val_vel_gan, _, transform_data_vel_gan, transform_label_vel_gan = get_dataloader(evaluate_dataset, transform_mode = "VelocityGAN")  
        _, amp_true_vel_gan, vel_true_vel_gan = dataset_val_vel_gan[items]
        amp_true_vel_gan, vel_true_vel_gan = torch.tensor(amp_true_vel_gan).to(device), torch.tensor(vel_true_vel_gan).to(device)
        
         
        #### diff normalization for auto linear
        dataset_val_auto_linear, _, transform_data_auto_linear, transform_label_auto_linear = get_dataloader(evaluate_dataset, transform_mode = "AutoLinear")  
        _, amp_true_auto_linear, vel_true_auto_linear = dataset_val_auto_linear[items]
        amp_true_auto_linear, vel_true_auto_linear = torch.tensor(amp_true_auto_linear).to(device), torch.tensor(vel_true_auto_linear).to(device)
        
        
        
        for model_train_dataset in [model_train_datasets[idx]]:
            print("Source Dataset: ", model_train_dataset)
            value_dict = plot_instance(amp_true, vel_true, transform_label, 
                                       amp_true_vel_gan, vel_true_vel_gan, transform_label_vel_gan, 
                                       amp_true_auto_linear, vel_true_auto_linear, transform_label_auto_linear,
                                       model_paths, model_train_dataset, evaluate_dataset)
    return value_dict,items_dict



In [None]:
def evaluate_marmousi(evaluate_dataset="marmousi", model_train_dataset="Style-A",model_paths=model_paths):
    
    model_train_dataset_ = model_train_dataset.lower()
    dataset_val, _, transform_data, transform_label = get_dataloader(dataset=evaluate_dataset, 
                                                                     train_dataset=model_train_dataset_)
    items = np.array([0, 1]) # marmousi dataset has 2 items
    _, amp_true, vel_true = dataset_val[items]
    amp_true, vel_true = torch.tensor(amp_true).to(device), torch.tensor(vel_true).to(device)
    
    #### diff normalization for vel gan
    dataset_val_vel_gan, _, transform_data_vel_gan, transform_label_vel_gan = get_dataloader(evaluate_dataset, train_dataset=model_train_dataset_, transform_mode = "VelocityGAN")  
    _, amp_true_vel_gan, vel_true_vel_gan = dataset_val_vel_gan[items]
    amp_true_vel_gan, vel_true_vel_gan = torch.tensor(amp_true_vel_gan).to(device), torch.tensor(vel_true_vel_gan).to(device)


    #### diff normalization for auto linear
    dataset_val_auto_linear, _, transform_data_auto_linear, transform_label_auto_linear = get_dataloader(evaluate_dataset, train_dataset=model_train_dataset_, transform_mode = "AutoLinear")  
    _, amp_true_auto_linear, vel_true_auto_linear = dataset_val_auto_linear[items]
    amp_true_auto_linear, vel_true_auto_linear = torch.tensor(amp_true_auto_linear).to(device), torch.tensor(vel_true_auto_linear).to(device)
        
        
    print("Source Dataset: ", model_train_dataset)
    value_dict = plot_instance(amp_true, vel_true, transform_label, 
                                       amp_true_vel_gan, vel_true_vel_gan, transform_label_vel_gan, 
                                       amp_true_auto_linear, vel_true_auto_linear, transform_label_auto_linear,
                                       model_paths, model_train_dataset, evaluate_dataset)
    
    

In [None]:
for model_train_dataset in model_train_datasets:
    for evaluate_dataset in evaluate_datasets:
        evaluate_marmousi(evaluate_dataset=evaluate_dataset, model_train_dataset=model_train_dataset, model_paths=model_paths)