In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

from plot_settings import global_setting
from matplotlib.lines import Line2D

global_setting()

In [None]:
model_names = ['FlatVel-A', 'FlatVel-B',
              'CurveVel-A', 'CurveVel-B',
              'FlatFault-A', 'FlatFault-B',
              'CurveFault-A', 'CurveFault-B']

metrics = ["MSE", "MAE", "SSIM"]


In [None]:
results_dict = {}
for metric in metrics:
    results_dict[metric] = pd.DataFrame(columns=["Dataset", "Metric"])
    results_dict[metric]["Dataset"] = model_names
    results_dict[metric]["Metric"] = metric


In [None]:
results_dict["MSE"]

In [None]:
model_dict = {"Small_70" : "UNetInverseModel_17M",
              "Large_70" : "UNetInverseModel_33M",
              "Small_64" : "UNetInverseModel_17M_Latent64",
              "Large_64" : "UNetInverseModel_33M_Latent64",
              "Small_32" : "UNetInverseModel_17M_Latent32",
              "Large_32" : "UNetInverseModel_33M_Latent32",
              "Small_16" : "UNetInverseModel_17M_Latent16",
              "Large_16" : "UNetInverseModel_33M_Latent16",
              "Small_8" : "UNetInverseModel_17M_Latent8",
              "Large_8" : "UNetInverseModel_33M_Latent8",
             }


color_dict = {
    "Small_70": 'blue',
    "Large_70": 'green',
    "Small_64": 'red',
    "Large_64": 'cyan',
    "Small_32": 'magenta',
    "Large_32": 'yellow',
    "Small_16": 'black',
    "Large_16": 'purple',
    "Small_8": 'orange',
    "Large_8": 'pink'
}


basepath = "../Metrics_final/"
prefix = "eval_metric"
mode = "Velocity"

sheets = ["Velocity_unnorm_MAE", "Velocity_unnorm_MSE", "Velocity_SSIM"]

In [None]:
for key, value in model_dict.items():
    grid_file = os.path.join(basepath, prefix+value+".xlsx")
    model_name = key
    for metric in sheets:
        metric_type = metric.split("_")[-1]
        #read the particular metric sheet in the result grid file (dataset vs dataset)
        df = pd.read_excel(grid_file, engine='openpyxl', header=0, index_col=0, sheet_name=metric)
        metric_diag = [df[dataset.lower()][dataset] for dataset in model_names]
        results_dict[metric_type][model_name] = metric_diag

In [None]:
results_dict["MSE"]

In [None]:
def extract_and_sort(df, key="Small"):
    df = df.filter(like=key)
    sorted_columns = sorted(df.columns, key=lambda x: int(x.split('_')[-1]))
    # Rearranging the DataFrame based on the sorted column names
    sorted_df = df[sorted_columns]
    latent_dims = [int(string.split("_")[-1]) for string in list(sorted_df.columns)]
    return latent_dims, sorted_df.values[0] 

color_dict = {
              'FlatVel-A':"lightcoral", 
              'FlatVel-B':"firebrick",
              'CurveVel-A':"limegreen",
              'CurveVel-B':"darkgreen",
              'FlatFault-A': "deepskyblue",
              'FlatFault-B': "blue",
              'CurveFault-A':"violet",
              'CurveFault-B':"purple"
             }

marker_dict = {
              'FlatVel-A':"s", 
              'FlatVel-B':"x",
              'CurveVel-A':"s",
              'CurveVel-B':"x",
              'FlatFault-A': "s",
              'FlatFault-B': "x",
              'CurveFault-A':"s",
              'CurveFault-B':"x"
             }


In [None]:
save_dir = "../PaperFigures/AAAI-LatentDim/"
os.makedirs(save_dir, exist_ok=True)

# Large UNets

In [None]:
metric = "SSIM"
unet_type = "Large"

results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, large_metrics = extract_and_sort(dataset_df, key=unet_type)
#     plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^")#, label=dataset, color=color_dict[dataset])
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker=marker_dict[dataset], label=dataset, color=color_dict[dataset])

# plt.ylim([4, 2e4])
# plt.yscale("log")
plt.ylim([0.49, 1.05])
plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
plt.title(f"Latent U-Net (Large)", fontsize=15)
plt.savefig(save_dir+f"effect_latent_dim_unet_{unet_type}_{metric}.png", dpi=150, bbox_inches="tight")
# plt.show()


In [None]:
metric = "MSE"
unet_type = "Large"

results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, large_metrics = extract_and_sort(dataset_df, key=unet_type)
#     plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^")#, label=dataset, color=color_dict[dataset])
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker=marker_dict[dataset], label=dataset, color=color_dict[dataset])

# plt.yscale("log")
# plt.ylim([0.49, 1.05])
plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
plt.title(f"Latent U-Net (Large)", fontsize=15)
plt.savefig(save_dir+f"effect_latent_dim_unet_{unet_type}_{metric}.png", dpi=150, bbox_inches="tight")

In [None]:
metric = "MAE"
unet_type = "Large"

results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, large_metrics = extract_and_sort(dataset_df, key=unet_type)
#     plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^")#, label=dataset, color=color_dict[dataset])
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker=marker_dict[dataset], label=dataset, color=color_dict[dataset])

# plt.yscale("log")
# plt.ylim([0.49, 1.05])
plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
plt.title(f"Latent U-Net (Large)", fontsize=15)
plt.savefig(save_dir+f"effect_latent_dim_unet_{unet_type}_{metric}.png", dpi=150, bbox_inches="tight")

# Small UNets

In [None]:
metric = "SSIM"
unet_type = "Small"

results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, large_metrics = extract_and_sort(dataset_df, key=unet_type)
#     plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^")#, label=dataset, color=color_dict[dataset])
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker=marker_dict[dataset], label=dataset, color=color_dict[dataset])

# plt.ylim([4, 2e4])
# plt.yscale("log")
# plt.ylim([0.49, 1.05])
plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
plt.title(f"Latent U-Net (Small)", fontsize=15)
# plt.show()
plt.savefig(save_dir+f"effect_latent_dim_unet_{unet_type}_{metric}.png", dpi=150, bbox_inches="tight")

In [None]:
metric = "MSE"
unet_type = "Small"

results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, large_metrics = extract_and_sort(dataset_df, key=unet_type)
#     plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^")#, label=dataset, color=color_dict[dataset])
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker=marker_dict[dataset], label=dataset, color=color_dict[dataset])

# plt.ylim([4, 2e4])
# plt.yscale("log")
# plt.ylim([0.49, 1.05])
plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
plt.title(f"Latent U-Net (Small)", fontsize=15)
# plt.show()
plt.savefig(save_dir+f"effect_latent_dim_unet_{unet_type}_{metric}.png", dpi=150, bbox_inches="tight")

In [None]:
metric = "MAE"
unet_type = "Small"

results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, large_metrics = extract_and_sort(dataset_df, key=unet_type)
#     plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^")#, label=dataset, color=color_dict[dataset])
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker=marker_dict[dataset], label=dataset, color=color_dict[dataset])

# plt.ylim([4, 2e4])
# plt.yscale("log")
# plt.ylim([0.49, 1.05])
plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
plt.title(f"Latent U-Net (Small)", fontsize=15)
# plt.show()
plt.savefig(save_dir+f"effect_latent_dim_unet_{unet_type}_{metric}.png", dpi=150, bbox_inches="tight")

# Comparing Large and Small on Family B

In [None]:
metric = "MAE"


results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("B")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, small_metrics = extract_and_sort(dataset_df, key="Small")
    latent_dims, large_metrics = extract_and_sort(dataset_df, key="Large")
    plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^", color=color_dict[dataset])#, label=dataset)
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker="s", label=dataset, color=color_dict[dataset])

# plt.ylim([4, 2e4])
# plt.yscale("log")
# plt.ylim([0.49, 1.05])
first_legend = plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
plt.gca().add_artist(first_legend)
# Second legend for line styles
legend_elements = [
    Line2D([0], [0], linestyle="solid", color="black", label="Small", marker="s"),
    Line2D([0], [0], linestyle="dashed", color="black", label="Large", marker="x")
]
plt.legend(handles=legend_elements, loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, 0.4])

plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
plt.title(f"Latent U-Net (Small vs Large)", fontsize=15)
# plt.show()
plt.savefig(save_dir+f"effect_latent_dim_unet_large_vs_small_{metric}.png", dpi=150, bbox_inches="tight")

In [None]:
metric = "MSE"


results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("B")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, small_metrics = extract_and_sort(dataset_df, key="Small")
    latent_dims, large_metrics = extract_and_sort(dataset_df, key="Large")
    plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^", color=color_dict[dataset])#, label=dataset)
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker="s", label=dataset, color=color_dict[dataset])

# plt.ylim([4, 2e4])
# plt.yscale("log")
# plt.ylim([0.49, 1.05])
first_legend = plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
plt.gca().add_artist(first_legend)
# Second legend for line styles
legend_elements = [
    Line2D([0], [0], linestyle="solid", color="black", label="Small", marker="s"),
    Line2D([0], [0], linestyle="dashed", color="black", label="Large", marker="x")
]
plt.legend(handles=legend_elements, loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, 0.4])

plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
plt.title(f"Latent U-Net (Small vs Large)", fontsize=15)
# plt.show()
plt.savefig(save_dir+f"effect_latent_dim_unet_large_vs_small_{metric}.png", dpi=150, bbox_inches="tight")

In [None]:
metric = "SSIM"
from matplotlib.lines import Line2D

results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("B")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, small_metrics = extract_and_sort(dataset_df, key="Small")
    latent_dims, large_metrics = extract_and_sort(dataset_df, key="Large")
    plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^", color=color_dict[dataset])#, label=dataset)
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker="s", label=dataset, color=color_dict[dataset])

# plt.ylim([4, 2e4])
# plt.yscale("log")
# plt.ylim([0.49, 1.05])
first_legend = plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
plt.gca().add_artist(first_legend)
# Second legend for line styles
legend_elements = [
    Line2D([0], [0], linestyle="solid", color="black", label="Small", marker="s"),
    Line2D([0], [0], linestyle="dashed", color="black", label="Large", marker="x")
]
plt.legend(handles=legend_elements, loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, 0.4])

plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
plt.title(f"Latent U-Net (Small vs Large)", fontsize=15)
# plt.show()
plt.savefig(save_dir+f"effect_latent_dim_unet_large_vs_small_{metric}.png", dpi=150, bbox_inches="tight")

In [None]:
color_dict = {
              'FlatVel-A':"firebrick", 
              'FlatVel-B':"firebrick",
              'CurveVel-A':"darkgreen",
              'CurveVel-B':"darkgreen",
              'FlatFault-A': "blue",
              'FlatFault-B': "blue",
              'CurveFault-A':"purple",
              'CurveFault-B':"purple"
             }

marker_dict = {
              'FlatVel-A':"s", 
              'FlatVel-B':"x",
              'CurveVel-A':"s",
              'CurveVel-B':"x",
              'FlatFault-A': "s",
              'FlatFault-B': "x",
              'CurveFault-A':"s",
              'CurveFault-B':"x"
             }

metric = "MAE"
unet_type = "Large"

results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("")]

plt.figure(figsize=(4, 3))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, large_metrics = extract_and_sort(dataset_df, key=unet_type)
#     plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^")#, label=dataset, color=color_dict[dataset])
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker=marker_dict[dataset], label=dataset, color=color_dict[dataset])

# plt.yscale("log")
# plt.ylim([0.49, 1.05])
# plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])

# first_legend = plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
# plt.gca().add_artist(first_legend)
# Second legend for line styles
legend_elements = [
    Line2D([0], [0], linestyle="solid", color="black", marker ="s", label="A"),
    Line2D([0], [0], linestyle="solid", color="black", marker="x", label="B")
]

second_legend = plt.legend(handles=legend_elements, loc="upper right", ncol=1, fontsize=12)
plt.gca().add_artist(second_legend)

# plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])


plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
# plt.title(f"Latent U-Net (Large)", fontsize=15)
plt.savefig(save_dir+f"effect_latent_dim_unet_{unet_type}_{metric}.png", dpi=150, bbox_inches="tight")

In [None]:
metric = "MAE"


results_A = results_dict[metric][results_dict[metric]["Dataset"].str.contains("B")]

plt.figure(figsize=(4, 3.5))
datasets = list(results_A['Dataset'])
for dataset in datasets:
    dataset_df = results_A[results_A['Dataset']==dataset] 
    latent_dims, small_metrics = extract_and_sort(dataset_df, key="Small")
    latent_dims, large_metrics = extract_and_sort(dataset_df, key="Large")
    plt.plot(latent_dims, small_metrics, linestyle="solid", marker="^", color=color_dict[dataset])#, label=dataset)
    plt.plot(latent_dims, large_metrics, linestyle="dashed", marker="s", label=dataset, color=color_dict[dataset])

# plt.ylim([4, 2e4])
# plt.yscale("log")
# plt.ylim([0.49, 1.05])
# first_legend = plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])
# plt.gca().add_artist(first_legend)
# Second legend for line styles
legend_elements = [
    Line2D([0], [0], linestyle="solid", color="black", label="Small", marker="s"),
    Line2D([0], [0], linestyle="dashed", color="black", label="Large", marker="x")
]

second_legend = plt.legend(handles=legend_elements, loc="upper right", ncol=1, fontsize=12)
plt.gca().add_artist(second_legend)

plt.legend(loc="upper right", ncol=1, fontsize=12, bbox_to_anchor=[1.6, .93])


# plt.legend(first_legend)
plt.grid("on", alpha=0.3)
plt.ylabel(f"{metric}", fontsize=15)
plt.xlabel(f"Latent Dimension", fontsize=15)
# plt.title(f"Latent U-Net (Small vs Large)", fontsize=15)
# plt.show()
plt.savefig(save_dir+f"effect_latent_dim_unet_large_vs_small_{metric}.png", dpi=150, bbox_inches="tight")