# Visualization of the field error map

In [None]:
import os
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
import plotly.express as px

def get_design_range(resolution, start=0.4, width=0.275):
    dl = 1e-6/resolution
    NPML = int(1/1e6/dl)
    subtrate_range = (0, 0 + int(start/1e6/dl))
    design_range = (subtrate_range[1],subtrate_range[1] + int(width/1e6/dl))
    return design_range
    
def plotly_imshow_func(name, data, wavelength, sim_name, design_range_dict, zmax):
    
    design_range = get_design_range(40, design_range_dict['design_start'], design_range_dict['width'])
    
    output_field = data['fields_dict'][f'{wavelength}'][0][0]
    target_field = data['fields_dict'][f'{wavelength}'][1][0]
    font=dict(
            family="Helvetica",
            size=45)
    
    error_field = torch.abs(output_field[0] - target_field[0])
    # print(f"{sim_name} and {name} max : {error_field.max()}")
    fig = px.imshow(error_field, color_continuous_scale='magma', zmin=0, zmax=zmax)
    fig.add_hrect(y0=design_range[0], y1=design_range[1], line_color="green", line_width=4)
    # fig.update_layout(coloraxis_showscale=False, autosize=False, title_text='FNO2d', title_x=0.5, title_y=0.95, font=font)
    dtick = round(zmax/4, 2)
    fig.update_layout(autosize=False, coloraxis_showscale=True, margin={'t':0,'l':0,'b':0,'r':0},font=font, coloraxis_colorbar=dict(dtick=dtick))

    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    
    
    fig.update_layout(
            font=dict(
                family="Helvetica",  # 폰트 종류 (예: "Arial", "Helvetica", "Times New Roman" 등)
                size=24,             # 폰트 크기 (기본 크기 10, 원하는 크기로 설정)
                color="black"        # 폰트 색상 (기본은 "black", 원하는 색상으로 설정)
            )
        )
    fig.show()
    fig.write_image(f"saved_figures/field_error_map/{sim_name}/{name}.pdf")

sim_names = ['single_layer','triple_layer', 'straight_waveguide', 'image_sensor']
sim_zmax_v_dict = {'single_layer': 0.25, 'triple_layer': 0.12, 'straight_waveguide': 0.5, 'image_sensor': 0.7}
# sim_names = ['image_sensor']
design_range_dict = {'single_layer': {'design_start':0.4, 'width': 0.12},'triple_layer': {'design_start':0.4, 'width': 0.12}, 'straight_waveguide': {'design_start':1.0, 'width':4.85 }, 'image_sensor': {'design_start':0.6, 'width': 3.5}}
names = ['wino', 'fno2d', 'fno2dfactor', 'neurolight', 'unet']
for sim_name in sim_names:
    paths = [
        f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/wino/nmse_waveprior_64dim_12layer_256_5060_auggroup4_weightsharing.pt',
        f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/fno2d/nmse_waveprior_32dim_5layer_3210.pt',
        f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/fno2dfactor/nmse_wp_64_12layer_256_mode5060.pt',
        f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/neurolight/nmse_wp_64_16layer_256_mode5060_dp01_bs32_ressetm.pt',
        f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/unet/nmse_waveprior_16dim.pt',
    ]
    for n, p in zip(names, paths):
        data = torch.load(p)
        name = n
        wavelength = 410 # nanometer
        os.makedirs(f"saved_figures/field_error_map/{sim_name}", exist_ok=True)
        plotly_imshow_func(name, data, wavelength, sim_name, design_range_dict[sim_name], zmax=sim_zmax_v_dict[sim_name])

# Box Plot Visualization

In [None]:
# box plot
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# Load data
sim_names = ['single_layer','triple_layer', 'straight_waveguide', 'image_sensor']
# sim_names = ['image_sensor']
design_range_dict = {'single_layer': {'design_start':0.4, 'width': 0.12},'triple_layer': {'design_start':0.4, 'width': 0.12}, 'straight_waveguide': {'design_start':1.0, 'width':4.85 }, 'image_sensor': {'design_start':0.6, 'width': 3.5}}
dfs = {}
trained_wavelengths = list(range(400, 701, 20))
models = ['UNet', 'FNO2d', 'F-FNO2d', 'NeurOLight', 'WINO']
eval_modes = ['nmse_val_dict', 'structure_nmse_val_dict', 'near_nmse_val_dict']

file_type = ['png', 'svg', 'pdf']


for sim_name in sim_names:
    wino = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/wino/nmse_waveprior_64dim_12layer_256_5060_auggroup4_weightsharing.pt')
    unet = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/unet/nmse_waveprior_16dim.pt')
    fno2d = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/fno2d/nmse_waveprior_32dim_5layer_3210.pt')
    fno2dfactor = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/fno2dfactor/nmse_wp_64_12layer_256_mode5060.pt')
    neurolight = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/neurolight/nmse_wp_64_16layer_256_mode5060_dp01_bs32_ressetm.pt')

    all_wvls = list(range(400, 701))

    data_dict = {'UNet': unet, "FNO2d": fno2d, "F-FNO2d": fno2dfactor, "NeurOLight": neurolight, 'Ours': wino}

    
    for eval_mode in eval_modes:
        # Initialize a list to hold data for the plot
        data_nmses = []
        
        # Collecting WINO data
        for wvl in all_wvls:
            for k, v in data_dict.items():
                nmses = v[eval_mode][str(wvl)]
                if eval_mode == 'near_nmse_val_dict':
                    nmses = nmses.view(-1)
                trained_str = "Trained" if int(wvl) in trained_wavelengths else "Untrained"
                for nmse in nmses:
                    data_nmses.append([nmse, wvl, k, trained_str])


        # Convert data into a DataFrame
        dfs[sim_name] = pd.DataFrame(data=data_nmses, columns=['NMSE', 'wavelength (nm)', 'model', 'trained'])


        # Create the plot using Plotly
        fig = go.Figure()


        df = dfs[sim_name]
        summary_df = df.groupby(['trained', 'model']).agg(
            NMSE_mean=('NMSE', 'mean'),
            NMSE_std=('NMSE', 'std')
        ).reset_index()

        # Bar plot 생성
        fig = go.Figure()
        # summary_df['x_label'] = summary_df.apply(lambda row: f"{row['model']}<br>{row['trained']}", axis=1)

        # 원하는 방향으로 바 플롯 생성
        fig = px.box(df,
                    x='model',
                    y='NMSE',
                    color='trained',             # 모델 별 색상 구분
                    color_discrete_map={'Trained': 'rgba(1,105,169,1)', 
                                        'Untrained': 'rgba(194,55,55,1)'}, )
                    #title='NMSE (Mean ± Std) per Condition (Grouped by Model)'
        fig.update_traces(
            selector=dict(type="box"),
            line=dict(width=4),      # 박스 테두리 선 굵기
        )
        fig.update_layout(
            #title='Model-wise Mean NMSE with Std',
            xaxis_title='Model',
            yaxis_title='NMSE',
            xaxis=dict(
                showline=True,   # Show x-axis line
                linecolor='black',  # Set x-axis line color
                linewidth=4,        # Set x-axis line width
                ticks='outside'     # Place ticks outside
            ),
            yaxis=dict(
                showline=True,   # Show y-axis line
                linecolor='black',  # Set y-axis line color
                linewidth=4,        # Set y-axis line width
                ticks='outside'     # Place ticks outside
            ),
            template='simple_white',
            legend=dict(
                x=0.8,  # Positioning the legend inside (adjust as needed)
                y=0.97,  # Adjust the height position of the legend
                traceorder='normal',
                font=dict(size=32),
                bgcolor='rgba(255, 255, 255, 1)',  # Transparent background for legend
                bordercolor='rgba(0, 0, 0, 0.5)',  # No border for legend
                borderwidth=4,
                tracegroupgap=0,  # Remove gap between legend items
            ),
            legend_title=""         # Hide the "trained" title in the legend
        )
        fig.update_xaxes(
            ticks="outside",    # 바깥쪽으로 눈금 표시
            tickwidth=4,        # 눈금 선 두께 (기본 1)
            ticklen=8,          # 눈금 길이 (기본 5)
            tickcolor="black"   # 눈금 색상
        )
        fig.update_yaxes(
            ticks="outside",    # 바깥쪽으로 눈금 표시
            tickwidth=4,        # 눈금 선 두께 (기본 1)
            ticklen=8,          # 눈금 길이 (기본 5)
            tickcolor="black"   # 눈금 색상
        )
        
        fig.update_layout(
            height=1000,
            width=1000
        )
        fig.update_layout(
            font=dict(
                family="Helvetica",  # 폰트 종류 (예: "Arial", "Helvetica", "Times New Roman" 등)
                size=32,             # 폰트 크기 (기본 크기 10, 원하는 크기로 설정)
                color="black"        # 폰트 색상 (기본은 "black", 원하는 색상으로 설정)
            )
        )
        # fig.show()
        os.makedirs(f"saved_figures/box_plot_comparison/{sim_name}", exist_ok=True)
        for ft in file_type:
            fig.write_image(f"saved_figures/box_plot_comparison/{sim_name}/{eval_mode}.{ft}")

# WINO vs NeurOLight : loss plot for varying wavelengths

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

sim_names = ['single_layer','triple_layer', 'straight_waveguide', 'image_sensor']
# sim_names = ['image_sensor']
design_range_dict = {'single_layer': {'design_start':0.4, 'width': 0.12},'triple_layer': {'design_start':0.4, 'width': 0.12}, 'straight_waveguide': {'design_start':1.0, 'width':4.85 }, 'image_sensor': {'design_start':0.6, 'width': 3.5}}
eval_modes = ['nmse_val_dict', 'structure_nmse_val_dict', 'near_nmse_val_dict']
file_type = ['png', 'svg', 'pdf']
for sim_name in sim_names:
    os.makedirs(f"saved_figures/varying_wavelengths/{sim_name}", exist_ok=True)
    # Load data
    data = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/wino/nmse_waveprior_64dim_12layer_256_5060_auggroup4_weightsharing.pt')
    neurolight = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/neurolight/nmse_wp_64_16layer_256_mode5060_dp01_bs32_ressetm.pt')

    all_wvls = list(range(400, 701))

    for eval_mode in eval_modes:
        # Initialize a list to hold data for the plot

        data_nmses = []

        # Collecting WINO data
        for wvl in all_wvls:
            nmses = data[eval_mode][str(wvl)]
            if eval_mode == 'near_nmse_val_dict':
                nmses = nmses.view(-1)
            for nmse in nmses:
                data_nmses.append([nmse, wvl, 'WINO'])

        # Collecting NeurOLight data
        for wvl in all_wvls:
            nmses = neurolight[eval_mode][str(wvl)]
            if eval_mode == 'near_nmse_val_dict':
                nmses = nmses.view(-1)
            for nmse in nmses:
                data_nmses.append([nmse, wvl, 'NeurOLight'])

        # Convert data into a DataFrame
        df = pd.DataFrame(data=data_nmses, columns=['NMSE', 'wavelength (nm)', 'model'])

        # Calculate the standard deviation for each wavelength for each model
        std_df = df.groupby(['wavelength (nm)', 'model']).agg({'NMSE': 'std'}).reset_index()
        mean_df = df.groupby(['wavelength (nm)', 'model']).agg({'NMSE': 'mean'}).reset_index()

        # Merge the mean and std into a single dataframe for easier plotting
        df_merged = pd.merge(mean_df, std_df, on=['wavelength (nm)', 'model'], suffixes=('_mean', '_std'))

        # Create the plot using Plotly
        fig = go.Figure()

        # Add lines with filled area for standard deviation
        for model in df_merged['model'].unique():
            model_data = df_merged[df_merged['model'] == model]
            
            # Add the line for mean values
            fig.add_trace(go.Scatter(
                x=model_data['wavelength (nm)'], 
                y=model_data['NMSE_mean'], 
                mode='lines', 
                name=model,
                line=dict(color='rgba(194,55,55,1)' if model == 'WINO' else 'rgba(1,105,169,1)', width=4)
            ))
            
            # Add shaded region for the std deviation
            fig.add_trace(go.Scatter(
                x=model_data['wavelength (nm)'].tolist() + model_data['wavelength (nm)'][::-1].tolist(),
                y=(model_data['NMSE_mean'] + model_data['NMSE_std']*2).tolist() + (model_data['NMSE_mean'] - model_data['NMSE_std']*2)[::-1].tolist(),
                fill='toself',
                fillcolor='rgba(194,55,55,0.3)' if model == 'WINO' else 'rgba(1,105,169,0.3)',
                line=dict(color='rgba(255,255,255,0)'),
                name=f'{model} Std',
                showlegend=False
            ))

        # Customize layout
        fig.update_layout(
            # title='NMSE results by wavelengths with Std Dev',
            xaxis_title='Wavelength (nm)',
            yaxis_title='NMSE',
            font_family="Helvetica",
            title_font_family="Helvetica",
            font=dict(size=40),
            template='plotly_white',
            plot_bgcolor='rgb(255, 255, 255)',
            paper_bgcolor='rgb(255, 255, 255)',
            xaxis=dict(
                showgrid=False,  # Enable grid on x-axis
                showline=True,
                linewidth=4,    # Set line width for x-axis
                linecolor='black',  # Set color for x-axis line
                ticks='outside',   # Place ticks outside the plot area
                tickvals=list(range(400, 701, 40)),   # Set x-ticks at intervals of 20
                ticktext=[str(i) for i in range(400, 701, 40)]  # Label x-ticks with the corresponding values
            ),
            # xaxis=dict(
            #     showgrid=False,  # Enable grid on x-axis
            #     showline=True,
            #     linewidth=2,    # Set line width for x-axis
            #     linecolor='black',  # Set color for x-axis line
            #     # tickangle=45,   # Optional: Rotate the x-axis labels to avoid overlap
            #     ticks='outside' # Optional: Place ticks outside the plot area
            # ),
            yaxis=dict(
                showgrid=False,  # Enable grid on y-axis
                showline=False,
                linewidth=4,    # Set line width for x-axis
                linecolor='black',  # Set color for x-axis line
                ticks='outside' # Optional: Place ticks outside the plot area
            ),
            legend=dict(
                x=0.65,  # Positioning the legend inside (adjust as needed)
                y=0.97,  # Adjust the height position of the legend
                traceorder='normal',
                font=dict(size=40),
                bgcolor='rgba(255, 255, 255, 1)',  # Transparent background for legend
                bordercolor='rgba(0, 0, 0, 0.5)',  # No border for legend
                borderwidth=4,
                tracegroupgap=0,  # Remove gap between legend items
                # margin=dict(l=4, r=4, t=4, b=4)
            ),
            legend_orientation = "h",
            legend_title=""         # Hide the "trained" title in the legend
        )
        fig.update_xaxes(
            ticks="outside",    # 바깥쪽으로 눈금 표시
            tickwidth=4,        # 눈금 선 두께 (기본 1)
            ticklen=8,          # 눈금 길이 (기본 5)
            tickcolor="black"   # 눈금 색상
        )
        fig.update_yaxes(
            ticks="outside",    # 바깥쪽으로 눈금 표시
            tickwidth=4,        # 눈금 선 두께 (기본 1)
            ticklen=8,          # 눈금 길이 (기본 5)
            tickcolor="black"   # 눈금 색상
        )
        
        
        fig.update_layout(
            height=700,
            width=1400
        )
        fig.update_layout(showlegend=False)
        fig.update_layout(
            font=dict(
                family="Helvetica",  # 폰트 종류 (예: "Arial", "Helvetica", "Times New Roman" 등)
                size=40,             # 폰트 크기 (기본 크기 10, 원하는 크기로 설정)
                color="black"        # 폰트 색상 (기본은 "black", 원하는 색상으로 설정)
            )
        )


        for i in range(400, 701, 20):
            fig.add_vline(x=i,line_dash="dash", line_color="green", line_width=4)
        # Show the figure
        # fig.show()

        # Save the figure as an image
        for ft in file_type:\
            fig.write_image(f"saved_figures/varying_wavelengths/{sim_name}/{eval_mode}.{ft}")

# WINO vs Others : loss plot for varying wavelengths

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

sim_names = ['single_layer','triple_layer', 'straight_waveguide', 'image_sensor']
# sim_names = ['image_sensor']
design_range_dict = {'single_layer': {'design_start':0.4, 'width': 0.12},'triple_layer': {'design_start':0.4, 'width': 0.12}, 'straight_waveguide': {'design_start':1.0, 'width':4.85 }, 'image_sensor': {'design_start':0.6, 'width': 3.5}}
eval_modes = ['nmse_val_dict', 'structure_nmse_val_dict', 'near_nmse_val_dict']
file_type = ['png', 'svg', 'pdf']

color_dict = {
    "WINO":  {"line": "rgba(194,55,55,1)", "fill": "rgba(194,55,55,0.2)"},
    "NeurOLight":  {"line": "rgba(1,105,169,1)", "fill": "rgba(1,105,169,0.2)"},
    "UNet": {"line": "rgba(105,105,105,1)", "fill": "rgba(105,105,105,0.2)"},
    "FNO2d": {"line": "rgba(63,81,181,1)", "fill": "rgba(63,81,181,0.2)"},
    "F-FNO2d": {"line": "rgba(255,87,34,1)", "fill": "rgba(255,87,34,0.2)"}
}

models = ['UNet', 'FNO2d', 'F-FNO2d', 'NeurOLight', 'WINO']

for sim_name in sim_names:
    os.makedirs(f"saved_figures/varying_wavelengths_allmodels/{sim_name}", exist_ok=True)
    # Load data
    wino = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/wino/nmse_waveprior_64dim_12layer_256_5060_auggroup4_weightsharing.pt')
    unet = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/unet/nmse_waveprior_16dim.pt')
    fno2d = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/fno2d/nmse_waveprior_32dim_5layer_3210.pt')
    fno2dfactor = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/fno2dfactor/nmse_wp_64_12layer_256_mode5060.pt')
    neurolight = torch.load(f'/data/joon/Results/WINO/three_main_results/three_main_results/retain_batch_test_wvlwise_results/{sim_name}/neurolight/nmse_wp_64_16layer_256_mode5060_dp01_bs32_ressetm.pt')

    data_dict = {'UNet': unet, 'FNO2d': fno2d, 'F-FNO2d': fno2dfactor, 'NeurOLight': neurolight, "WINO": wino}
    all_wvls = list(range(400, 701))

    for eval_mode in eval_modes:
        # Initialize a list to hold data for the plot

        data_nmses = []
        for k,v in data_dict.items():
            # Collecting data
            for wvl in all_wvls:
                nmses = v[eval_mode][str(wvl)]
                if eval_mode == 'near_nmse_val_dict':
                    nmses = nmses.view(-1)
                for nmse in nmses:
                    data_nmses.append([nmse, wvl, k])
                    
        # Convert data into a DataFrame
        df = pd.DataFrame(data=data_nmses, columns=['NMSE', 'wavelength (nm)', 'model'])

        # Calculate the standard deviation for each wavelength for each model
        std_df = df.groupby(['wavelength (nm)', 'model']).agg({'NMSE': 'std'}).reset_index()
        mean_df = df.groupby(['wavelength (nm)', 'model']).agg({'NMSE': 'mean'}).reset_index()

        # Merge the mean and std into a single dataframe for easier plotting
        df_merged = pd.merge(mean_df, std_df, on=['wavelength (nm)', 'model'], suffixes=('_mean', '_std'))

        # Create the plot using Plotly
        fig = go.Figure()

        # Add lines with filled area for standard deviation
        for model in df_merged['model'].unique():
            model_data = df_merged[df_merged['model'] == model]
            
            # Add the line for mean values
            fig.add_trace(go.Scatter(
                x=model_data['wavelength (nm)'], 
                y=model_data['NMSE_mean'], 
                mode='lines', 
                name=model,
                line=dict(color=color_dict[model]['line'], width=4)
            ))
            
            # Add shaded region for the std deviation
            fig.add_trace(go.Scatter(
                x=model_data['wavelength (nm)'].tolist() + model_data['wavelength (nm)'][::-1].tolist(),
                y=(model_data['NMSE_mean'] + model_data['NMSE_std']*2).tolist() + (model_data['NMSE_mean'] - model_data['NMSE_std']*2)[::-1].tolist(),
                fill='toself',
                fillcolor=color_dict[model]['fill'],
                line=dict(color='rgba(255,255,255,0)'),
                name=f'{model} Std',
                showlegend=False
            ))

        # Customize layout
        fig.update_layout(
            # title='NMSE results by wavelengths with Std Dev',
            xaxis_title='Wavelength (nm)',
            yaxis_title='NMSE',
            font_family="Helvetica",
            title_font_family="Helvetica",
            font=dict(size=40),
            template='plotly_white',
            plot_bgcolor='rgb(255, 255, 255)',
            paper_bgcolor='rgb(255, 255, 255)',
            xaxis=dict(
                showgrid=False,  # Enable grid on x-axis
                showline=True,
                linewidth=4,    # Set line width for x-axis
                linecolor='black',  # Set color for x-axis line
                ticks='outside',   # Place ticks outside the plot area
                tickvals=list(range(400, 701, 40)),   # Set x-ticks at intervals of 20
                ticktext=[str(i) for i in range(400, 701, 40)]  # Label x-ticks with the corresponding values
            ),
            # xaxis=dict(
            #     showgrid=False,  # Enable grid on x-axis
            #     showline=True,
            #     linewidth=2,    # Set line width for x-axis
            #     linecolor='black',  # Set color for x-axis line
            #     # tickangle=45,   # Optional: Rotate the x-axis labels to avoid overlap
            #     ticks='outside' # Optional: Place ticks outside the plot area
            # ),
            yaxis=dict(
                showgrid=False,  # Enable grid on y-axis
                showline=False,
                linewidth=4,    # Set line width for x-axis
                linecolor='black',  # Set color for x-axis line
                ticks='outside' # Optional: Place ticks outside the plot area
            ),
            legend=dict(
                x=0.65,  # Positioning the legend inside (adjust as needed)
                y=0.97,  # Adjust the height position of the legend
                traceorder='normal',
                font=dict(size=40),
                bgcolor='rgba(255, 255, 255, 1)',  # Transparent background for legend
                bordercolor='rgba(0, 0, 0, 0.5)',  # No border for legend
                borderwidth=4,
                tracegroupgap=0,  # Remove gap between legend items
                # margin=dict(l=4, r=4, t=4, b=4)
            ),
            legend_orientation = "h",
            legend_title=""         # Hide the "trained" title in the legend
        )
        
        fig.update_layout(
            height=700,
            width=1400
        )
        fig.update_layout(showlegend=False)
        fig.update_layout(
            font=dict(
                family="Helvetica",  # 폰트 종류 (예: "Arial", "Helvetica", "Times New Roman" 등)
                size=40,             # 폰트 크기 (기본 크기 10, 원하는 크기로 설정)
                color="black"        # 폰트 색상 (기본은 "black", 원하는 색상으로 설정)
            )
        )


        for i in range(400, 701, 20):
            fig.add_vline(x=i,line_dash="dash", line_color="green", line_width=4)
        # Show the figure
        # fig.show()

        # Save the figure as an image
        for ft in file_type:
            fig.write_image(f"saved_figures/varying_wavelengths_allmodels/{sim_name}/{eval_mode}.{ft}")