In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import h5py
from paltas.Utils.distribution_utils import geometric_average
import sys
sys.path.insert(0, '/Users/smericks/Desktop/StrongLensing/lens-npe/')
from network_predictions import NetworkPredictions, generate_sequential_predictions, generate_narrow_sequential_predictions, generate_data_sequential_predictions

This notebook demonstrates how to generate network predictions from a trained model. 
We use the shifted test set for this example

In [None]:
zenodo_filepath = '../Paper/from_zenodo/'

## NPE Predictions ##

In [None]:
# path to weights of the trained network
path_to_weights = zenodo_filepath+'trained_models/npe/diag/xresnet34_068--14.58.h5'
path_to_norms = zenodo_filepath+'trained_models/npe/diag/norms.csv'

# which parameters network learned
learning_params = ['main_deflector_parameters_theta_E','main_deflector_parameters_gamma1',
                   'main_deflector_parameters_gamma2','main_deflector_parameters_gamma',
                   'main_deflector_parameters_e1','main_deflector_parameters_e2',
                   'main_deflector_parameters_center_x','main_deflector_parameters_center_y',
                   'source_parameters_center_x','source_parameters_center_y']
learning_params_names = [r'$\theta_\mathrm{E}$',r'$\gamma_1$',r'$\gamma_2$',r'$\gamma_\mathrm{lens}$',r'$e_1$',
								r'$e_2$',r'$x_{lens}$',r'$y_{lens}$',r'$x_{src}$',r'$y_{src}$']

# set up NetworkPredictions object, uses paltas code to make the predictions
predictions = NetworkPredictions(path_to_weights,path_to_norms,
    learning_params,loss_type='diag',model_type='xresnet34',norm_type='lognorm')

# where to write model predictions
write_folder = 'notebook_data/'

# where test set images are stored
path_to_narrow_set = zenodo_filepath+'test_sets/shifted/'

predictions.generate_all_predictions(write_folder,None,path_to_narrow_set)

## Generating SNPE Training Configs from NPE Predictions ##
This is an intermediate step that uses the NPE predictions to generate new
training images for SNPE

In [None]:
# load in predictions from broad training
file_prefix = 'notebook_data/'
file_path = file_prefix+'narrow_predictions.h5'
h5f = h5py.File(file_path, 'r')
y_pred_narrow = h5f.get('y_pred').value.astype(np.float64)
std_pred_narrow = h5f.get('std_pred').value.astype(np.float64)
h5f.close()

# get rid of scientific notation
np.set_printoptions(suppress=True)

mu_prior = np.asarray([0.8,0.,0.,2.0,0.,0.,0.,0.,0.,0.])
sigma_prior = np.asarray([0.15,0.12,0.12,0.2,0.2,0.2,0.07,0.07,0.1,0.1])

for narrow_idx in range(0,20):

    for prior_factor in [2]:

        mus = y_pred_narrow[narrow_idx,:]
        sigmas = std_pred_narrow[narrow_idx,:]
        if prior_factor != 0:
            mus,sigmas = geometric_average(mus,sigmas,mu_prior,sigma_prior,weight_wide=prior_factor)
            
        mus_string = repr(np.round(mus,3))[6:-1]        
        sigmas_string = repr(np.round(sigmas,3))[6:-1]

        filename = 'notebook_data/sequential_config_base.py'
        with open(filename) as file:
            lines = [line.rstrip() for line in file]
            lines[11] = 'seq_mus = ' + mus_string
            lines[12] = 'seq_sigmas = ' + sigmas_string

        with open('notebook_data/config_shifted%03d.py'%(narrow_idx), 'w') as f:
            for line in lines:
                f.write(f"{line}\n")

## SNPE Predictions ##

In [None]:

# list is in order 0x,1x,2x,4x
y_pred_narrow_seq_list = []
std_pred_narrow_seq_list = []
prec_pred_narrow_seq_list = []

test_set_indices = range(0,20)

for f in ['0x','1x','2x','4x']:
    epoch = 10
    weights_paths = []

    #model_weights_list = np.asarray(os.popen('ls -d '+weights_files).read().split())

    for i in test_set_indices:
        weights_paths.append(zenodo_filepath+'trained_models/snpe/shifted/gem_avg_'
            +f+'/xresnet34_%03d_narrow%03d.h5'%(epoch,i))

    narrow_image_folder = zenodo_filepath+'test_sets/shifted/'
    norm_path = zenodo_filepath+'trained_models/npe/diag/norms.csv'
    y_pred_narrow_seq, std_pred_narrow_seq, prec_pred_narrow_seq = generate_narrow_sequential_predictions(
        weights_paths,narrow_image_folder,image_indices=test_set_indices,
        norms_path=norm_path,loss_type='diag',image_type='h5')
    y_pred_narrow_seq_list.append(y_pred_narrow_seq)
    std_pred_narrow_seq_list.append(std_pred_narrow_seq)
    prec_pred_narrow_seq_list.append(prec_pred_narrow_seq)

y_pred_narrow_seq_list = np.asarray(y_pred_narrow_seq_list)
std_pred_narrow_seq_list = np.asarray(std_pred_narrow_seq_list)

np.save('notebook_data/y_pred_list_epoch10.npy',
    y_pred_narrow_seq_list)
np.save('notebook_data/std_pred_list_epoch10.npy',
        std_pred_narrow_seq_list)
np.save('notebook_data/prec_pred_list_epoch10.npy',
        prec_pred_narrow_seq_list)