In [None]:
print('Begin script')

In [None]:
# usual imports
import os
import glob
import h5py
import numpy as np
import matplotlib.pyplot as plt
from rail.core.utils import RAILDIR
from rail.core import RailStage
from rail.core.data import TableHandle
#from rail.estimation.algos.sompz_version.utils import RAIL_SOMPZ_DIR
#from rail.pipelines.estimation.estimate_all import EstimatePipeline
#from rail.core import common_params
#from rail.pipelines.utils.name_factory import NameFactory, DataType, CatalogType, ModelType, PdfType
import qp
import ceci

In [None]:
from rail.estimation.algos.sompz import SOMPZEstimator

In [None]:
import rail.estimation.algos.sompz as sompz_

In [None]:
from rail.sompz.utils import mean_of_hist

In [None]:
sompz_.__file__

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
#from rail.core.utils import find_rail_file

# change to your rail location
RAIL_SOMPZ_DIR = "/global/u2/j/jmyles/repositories/LSSTDESC/rail_sompz/src"

datadir = '/pscratch/sd/j/jmyles/sompz_buzzard/2024-06-24/'
datadir_run = os.path.join(datadir, 'run-2024-07-01')
outdir = datadir_run
os.system(f'mkdir -p {outdir}')

testFileSpec = os.path.join(datadir, 'spec_data.h5') #'./datafiles/romandesc_deep_data_3700.hdf5'
testFileBalrog = os.path.join(datadir, 'balrog_data_subcatalog.h5') #'./datafiles/romandesc_deep_data_3700.hdf5'
testFileWide = os.path.join(datadir, 'wide_data_subsample.hdf5') #'./datafiles/romandesc_wide_data_5000.hdf5'

spec_data = DS.read_file("input_spec_data", TableHandle, testFileSpec)
balrog_data = DS.read_file("input_balrog_data", TableHandle, testFileBalrog)
wide_data = DS.read_file("input_wide_data", TableHandle, testFileWide)

model_file = os.path.join(datadir_run, "DEMO_CARDINAL_model_2024-06-24.pkl")

print('Catalogs specified')

### Define metadata for SOMPZ inference

In [None]:
# bands = ['U','G','R','I','Z','Y','J','H','K']

# deepbands = []
# deeperrs = []
# zeropts = []
# for band in bands:
#     deepbands.append(f'FLUX_{band}')
#     deeperrs.append(f'FLUX_ERR_{band}')
#     zeropts.append(30.)

# widebands = []
# wideerrs = []  
# for band in bands[:6]:
#     widebands.append(f'FLUX_{band}')
#     wideerrs.append(f'FLUX_ERR_{band}')
    
# refband_deep=deepbands[3]
# refband_wide=widebands[3]

bands_deep = ['lsst_u', 'lsst_g', 'lsst_r', 'lsst_i', 'lsst_z', 
              'VISTA_Filters_at80K_forETC_Y', 'VISTA_Filters_at80K_forETC_J', 'VISTA_Filters_at80K_forETC_H', 'VISTA_Filters_at80K_forETC_Ks',]
bands_wide = ['G','R','I','Z',] # 'U', 'Y','J','H','K'

deepbands = []
deeperrs = []
zeropts = []
for band in bands_deep:
    deepbands.append(f'TRUEMAG_{band}')
    deeperrs.append(f'TRUEMAG_ERR_{band}')
    zeropts.append(30.)

widebands = []
wideerrs = []  
for band in bands_wide: #[:6]:
    widebands.append(f'FLUX_{band}')
    wideerrs.append(f'FLUX_ERR_{band}')
    
    
refband_deep=deepbands[3]
refband_wide=widebands[2]

In [None]:
sompz_params = dict(inputs_deep=deepbands, input_errs_deep=deeperrs,
                    zero_points_deep=zeropts, 
                    inputs_wide=widebands, input_errs_wide=wideerrs,
                    convert_to_flux_deep=True, convert_to_flux_wide=False, 
                    set_threshold_deep=True, thresh_val_deep=1.e-5, 
                    som_shape_wide=(32,32), som_minerror_wide=0.005,
                    som_take_log_wide=False, som_wrap_wide=False,
                    specz_name='Z',
                    debug=False)

In [None]:
# bands = 'grizy'
# maglims = [27.66, 27.25, 26.6, 26.24, 25.35]
# maglim_dict={}
# for band,limx in zip(bands, maglims):
#     maglim_dict[f"HSC{band}_cmodel_dered"] = limx

### Prepare and run SOMPZ Estimation

In [None]:
print('make stage')
som_estimate = SOMPZEstimator.make_stage(name="cardinal_som_estimator", 
                                      spec_groupname="key", 
                                      balrog_groupname="key", 
                                      wide_groupname="", #"key"
                                      model=model_file, 
#                                      data_path=outdir, # TODO enable setting outdir for output files
                                         **sompz_params)

In [None]:
spec_data.data['key'].head()

In [None]:
spec_data.data['key']['TRUEMAG_VISTA_Filters_at80K_forETC_Ks'].max()

In [None]:
#som_estimate.estimate?

In [None]:
outfiles = sorted(glob.glob(os.path.join(datadir_run, '*_estimator.hdf5')))
print(len(outfiles))
if len(outfiles) == 0:
    print('estimate')
    output = som_estimate.estimate(spec_data, balrog_data, wide_data)
else:
    print('Estimation already done. Skipping estimation.')

In [None]:
outfile = os.path.join(outdir, 'output.npy')
if not os.path.exists(outfile):
    np.save(outfile, output)
else:
    print('Output file already exists. Skipping write to disk.')

In [None]:
print('Finished Estimation. Proceed to plotting.')

## display $n(z)$

In [None]:
# directly reading the hdf5 files with qp
qp_file = os.path.join(datadir_run, 'nz_cardinal_som_estimator.hdf5') # #os.path.join(RAIL_SOMPZ_DIR, '../examples/nz_cardinal_som_estimator.hdf5')
print(qp_file)
qp_single_nz_sompz = qp.read(qp_file)

In [None]:
nbins = 600
z_grid = np.linspace(0,6,nbins)
nz_sompz_grid = qp_single_nz_sompz.pdf(z_grid)

In [None]:
# Part of the spectroscopic samples failed and have z=-99, we should exclude them. 
specz_good = spec_data.data['key'][spec_data.data['key']['Z']>0.0]['Z']

### Make a plot to compare the two summarizers and the true n(z)

In [None]:
# load tomo binning assignment for subset of simulated catalog so that we can show true n(z| tomo. bin)
infile_nz_tomo_binning_sompz = os.path.join(outdir, 'tomo_bin_mask_wide_data_cardinal_som_estimator.hdf5')
finfile_nz_tomo_binning_sompz = h5py.File(infile_nz_tomo_binning_sompz, mode='r')
nz_tomo_binning_sompz = finfile_nz_tomo_binning_sompz['bin'][:]
print(f'binning info for {len(nz_tomo_binning_sompz):,} wide field galaxies loaded')

nsamp = 1_000_000
select_bin1 = nz_tomo_binning_sompz == 0
select_bin2 = nz_tomo_binning_sompz == 1
select_bin3 = nz_tomo_binning_sompz == 2
select_bin4 = nz_tomo_binning_sompz == 3

tmp = np.full(len(nz_tomo_binning_sompz), False)
tmp[:nsamp] = True
select_samp = np.random.shuffle(tmp)

print(len(nz_tomo_binning_sompz[select_bin1][select_samp][0]))
print(len(nz_tomo_binning_sompz[select_bin2][select_samp][0]))
print(len(nz_tomo_binning_sompz[select_bin3][select_samp][0]))
print(len(nz_tomo_binning_sompz[select_bin4][select_samp][0]))

select_bins = [select_bin1, select_bin2, select_bin3, select_bin4]

In [None]:
mpl_style_file = '~/.matplotlib/stylelib/jmyles.mplstyle'
if os.path.exists(mpl_style_file):
    try: 
        plt.style.use(mpl_style_file)
    except: 
        pass
    
colors = ['tab:blue', 'tab:orange', 'tab:red', 'tab:green']

fig, axarr = plt.subplots(4, 1, figsize=(16, 12))
handles_top = []
labels_top = []

handles_bot = []
labels_bot = []
for i, select_bin in enumerate(select_bins):

    # spec-z hist
    if i == 0:
        n_,bins_,patches0 = axarr[i].hist(specz_good,density = True ,bins = nbins,histtype = 'step', label = 'SOMPZ spec-z calibration sample' if i == 0 else '',  # // 2
                        color='k', alpha=0.25) # %2
    # SOMPZ truth (subsampled)
    n_,bins_,patches1 = axarr[i].hist(wide_data.data['Z'][select_bin], bins=100, lw=3, histtype='step',  # [select_samp][0]
                    label= f'Truth -- Bin {i+1}' if i == 0 else '', color=colors[i], ls='-', density=True)
    meanz_true = wide_data.data['Z'][select_bin].mean() # %2 # [select_samp]
    print(meanz_true)
    meanz_true_line = axarr[i].axvline(meanz_true, color=colors[i], ls='-', lw=2, label=f'_{meanz_true:.4f}') # %2

    # SOMPZ nz
    meanz_est = np.sum( nz_sompz_grid[i] * z_grid / np.sum(nz_sompz_grid[i])) # mean_of_hist(nz_sompz_grid[i], z_grid)
    print(meanz_est)
    meanz_est_line = axarr[i].axvline(meanz_est, color=colors[i], ls='-', alpha=0.5, lw=2., label=f'_{meanz_est:.4f}') # %2
    handle0, = axarr[i].plot(z_grid, nz_sompz_grid[i], label = f'Estimate (SOMPZ) -- Bin {i+1}', 
                    color=colors[i], ls='-', alpha=0.5, lw=3) # %2
    
    if i % 2 == 1:
        handles_bot.append(meanz_true_line)
        handles_bot.append(meanz_est_line)

        labels_bot.append(f'{meanz_true:.4f}')
        labels_bot.append(f'{meanz_est:.4f}')
    else:
        handles_top.append(meanz_true_line)
        handles_top.append(meanz_est_line)
        labels_top.append(f'{meanz_true:.4f}')
        labels_top.append(f'{meanz_est:.4f}')        
        
    #handles0.extend([patches0[0], patches1[0], handle0])
    axarr[i].set_xlim(0,2.25)
    axarr[i].set_ylim(0,3.25)
    axarr[-1].set_xlabel('redshift')
    axarr[i].set_ylabel('prob. density')
    axarr[i].set_yticks([])

    main_legend = axarr[i].legend() #handles = handles0 # %2
    axarr[i].add_artist(main_legend) # %2

    # # # Create a legend for the vertical lines
    # if i == 2:
    #     meanz_legend1 = axarr[0].legend(handles=handles_top, labels=labels_top, loc='lower right') # handles
    #     axarr[0].add_artist(meanz_legend1)
    # elif i == 3:
    #     meanz_legend2 = axarr[1].legend(handles=handles_bot, labels=labels_bot, loc='lower right') # handles
    #     axarr[1].add_artist(meanz_legend2)
fig.text(0.5, 0.5, 'preliminary', fontsize=80, color='k', alpha=0.25, ha='center', va='center', rotation=45)
axarr[0].set_title('Cardinal Simulation -- LSST Y1-like sample (TBR)')
# Add the legend manually to the Axes.
#axarr[0].add_artist(meanz_legend)
#axarr[0].legend(handles=handles0)
    
outfile = os.path.join(outdir, 'nz_sompz_est_script.png')
fig.savefig(outfile, dpi=150)
print(f'Wrote {outfile}')