# Evaluation of prediction error (entire wavelengths)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# You can set the path of evaluated data (wvlwise_field_prediction.py)
data = torch.load('test_wvlwise_results/wino/nmse_waveprior_64dim_12layer_256_5060_auggroup4_weightsharing.pt')
step = 20 # default setting
all_wvls = list(range(400, 701)) # visible wavelengths
observe_wvls = list(range(400, 701, step))


def get_set(data_dict):
    obs_data = []
    non_obs_data = []
    for wvl in all_wvls:
        if wvl in observe_wvls:
            obs_data.append(data_dict[str(wvl)].item())
        else:
            non_obs_data.append(data_dict[str(wvl)].item())
    return obs_data, non_obs_data


def one_test_value(data):
    train_wvl_nmses = []
    nmses = []
    train_wvl_structure_nmse = []
    structure_nmse = []
    
    
    
    for wvl in all_wvls:
        if wvl not in observe_wvls: # unobserve_wvl
            nmse = data['nmse_val_dict'][str(wvl)]
            nmses.append(nmse.item())
            nmse = data['structure_nmse_val_dict'][str(wvl)]
            structure_nmse.append(nmse.item())
        else: # obseve_wvl
            nmse = data['nmse_val_dict'][str(wvl)]
            train_wvl_nmses.append(nmse.item())
            nmse = data['structure_nmse_val_dict'][str(wvl)]
            train_wvl_structure_nmse.append(nmse.item())
            
    

    print(f"OBSERVED WVL NMSE : {np.mean(train_wvl_nmses)}")
    print(f"UNOBSERVED WVL NMSE : {np.mean(nmses)}")
    print(f"OBSERVED WVL STRUCTURE NMSE : {np.mean(train_wvl_structure_nmse)}")
    print(f"UNOBSERVED WVL STRUCTURE NMSE : {np.mean(structure_nmse)}")
    

one_test_value(data)

# Visualization of field prediction error (error map)

In [None]:
import os
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
import plotly.express as px

def get_design_range(resolution):
    dl = 1e-6/resolution
    NPML = int(1/1e6/dl)
    subtrate_range = (NPML, NPML + int(0.4/1e6/dl))
    design_range = (subtrate_range[1]-NPML,subtrate_range[1] + int(0.12/1e6/dl) - NPML)
    return design_range
    
def plotly_imshow_func(data, wavelength):
    design_range = get_design_range(40)
    
    output_field = data['fields_dict'][f'{wavelength}'][0][0]
    target_field = data['fields_dict'][f'{wavelength}'][1][0]
    font=dict(
            family="Helvetica",
            size=45)
    
    error_field = torch.abs(output_field[0] - target_field[0])
    fig = px.imshow(error_field, color_continuous_scale='inferno', zmin=0, zmax=0.175)
    fig.add_hrect(y0=design_range[0], y1=design_range[1], line_color="green", line_width=4)
    # fig.update_layout(coloraxis_showscale=False, autosize=False, title_text='FNO2d', title_x=0.5, title_y=0.95, font=font)
    fig.update_layout(autosize=False, coloraxis_showscale=True, margin={'t':0,'l':0,'b':0,'r':0},font=font, coloraxis_colorbar=dict(dtick=0.08))

    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    
    
    fig.show()
    fig.write_image("png_store/field_error_map.png")

os.makedirs("saved_figures", exist_ok=True)
names = ['wino', 'fno2d', 'fno2dfactor', 'neurolight', 'unet']
paths = [
    'test_wvlwise_results/wino/nmse_waveprior_64dim_12layer_256_5060_auggroup4_weightsharing.pt',
    'test_wvlwise_results/fno2d/nmse_waveprior_32dim_5layer_3210.pt',
    'test_wvlwise_results/fno2dfactor/nmse_wp_64_12layer_256_mode5060.pt',
    'test_wvlwise_results/neurolight/nmse_wp_64_16layer_256_mode5060_dp01_bs32_ressetm.pt',
    'test_wvlwise_results/unet/nmse_waveprior_16dim.pt',
]
for n, p in zip(names, paths):
    data = torch.load(p)
    name = n
    wavelength = 410 # nanometer
    plotly_imshow_func(name, data, wavelength)

# WINO vs NeurOLight : loss plot for varying wavelengths

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import plotly.express as px

data = torch.load('test_wvlwise_results/wino/nmse_waveprior_64dim_12layer_256_5060_group8.pt')
neurolight = torch.load('test_wvlwise_results/neurolight/nmse_wp_64_16layer_256_mode5060_dp01_bs32_ressetm.pt')

all_wvls = list(range(400, 701))
data_nmses=[]
for wvl in all_wvls:
    nmse = data['nmse_val_dict'][str(wvl)]
    data_nmses.append([nmse.item(), wvl, 'WINO'])

for wvl in all_wvls:
    nmse = neurolight['mse_val_dict'][str(wvl)]
    data_nmses.append([nmse.item(), wvl, 'NeurOLight'])


df = pd.DataFrame(data=data_nmses,columns=['NMSE', 'wavelength (nm)', 'model'])
df.head()

fig = px.line(df, x="wavelength (nm)", y="NMSE", title='NMSE results by wavelengths', color='model', height=900, width=2400)

fig.update_layout(
    font_family="Helvetica",
    # font_color="blue",
    title_font_family="Helvetica",
    font=dict(size=60)
    # title_font_color="red",
    # legend_title_font_color="green"
)
fig.update_layout(
                template='plotly_white',
                plot_bgcolor='rgb(220, 220, 220)',
                paper_bgcolor='rgb(255, 255, 255)',
                xaxis=dict(showgrid=False),
                yaxis=dict(showgrid=False)
)

for i in range(400, 701, 20):
    fig.add_vline(x=i,line_dash="dash", line_color="green", line_width=6)
# color_continuous_scale=px.colors.sequential.Cividis_r
fig.update_traces(line={'width': 6})
fig.show()
fig.write_image(f"saved_figures/entire_field_wino_vs_neurolight_wvl_errorplot.png")



# Temp check

In [2]:
import torch
data = torch.load('/root/WINO_interp/retain_batch_test_wvlwise_results/wino/nmse_waveprior_64dim_12layer_256_5060_auggroup4_weightsharing.pt')

In [3]:
data

{'nmse_val_dict': {'400': tensor([0.0040, 0.0040, 0.0034, 0.0058, 0.0044, 0.0052, 0.0056, 0.0036, 0.0058,
          0.0060, 0.0064, 0.0037, 0.0051, 0.0033, 0.0065, 0.0104, 0.0040, 0.0046,
          0.0044, 0.0036], dtype=torch.float64),
  '401': tensor([0.0059, 0.0053, 0.0046, 0.0051, 0.0056, 0.0053, 0.0042, 0.0054, 0.0036,
          0.0054, 0.0076, 0.0053, 0.0044, 0.0054, 0.0043, 0.0048, 0.0050, 0.0062,
          0.0053, 0.0044], dtype=torch.float64),
  '402': tensor([0.0047, 0.0064, 0.0056, 0.0062, 0.0055, 0.0046, 0.0054, 0.0053, 0.0055,
          0.0055, 0.0058, 0.0067, 0.0063, 0.0062, 0.0074, 0.0054, 0.0054, 0.0061,
          0.0082, 0.0052], dtype=torch.float64),
  '403': tensor([0.0074, 0.0058, 0.0069, 0.0065, 0.0063, 0.0067, 0.0060, 0.0057, 0.0093,
          0.0061, 0.0059, 0.0062, 0.0068, 0.0066, 0.0063, 0.0063, 0.0060, 0.0062,
          0.0070, 0.0068], dtype=torch.float64),
  '404': tensor([0.0073, 0.0071, 0.0074, 0.0071, 0.0066, 0.0073, 0.0073, 0.0086, 0.0123,
          0.00