In [1]:
import numpy as np

import wf_psf as wf

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mtick
import seaborn as sns


import sys
import numpy as np
import time
import tensorflow as tf
import tensorflow_addons as tfa


In [3]:
def define_plot_style():
    # Define plot paramters
    # Use seaborn style
    sns.set()
    
    plot_style = {
        'figure.figsize': (12,8),
        'figure.dpi': 200,
        'figure.autolayout':True,
        'lines.linewidth': 2,
        'lines.linestyle': '-',
        'lines.marker': 'o',
        'lines.markersize': 10,
        'legend.fontsize': 20,
        'legend.loc': 'best',
        'axes.titlesize': 24,
        'font.size': 22
    }
    mpl.rcParams.update(plot_style)

saving_folder ='../figures/'

define_plot_style()
sns.set(font_scale=1.5)


In [4]:
dataset_base_path = '/gpfswork/rech/ynx/ulx23va/repo/wf-SEDs/WFE_sampling_test/multires_dataset/4096/'

data_4096 = np.load(dataset_base_path+'test_Euclid_res_id_009_wfeRes_4096.npy', allow_pickle=True)[()]
data_256 = np.load(dataset_base_path+'test_Euclid_res_id_009_wfeRes_256.npy', allow_pickle=True)[()]


data_4096.keys()


dict_keys(['stars', 'super_res_stars', 'positions', 'SEDs', 'zernike_coef', 'C_poly', 'parameters'])

In [5]:

train_data_4096 = np.load(dataset_base_path+'train_Euclid_res_2000_TrainStars_id_009_wfeRes_4096.npy', allow_pickle=True)[()]
train_data_256 = np.load(dataset_base_path+'train_Euclid_res_2000_TrainStars_id_009_wfeRes_256.npy', allow_pickle=True)[()]

train_data_4096.keys()


dict_keys(['stars', 'noisy_stars', 'super_res_stars', 'positions', 'SEDs', 'zernike_coef', 'C_poly', 'parameters'])

In [6]:
def compute_stats(GT_preds, preds):
    # Calculate residuals
    residuals = np.sqrt(np.mean((GT_preds - preds)**2, axis=(1, 2)))
    GT_star_mean = np.sqrt(np.mean((GT_preds)**2, axis=(1, 2)))

    # RMSE calculations
    rmse = np.mean(residuals)
    rel_rmse = 100. * np.mean(residuals / GT_star_mean)

    # STD calculations
    std_rmse = np.std(residuals)
    std_rel_rmse = 100. * np.std(residuals / GT_star_mean)

    # Print RMSE values
    print('Absolute RMSE:\t %.4e \t +/- %.4e' % (rmse, std_rmse))
    print('Relative RMSE:\t %.4e %% \t +/- %.4e %%' % (rel_rmse, std_rel_rmse))

    return rmse, rel_rmse, std_rmse, std_rel_rmse



## Base error resolution x1

In [7]:

GT_preds = data_4096['stars']
preds = data_256['stars']

_ = compute_stats(GT_preds, preds)



Absolute RMSE:	 3.8303e-05 	 +/- 7.4604e-06
Relative RMSE:	 5.1611e-01 % 	 +/- 1.1823e-01 %


In [8]:
preds.shape

(400, 32, 32)

## Base error resolution x3

In [11]:

GT_preds = train_data_4096['super_res_stars']
preds = train_data_256['super_res_stars']

_ = compute_stats(GT_preds, preds)

print('')

GT_preds = train_data_4096['stars']
preds = train_data_256['stars']

_ = compute_stats(GT_preds, preds)


Absolute RMSE:	 9.3356e-06 	 +/- 1.6600e-06
Relative RMSE:	 6.1320e-01 % 	 +/- 1.0922e-01 %

Absolute RMSE:	 3.8210e-05 	 +/- 7.7971e-06
Relative RMSE:	 5.1518e-01 % 	 +/- 1.2361e-01 %


In [10]:

GT_preds = data_4096['super_res_stars']
preds = data_256['super_res_stars']

_ = compute_stats(GT_preds, preds)


Absolute RMSE:	 9.3197e-06 	 +/- 1.6003e-06
Relative RMSE:	 6.1254e-01 % 	 +/- 1.0338e-01 %


## Generate x3 reconstructions

In [14]:
args = {
    'model': 'poly',
    'model_eval': 'poly',
    'base_id_name': '_wfe_study_id009_256_bis_',
    'id_name': '_wfe_study_id009_bis_1',
    'train_dataset_file': 'train_Euclid_res_2000_TrainStars_id_009_wfeRes_4096.npy',
    'test_dataset_file': 'test_Euclid_res_id_009_wfeRes_4096.npy',
    'n_epochs_param': [15, 15],
    'n_epochs_non_param': [100, 50],
    'n_zernikes': 15,
    'gt_n_zernikes': 45,
    'pupil_diameter': 256,
    'oversampling_rate': 3.,
    'output_q': 3.,
    'output_dim': 32,
    'batch_size': 32,
    'd_max': 2,
    'x_lims': [0, 1e3],
    'y_lims': [0, 1e3],
    'd_max_nonparam': 5,
    'n_bins_lda': 20,
    'eval_batch_size': 16,
    'interpolation_type': 'none',
    'l_rate_param': [0.01, 0.004],
    'l_rate_non_param': [0.1, 0.06],
    'saved_model_type': 'checkpoint',
    'saved_cycle': 'cycle2',
    'total_cycles': 2,
    'use_sample_weights': True,
    'l2_param': 0.,
    'cycle_def': 'complete',
    'suffix_id_name': ['1','2','3','4','5', '6', '7', '8', '9'],
    'star_numbers': [1, 2, 3, 4, 5, 6, 7, 8, 9],
    'train_opt': True,
    'eval_opt':  True,
    'plot_opt': True,
    'base_path':  '/gpfswork/rech/ynx/ulx23va/repo/wf-SEDs/model_WFE_size/wf-outputs/',
    'dataset_folder':  '/gpfswork/rech/ynx/ulx23va/repo/wf-SEDs/WFE_sampling_test/multires_dataset/4096/',
    'metric_base_path':  '/gpfswork/rech/ynx/ulx23va/repo/wf-SEDs/model_WFE_size/wf-outputs/metrics/wfe_study_id009_bis/',
    'chkp_save_path':  '/gpfswork/rech/ynx/ulx23va/repo/wf-SEDs/model_WFE_size/wf-outputs/chkp/wfe_study_id009_bis/',
    'log_folder':  'log-files/wfe_study_id009_bis/',
    'model_folder':  'chkp/wfe_study_id009_bis/',
    'optim_hist_folder':  'optim-hist/wfe_study_id009_bis/',
    'plots_folder':  'plots/wfe_study_id009_bis/',
}


In [15]:
# Load models
test_dataset = np.load(args['dataset_folder'] + args['test_dataset_file'], allow_pickle=True)[()]
test_stars = test_dataset['stars']
test_pos = test_dataset['positions']
test_SEDs = test_dataset['SEDs']
# test_zernike_coef = test_dataset['zernike_coef']
test_C_poly = test_dataset['C_poly']
test_parameters = test_dataset['parameters']


tf_test_pos = tf.convert_to_tensor(test_dataset['positions'], dtype=tf.float32)


2022-08-26 09:51:48.911991: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-08-26 09:51:49.565846: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30986 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:1c:00.0, compute capability: 7.0


In [17]:

# Iterate through the model (bis_id), then through the pupil_diameter
dataset_list = [
    '064',
    '128',
    '256',
    '256_benchmark',
]

pupil_diameter_list = [64, 128, 256, 256]

n_realisations = 9



In [18]:

for it_j in range(len(pupil_diameter_list)):
    
    
    pupil_diameter = pupil_diameter_list[it_j]
    dataset_name = dataset_list[it_j]
    result_x3_list = []
    
    for idx in range(1, n_realisations+1):

        run_id_name = 'poly_wfe_study_id009_' + dataset_name + '_bis_' + str(idx)
        print(run_id_name)

        n_bins_lda=args['n_bins_lda']
        output_Q=1
        output_dim=64
        batch_size=args['eval_batch_size']

        # Define weights
        weights_paths = args['chkp_save_path'] + 'chkp_callback_' + run_id_name + '_' + args['saved_cycle']


        ## Prepare models
        # Generate Zernike maps
        zernikes = wf.utils.zernike_generator(
            n_zernikes=args['n_zernikes'], wfe_dim=pupil_diameter
        )
        # Now as cubes
        np_zernike_cube = np.zeros((len(zernikes), zernikes[0].shape[0], zernikes[0].shape[1]))

        for it in range(len(zernikes)):
            np_zernike_cube[it, :, :] = zernikes[it]

        np_zernike_cube[np.isnan(np_zernike_cube)] = 0
        tf_zernike_cube = tf.convert_to_tensor(np_zernike_cube, dtype=tf.float32)

        # Prepare np input
        simPSF_np = wf.SimPSFToolkit(
            zernikes,
            max_order=args['n_zernikes'],
            pupil_diameter=pupil_diameter,
            output_dim=args['output_dim'],
            oversampling_rate=args['oversampling_rate'],
            output_Q=args['output_q']
        )
        simPSF_np.gen_random_Z_coeffs(max_order=args['n_zernikes'])
        z_coeffs = simPSF_np.normalize_zernikes(simPSF_np.get_z_coeffs(), simPSF_np.max_wfe_rms)
        simPSF_np.set_z_coeffs(z_coeffs)
        simPSF_np.generate_mono_PSF(lambda_obs=0.7, regen_sample=False)

        # Obscurations
        obscurations = simPSF_np.generate_pupil_obscurations(N_pix=pupil_diameter, N_filter=2)
        tf_obscurations = tf.convert_to_tensor(obscurations, dtype=tf.complex64)


        # Initialize the model
        tf_semiparam_field = wf.tf_psf_field.TF_SemiParam_field(
            zernike_maps=tf_zernike_cube,
            obscurations=tf_obscurations,
            batch_size=args['batch_size'],
            output_Q=args['output_q'],
            d_max_nonparam=args['d_max_nonparam'],
            l2_param=args['l2_param'],
            output_dim=args['output_dim'],
            n_zernikes=args['n_zernikes'],
            d_max=args['d_max'],
            x_lims=args['x_lims'],
            y_lims=args['y_lims']
        )

        tf_semiparam_field.load_weights(weights_paths)



        # Save original output_Q and output_dim
        original_out_Q = tf_semiparam_field.output_Q
        original_out_dim = tf_semiparam_field.output_dim

        # Set the required output_Q and output_dim parameters in the models
        tf_semiparam_field.set_output_Q(output_Q=output_Q, output_dim=output_dim)


        # Need to compile the models again
        tf_semiparam_field = wf.tf_psf_field.build_PSF_model(tf_semiparam_field)


        # Generate SED data list
        packed_SED_data = [
            wf.utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_lda) for _sed in test_SEDs
        ]


        # Prepare inputs
        tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32)
        tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1])
        pred_inputs = [tf_test_pos, tf_packed_SED_data]

        # PSF model
        predictions = tf_semiparam_field.predict(x=pred_inputs, batch_size=16)

        # Save results
        result_x3_list.append(predictions)

    # Save the realisations
    np.save(
        saving_folder+'model_id009_res_'+dataset_name+'.npy',
        np.array(result_x3_list),
        allow_pickle=True
    )




poly_wfe_study_id009_064_bis_1
poly_wfe_study_id009_064_bis_2
poly_wfe_study_id009_064_bis_3
poly_wfe_study_id009_064_bis_4
poly_wfe_study_id009_064_bis_5
poly_wfe_study_id009_064_bis_6
poly_wfe_study_id009_064_bis_7
poly_wfe_study_id009_064_bis_8


poly_wfe_study_id009_064_bis_9
poly_wfe_study_id009_128_bis_1
poly_wfe_study_id009_128_bis_2
poly_wfe_study_id009_128_bis_3
poly_wfe_study_id009_128_bis_4
poly_wfe_study_id009_128_bis_5
poly_wfe_study_id009_128_bis_6
poly_wfe_study_id009_128_bis_7
poly_wfe_study_id009_128_bis_8
poly_wfe_study_id009_128_bis_9
poly_wfe_study_id009_256_bis_1
poly_wfe_study_id009_256_bis_2




poly_wfe_study_id009_256_bis_3
poly_wfe_study_id009_256_bis_4
poly_wfe_study_id009_256_bis_5


poly_wfe_study_id009_256_bis_6
poly_wfe_study_id009_256_bis_7
poly_wfe_study_id009_256_bis_8
poly_wfe_study_id009_256_bis_9
poly_wfe_study_id009_256_benchmark_bis_1
poly_wfe_study_id009_256_benchmark_bis_2
poly_wfe_study_id009_256_benchmark_bis_3
poly_wfe_study_id009_256_benchmark_bis_4
poly_wfe_study_id009_256_benchmark_bis_5




poly_wfe_study_id009_256_benchmark_bis_6
poly_wfe_study_id009_256_benchmark_bis_7
poly_wfe_study_id009_256_benchmark_bis_8
poly_wfe_study_id009_256_benchmark_bis_9


# Evaluate the results

In [19]:
GT_preds = data_4096['super_res_stars']
preds = data_256['super_res_stars']

_ = compute_stats(GT_preds, preds)


Absolute RMSE:	 9.3197e-06 	 +/- 1.6003e-06
Relative RMSE:	 6.1254e-01 % 	 +/- 1.0338e-01 %


In [20]:


dataset_list = [
    '064',
    '128',
    '256',
    '256_benchmark',
]

GT_preds = data_4096['super_res_stars']

save_dict = {}

# rmse, rel_rmse, std_rmse, std_rel_rmse

for it_j in range(len(dataset_list)):
    rel_rmse_list = []  
    std_rel_rmse_list = [] 
    
    model = np.load(saving_folder+'model_id009_res_'+dataset_list[it_j]+'.npy', allow_pickle=True)
    
    for it in range(model.shape[0]):

        preds = model[it,:,:,:]

        rmse, rel_rmse, std_rmse, std_rel_rmse = compute_stats(GT_preds, preds)

        rel_rmse_list.append(rel_rmse)
        std_rel_rmse_list.append(std_rel_rmse)
    
    save_dict['model_'+dataset_list[it_j]+'_rel_rmse'] = np.array(rel_rmse_list)
    save_dict['model_'+dataset_list[it_j]+'_std_rel_rmse'] = np.array(std_rel_rmse_list)


Absolute RMSE:	 7.5466e-05 	 +/- 3.1709e-05
Relative RMSE:	 5.0871e+00 % 	 +/- 2.6162e+00 %
Absolute RMSE:	 8.4209e-05 	 +/- 2.5405e-05
Relative RMSE:	 5.6145e+00 % 	 +/- 1.8782e+00 %
Absolute RMSE:	 7.5631e-05 	 +/- 2.9636e-05
Relative RMSE:	 5.0870e+00 % 	 +/- 2.4614e+00 %
Absolute RMSE:	 5.2935e-05 	 +/- 2.2880e-05
Relative RMSE:	 3.5404e+00 % 	 +/- 1.6218e+00 %
Absolute RMSE:	 8.5815e-05 	 +/- 2.4373e-05
Relative RMSE:	 5.7373e+00 % 	 +/- 1.8906e+00 %
Absolute RMSE:	 8.5173e-05 	 +/- 2.5255e-05
Relative RMSE:	 5.6848e+00 % 	 +/- 2.0643e+00 %
Absolute RMSE:	 6.5301e-05 	 +/- 2.9490e-05
Relative RMSE:	 4.3856e+00 % 	 +/- 2.3062e+00 %
Absolute RMSE:	 7.9238e-05 	 +/- 2.6394e-05
Relative RMSE:	 5.2786e+00 % 	 +/- 1.9645e+00 %
Absolute RMSE:	 5.2678e-05 	 +/- 2.2054e-05
Relative RMSE:	 3.5187e+00 % 	 +/- 1.6031e+00 %
Absolute RMSE:	 1.6967e-05 	 +/- 3.6196e-06
Relative RMSE:	 1.1420e+00 % 	 +/- 3.7159e-01 %
Absolute RMSE:	 5.9363e-05 	 +/- 1.8160e-05
Relative RMSE:	 4.0131e+00 % 	 +/- 1

In [21]:

np.save(saving_folder+'result_dict_id009_total_reps.npy', save_dict, allow_pickle=True)


In [22]:
save_dict

{'model_064_rel_rmse': array([5.08705677, 5.61449862, 5.08701694, 3.54040342, 5.73725727,
        5.68484656, 4.38561155, 5.27863165, 3.51873333]),
 'model_064_std_rel_rmse': array([2.61616571, 1.87823922, 2.46136524, 1.62178299, 1.89059537,
        2.06433733, 2.30615094, 1.96449952, 1.60309463]),
 'model_128_rel_rmse': array([1.14195661, 4.01311825, 2.97478546, 2.65130685, 2.86941983,
        2.71096396, 3.93740918, 4.30978019, 1.63257378]),
 'model_128_std_rel_rmse': array([0.37159136, 1.62235134, 3.17561218, 0.77001497, 1.78601803,
        1.04034929, 2.27665594, 1.86949049, 0.53014815]),
 'model_256_rel_rmse': array([2.62705269, 2.23776978, 4.27448855, 2.58015316, 1.04575128,
        3.15205333, 3.49616721, 1.66190468, 4.36528636]),
 'model_256_std_rel_rmse': array([0.98672247, 0.68250382, 2.72047781, 0.87204603, 0.35602997,
        2.33784353, 2.00473024, 0.4836183 , 1.68170574]),
 'model_256_benchmark_rel_rmse': array([1.31415853, 2.97896004, 4.22073245, 3.6112005 , 3.38505149,


In [23]:
print('All repetitions:'),print('')
print('Mean: ', np.mean(save_dict['model_064_rel_rmse']))
print('Median: ', np.median(save_dict['model_064_rel_rmse']))
print('Std dev reps: ', np.std(save_dict['model_064_rel_rmse']))
print('Mean Std dev: ', np.mean(save_dict['model_064_std_rel_rmse']))
print('')
print('Mean: ', np.mean(save_dict['model_128_rel_rmse']))
print('Median: ', np.median(save_dict['model_128_rel_rmse']))
print('Std dev reps: ', np.std(save_dict['model_128_rel_rmse']))
print('Mean Std dev: ', np.mean(save_dict['model_128_std_rel_rmse']))
print('')
print('Mean: ', np.mean(save_dict['model_256_rel_rmse']))
print('Median: ', np.median(save_dict['model_256_rel_rmse']))
print('Std dev reps: ', np.std(save_dict['model_256_rel_rmse']))
print('Mean Std dev: ', np.mean(save_dict['model_256_std_rel_rmse']))
print('')
print('Mean: ', np.mean(save_dict['model_256_benchmark_rel_rmse']))
print('Median: ', np.median(save_dict['model_256_benchmark_rel_rmse']))
print('Std dev reps: ', np.std(save_dict['model_256_benchmark_rel_rmse']))
print('Mean Std dev: ', np.mean(save_dict['model_256_benchmark_std_rel_rmse']))
print('')


All repetitions:

Mean:  4.881561790468762
Median:  5.08705676814964
Std dev reps:  0.820484671345332
Mean Std dev:  2.0451367725267504

Mean:  2.9157015671371584
Median:  2.869419826391799
Std dev reps:  1.0056345746787356
Mean Std dev:  1.4935813063611885

Mean:  2.8267363366105602
Median:  2.6270526924378013
Std dev reps:  1.0537441434680577
Mean Std dev:  1.347297545454857

Mean:  3.205391193848263
Median:  3.2774441216955488
Std dev reps:  1.049751027973218
Mean Std dev:  1.6799569884941763



In [21]:
print('Last 10 repetitions:'),print('')
print('Mean: ', np.mean(save_dict['model_064_rel_rmse'][5:]))
print('Median: ', np.median(save_dict['model_064_rel_rmse'][5:]))
print('Std dev reps: ', np.std(save_dict['model_064_rel_rmse'][5:]))
print('Mean Std dev: ', np.mean(save_dict['model_064_std_rel_rmse'][5:]))
print('')
print('Mean: ', np.mean(save_dict['model_128_rel_rmse'][5:]))
print('Median: ', np.median(save_dict['model_128_rel_rmse'][5:]))
print('Std dev reps: ', np.std(save_dict['model_128_rel_rmse'][5:]))
print('Mean Std dev: ', np.mean(save_dict['model_128_std_rel_rmse'][5:]))
print('')
print('Mean: ', np.mean(save_dict['model_256_rel_rmse'][5:]))
print('Median: ', np.median(save_dict['model_256_rel_rmse'][5:]))
print('Std dev reps: ', np.std(save_dict['model_256_rel_rmse'][5:]))
print('Mean Std dev: ', np.mean(save_dict['model_256_std_rel_rmse'][5:]))
print('')
print('Mean: ', np.mean(save_dict['model_256_benchmark_rel_rmse'][5:]))
print('Median: ', np.median(save_dict['model_256_benchmark_rel_rmse'][5:]))
print('Std dev reps: ', np.std(save_dict['model_256_benchmark_rel_rmse'][5:]))
print('Mean Std dev: ', np.mean(save_dict['model_256_benchmark_std_rel_rmse'][5:]))
print('')


Last 10 repetitions:

Mean:  5.800979539380679
Median:  5.686010443418695
Std dev reps:  0.7128806375529859
Mean Std dev:  2.3639498659359495

Mean:  4.292202804962117
Median:  4.385432541559398
Std dev reps:  0.9503749912314566
Mean Std dev:  1.9218999862259687

Mean:  4.466562247264386
Median:  4.339361947000835
Std dev reps:  0.7505609789278637
Mean Std dev:  2.021885716077639

Mean:  4.936639088891042
Median:  4.538246798904764
Std dev reps:  1.0475086011399142
Mean Std dev:  2.493699321761354



In [17]:

old_save_dict = np.load(saving_folder+'result_dict.npy', allow_pickle=True)[()]


In [18]:
old_save_dict

{'model_064_rel_rmse': array([6.43958849, 4.5148942 , 7.03129339, 5.18500449, 5.66809588]),
 'model_064_std_rel_rmse': array([2.46884952, 2.04078886, 4.19700985, 1.8659101 , 1.69467889]),
 'model_128_rel_rmse': array([5.42097046, 2.11293908, 2.10626005, 3.45872813, 3.28045238]),
 'model_128_std_rel_rmse': array([4.0749473 , 1.29302873, 1.20299272, 1.18403576, 1.43188567]),
 'model_256_rel_rmse': array([4.18412105, 6.33608123, 3.2794323 , 4.67587151, 4.80224664]),
 'model_256_std_rel_rmse': array([1.85443559, 3.54970007, 1.32788761, 2.23428099, 2.3315903 ]),
 'model_256_benchmark_rel_rmse': array([2.58978528, 3.92488443, 5.63453092, 3.48876644, 2.70354226]),
 'model_256_benchmark_std_rel_rmse': array([1.40024367, 1.20963646, 3.8965619 , 1.1530999 , 1.07147565])}

In [19]:

print('Mean: ', np.mean(old_save_dict['model_256_benchmark_rel_rmse']))
print('Median: ', np.median(old_save_dict['model_256_benchmark_rel_rmse']))
print('Std dev reps: ', np.std(old_save_dict['model_256_benchmark_rel_rmse']))
print('Mean Std dev: ', np.mean(old_save_dict['model_256_benchmark_std_rel_rmse']))
print('')



Mean:  3.668301864551181
Median:  3.4887664380025654
Std dev reps:  1.1007346096343087
Mean Std dev:  1.7462035157467848



In [20]:
GT_preds = data_4096['super_res_stars']
preds = data_256['super_res_stars']

_ = compute_stats(GT_preds, preds)

Absolute RMSE:	 8.5248e-06 	 +/- 1.3512e-06
Relative RMSE:	 5.7081e-01 % 	 +/- 1.0100e-01 %
