# Step 2: Compute ground truths (fermat potentials, time-delays, kinematics) #

### Note: This notebook runs in the forecast_env ###

In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import h5py
import copy
import os
import sys
sys.path.insert(0, '/Users/smericks/Desktop/StrongLensing/darkenergy-from-LAGN/')
import Utils.ground_truth_utils as gt_utils
from astropy.cosmology import FlatLambdaCDM

Choice of ground truth cosmology: 

In [None]:
gt_cosmo = FlatLambdaCDM(H0=70.,Om0=0.3)

Compute fermat potentials, time-delays, kinematics

In [None]:
# load in metadata.csv
metadata_gold = pd.read_csv('DataVectors/gold/truth_metadata.csv')
metadata_silver = pd.read_csv('DataVectors/silver/truth_metadata.csv')

# write in truth info
for metadata_df in [metadata_gold,metadata_silver]:
    gt_utils.populate_fermat_differences(metadata_df)
    gt_utils.populate_truth_Ddt_timedelays(metadata_df,gt_cosmo)

Check for lenses with discrepant time-delays (leftover artifact from numerical lens solving...)

In [None]:
diff_td01 = metadata_gold.loc[:,'td01'] - metadata_silver.loc[:,'td01']
plt.hist(diff_td01,bins=20)
plt.title('Doubles + Quads $\Delta$td01')
plt.show()

remove_idx = np.where(np.abs(diff_td01) > 0.01)[0] # this is just one lens right now...
remove_catalog_idx = metadata_gold.loc[remove_idx,'catalog_idx'].to_numpy()
keep_idx = np.where(np.abs(diff_td01) < 0.01)[0]

# and what about quads... (these are fine right now...)
quads_idx = np.where(metadata_gold.loc[:,'point_source_parameters_num_images'].to_numpy() == 4)[0]
diff_td03 = metadata_gold.loc[quads_idx,'td03'] - metadata_silver.loc[quads_idx,'td03']
plt.figure()
plt.hist(diff_td03,bins=20)
plt.title('Quads $\Delta$td03')
plt.show()

Remove from data vectors and save new products

In [None]:
# Modify and save truth_metadata.csv
metadata_gold = metadata_gold[~metadata_gold['catalog_idx'].isin(remove_catalog_idx)]
metadata_gold.to_csv('DataVectors/gold/truth_metadata.csv')
metadata_silver = metadata_silver[~metadata_silver['catalog_idx'].isin(remove_catalog_idx)]
metadata_silver.to_csv('DataVectors/silver/truth_metadata.csv')

# Modify and save image_models.h5
for quality in ['gold','silver']:
    images_path = 'DataVectors/'+quality+'/image_models.h5'
    h5 = h5py.File(images_path, 'r+')
    
    # remove via catalog_idx to avoid deleting repeatedly
    catalog_idxs = h5['catalog_idx']
    h5_idx_to_remove = []
    for rci in(remove_catalog_idx):
        bad_idx = np.where(catalog_idxs == rci)[0]
        if len(bad_idx) > 0:
            h5_idx_to_remove.append(bad_idx.item())
    
    # remove bad idx(s) from all things stored in .h5 files!
    for data_name in ['catalog_idx','images_array','mu_npe','cov_npe']:
        array_copy = h5[data_name]
        array_copy = np.delete(array_copy, h5_idx_to_remove, axis=0)
        del h5[data_name]
        h5.create_dataset(data_name, data=array_copy)
    
    h5.close()