In [1]:
%load_ext autoreload
%autoreload 2

from sbifitter import SBI_Fitter
from astropy.table import Table
import numpy as np
import matplotlib.pyplot as plt


Device: cuda
Pytorch version: 2.4.1
ROCM HIP version: 6.1.40093-e3dc58bf0


In [2]:
grid_path = '/home/tharvey/work/output/grid_BPASS_DelayedExponential_SFH_0.01_z_12_logN_5.7_Chab_CF00_v1.hdf5'

fitter = SBI_Fitter.init_from_hdf5('BPASS_Chab_LogNorm_5_z_12_phot_grid2', 
                                grid_path, return_output=False)

model_name = 'BPASS_Chab_DelayedExpSFH_0.01_z_12_CF00_v1sbipp_settings_posterior'

fitter.load_model_from_pkl(f'/home/tharvey/work/ltu-ili_testing/models/BPASS_Chab_DelayedExpSFH_0.01_z_12_CF00_v1/{model_name}.pkl');

In [3]:
# Do the cheeky thing of setting the flags explicity

fitter.feature_array_flags = dict(
            normalize_method = None,
            extra_features= [],
            normed_flux_units= "AB",
            normalization_unit= "AB",
            scatter_fluxes= True,
            empirical_noise_models= True,
            depths= None,
            include_errors_in_feature_array = True,
            min_flux_pc_error = 0.9,
            simulate_missing_fluxes = False,
            missing_flux_value = 99.0,
            missing_flux_fraction = 0.9,
            missing_flux_options = None,
            include_flags_in_feature_array = False,
            override_phot_grid = None,
            override_phot_grid_units = None,
            norm_mag_limit = 40,
            remove_nan_inf = None,
            parameters_to_remove = [],
            photometry_to_remove = None,
            drop_dropouts = False,
            drop_dropout_fraction = 1.0,
            raw_photometry_names = fitter.feature_names[:9],
            error_names = fitter.feature_names[9:],
            flag_names = [],
            norm_name = None,
)

# Test on real galaxies

Firstly load our real catalogue, and setup a mapping between feature names and column names using a dictionary.

In [4]:
file = '/home/tharvey/Downloads/JADES-Deep-GS_MASTER_Sel-f277W+f356W+f444W_v9_loc_depth_masked_10pc_EAZY_matched_selection_ext_src_UV.fits'
table = Table.read(file)

new_table = Table()
mag_cols_syntax = lambda band: f"MAG_APER_{band}_aper_corr"
magerr_cols_syntax = lambda band: f"MAGERR_APER_{band}_u1_loc_depth"
table = table[(table['selected_gal_all_criteria_delta_chi2_4_fsps_larson_no_bd'] == True) & (table['sigma_f444W'][:, 0] > 10) & (table['sigma_f277W'][:, 0] > 10) & (table['sigma_f356W'][:, 0] > 10)]
bands = [i.split("_")[-1] for i in table.colnames if i.startswith("loc_depth")]
new_band_names = ["HST/ACS_WFC.F606W"] + [f"JWST/NIRCam.{band.upper()}" for band in bands[1:]]

for band in bands:
    new_table[mag_cols_syntax(band)] = table[mag_cols_syntax(band)][:, 0]
    new_table[magerr_cols_syntax(band)] = table[magerr_cols_syntax(band)][:, 0]

new_table['z_eazy'] = table['zbest_fsps_larson']

conversion_dict = {mag_cols_syntax(band):new_band for band, new_band in zip(bands, new_band_names)}
conversion_dict.update({magerr_cols_syntax(band):f"unc_{new_band}" for band, new_band in zip(bands, new_band_names)})

 Features is catalogue converted to match features. Mask is a 1D boolean array showing rows which were removed
due to missing data

In [5]:
obs_features, obs_mask = fitter.create_features_from_observations(new_table, columns_to_feature_names=conversion_dict, flux_units='AB')

Removing 49 observations with missing data.


We can check if any observations are out of distribution given our training data

In [6]:
from sbifitter import test_out_of_distribution

a, b = test_out_of_distribution(fitter.feature_array, obs_features, sigma_threshold=5.0);


Original number of samples: 448
Number of outliers removed (5.0-sigma): 20
Number of samples remaining: 428


Or we can let the code do this internally and it will add the columns to the table for us

In [7]:
output = fitter.fit_catalogue(new_table,  
                            columns_to_feature_names=conversion_dict, 
                            flux_units='AB', 
                            sample_method='direct',
                            timeout_seconds_per_row=10)

Removing 49 observations with missing data.
torch.return_types.min(
values=tensor([[ 1.1139e+00,  7.5936e+00, -4.1183e-01, -2.5371e-01,  2.2437e+02,
          5.5661e+02, -3.2788e+00]], device='cuda:0'),
indices=tensor([[603, 603, 250, 250, 775, 754, 810]], device='cuda:0'))
tensor([[ 6.1017e+00,  8.1214e+00, -7.4543e-02, -5.9354e-02,  9.9005e+02,
          7.6850e+02, -2.5547e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[ 7.1637e+00,  8.4432e+00,  3.2980e-02,  6.2265e-02,  1.6504e+03,
          5.2127e+03, -1.4724e+00]], device='cuda:0'),
indices=tensor([[754, 494, 754, 520,  60, 603, 503]], device='cuda:0'))

torch.return_types.min(
values=tensor([[ 9.8384e-01,  7.4450e+00, -5.1909e-01, -3.3764e-01,  1.9143e+02,
          5.1072e+02, -3.3405e+00]], device='cuda:0'),
indices=tensor([[8472, 1548, 6027, 6027, 9777, 5444, 2974]], device='cuda:0'))
tensor([[ 6.0936e+00,  8.1190e+00, -7.6316e-02, -6.0151e-02,  9.9838e+02,
          7.7527e+02, -2.5531e+00]], device='cuda:

                    accepted. It may take a long time to collect the remaining
                    -6 samples. Consider interrupting (Ctrl-C) and switching to
                    `build_posterior(..., sample_with='mcmc')`.


torch.return_types.min(
values=tensor([[ 1.1139e+00,  7.4204e+00, -5.4955e-01, -3.6289e-01,  1.5588e+02,
          4.8412e+02, -3.4350e+00]], device='cuda:0'),
indices=tensor([[2383, 3996, 3996, 3996, 1540, 5249,  771]], device='cuda:0'))
tensor([[ 6.0920e+00,  8.1182e+00, -7.6837e-02, -6.0596e-02,  9.9867e+02,
          7.7780e+02, -2.5578e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[ 7.7299e+00,  8.4108e+00,  8.4922e-02,  7.8596e-02,  1.7021e+03,
          5.2096e+03, -1.3752e+00]], device='cuda:0'),
indices=tensor([[5249, 1828, 2657, 2693, 2149, 2693, 1154]], device='cuda:0'))

drawin samples


  0%|          | 0/447 [00:00<?, ?it/s]

torch.return_types.min(
values=tensor([[ 1.1885e+00,  7.3662e+00, -4.6536e-01, -3.1607e-01,  2.5317e+02,
          7.8395e+02, -3.1704e+00]], device='cuda:0'),
indices=tensor([[ 64, 419, 718, 357, 193, 118, 192]], device='cuda:0'))
tensor([[ 4.7573e+00,  8.0127e+00, -7.6533e-02, -5.2511e-02,  9.8305e+02,
          1.1724e+03, -2.1682e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[ 5.8467e+00,  8.2989e+00,  8.5364e-02,  1.6760e-01,  1.8118e+03,
          4.9593e+03, -1.4273e+00]], device='cuda:0'),
indices=tensor([[118, 774, 911,  88,  35,  64, 204]], device='cuda:0'))

torch.return_types.min(
values=tensor([[ 1.0741e+00,  7.3105e+00, -6.3970e-01, -4.1181e-01,  1.5178e+02,
          7.1556e+02, -3.3799e+00]], device='cuda:0'),
indices=tensor([[6696, 3261, 2176, 7116, 7346, 6138, 9515]], device='cuda:0'))
tensor([[ 4.7748e+00,  8.0159e+00, -7.3972e-02, -5.1125e-02,  9.7823e+02,
          1.1618e+03, -2.1649e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[

  0%|          | 1/447 [00:00<03:08,  2.37it/s]

torch.return_types.min(
values=tensor([[ 1.0134e+00,  7.2660e+00, -6.1580e-01, -3.6419e-01,  7.6562e+01,
          6.8279e+02, -3.3318e+00]], device='cuda:0'),
indices=tensor([[1109,  855, 6048, 8184, 9015, 2421, 1498]], device='cuda:0'))
tensor([[ 4.7949e+00,  8.0190e+00, -7.2954e-02, -5.0420e-02,  9.7406e+02,
          1.1507e+03, -2.1591e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[ 6.3749e+00,  8.3569e+00,  1.7196e-01,  2.7229e-01,  1.8960e+03,
          5.6233e+03, -1.3401e+00]], device='cuda:0'),
indices=tensor([[2421, 1532,  612,  612, 4784, 1109, 8997]], device='cuda:0'))

torch.return_types.min(
values=tensor([[ 3.6942e+00,  8.0647e+00, -4.1751e-01, -2.3883e-01,  4.8239e+02,
          5.9973e+02, -3.2392e+00]], device='cuda:0'),
indices=tensor([[481, 481, 972, 972, 852, 808, 657]], device='cuda:0'))
tensor([[ 6.1226e+00,  8.3155e+00, -1.1359e-01, -9.8035e-02,  1.3122e+03,
          7.3153e+02, -2.3818e+00]], device='cuda:0')
torch.return_types.max(
values=te

  0%|          | 2/447 [00:10<44:55,  6.06s/it]

Timeout exceeded for sample 2. Returning empty array for this sample.
torch.return_types.min(
values=tensor([[ 1.0084e+00,  7.6014e+00, -7.6264e-01, -3.3575e-01,  3.8533e+02,
          5.5869e+02, -3.2841e+00]], device='cuda:0'),
indices=tensor([[9792, 5950, 5950, 5950,  437, 7982, 5532]], device='cuda:0'))
tensor([[ 6.1103e+00,  8.3118e+00, -1.1676e-01, -9.9488e-02,  1.3072e+03,
          7.3565e+02, -2.3857e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[ 7.1204e+00,  8.6064e+00,  2.6946e-02,  3.9028e-03,  2.3870e+03,
          5.5498e+03, -1.4890e+00]], device='cuda:0'),
indices=tensor([[7982,    2,  210, 5107, 4505, 9792, 9016]], device='cuda:0'))

torch.return_types.min(
values=tensor([[ 9.5141e-01,  7.7007e+00, -6.1167e-01, -2.7327e-01,  4.0908e+02,
          5.8932e+02, -3.2527e+00]], device='cuda:0'),
indices=tensor([[6932, 6932, 8023, 8023, 3306, 6315, 7148]], device='cuda:0'))
tensor([[ 6.1183e+00,  8.3127e+00, -1.1673e-01, -9.9398e-02,  1.3081e+03,
          

  1%|          | 3/447 [00:11<27:51,  3.77s/it]

torch.return_types.min(
values=tensor([[ 1.6665e+00,  7.8116e+00, -4.7325e-01, -2.7574e-01,  2.7085e+02,
          5.3727e+02, -3.3578e+00]], device='cuda:0'),
indices=tensor([[5471, 5471, 9648, 5471, 7865, 2697, 2310]], device='cuda:0'))
tensor([[ 6.1139e+00,  8.3129e+00, -1.1667e-01, -9.9377e-02,  1.3014e+03,
          7.3430e+02, -2.3854e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[ 7.2775e+00,  8.6066e+00,  3.8768e-02,  2.1143e-03,  2.3009e+03,
          3.7137e+03, -1.4892e+00]], device='cuda:0'),
indices=tensor([[2697, 3772, 3772, 4250, 9640, 7876, 6006]], device='cuda:0'))

torch.return_types.min(
values=tensor([[ 1.1780e+00,  7.7013e+00, -3.1279e-01, -2.2166e-01,  2.6855e+02,
          5.3923e+02, -3.2523e+00]], device='cuda:0'),
indices=tensor([[2101, 1919,  213,  213, 2359,  870,  189]], device='cuda:0'))
tensor([[ 6.1796e+00,  8.2919e+00, -2.1128e-02, -3.2146e-02,  9.3564e+02,
          7.5115e+02, -2.4378e+00]], device='cuda:0')
torch.return_types.max(
va

  1%|          | 4/447 [00:12<18:38,  2.52s/it]

torch.return_types.min(
values=tensor([[ 1.4600e+00,  7.7328e+00, -5.7490e-01, -2.9436e-01,  4.2949e+02,
          5.4320e+02, -3.2364e+00]], device='cuda:0'),
indices=tensor([[2571, 2571, 5615, 2571,  928, 7862, 1845]], device='cuda:0'))
tensor([[ 6.1163e+00,  8.3127e+00, -1.1673e-01, -9.9375e-02,  1.3063e+03,
          7.3409e+02, -2.3797e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[ 7.2351e+00,  8.6387e+00,  2.1813e-02, -1.3142e-02,  2.3498e+03,
          4.1505e+03, -1.3759e+00]], device='cuda:0'),
indices=tensor([[7862, 7491, 5775, 5775, 9791, 2571, 2311]], device='cuda:0'))

torch.return_types.min(
values=tensor([[ 6.2762e+00,  9.1828e+00,  9.8876e-02, -4.6500e-02,  4.0565e+02,
          5.8657e+02, -3.4029e+00]], device='cuda:0'),
indices=tensor([[6157, 4189, 4189, 1339, 9836, 6010, 9741]], device='cuda:0'))
tensor([[ 6.5372e+00,  9.2963e+00,  2.7256e-01,  1.0918e-01,  1.3148e+03,
          6.5407e+02, -2.4775e+00]], device='cuda:0')
torch.return_types.max(
va

  1%|          | 5/447 [00:13<19:53,  2.70s/it]

torch.return_types.min(
values=tensor([[ 1.3932e+00,  7.7714e+00, -4.7512e-01, -2.9745e-01,  3.4054e+02,
          5.4613e+02, -3.3617e+00]], device='cuda:0'),
indices=tensor([[4172,   62, 8912,   62, 9624, 6073, 6614]], device='cuda:0'))
tensor([[ 6.1166e+00,  8.3127e+00, -1.1637e-01, -9.9009e-02,  1.3018e+03,
          7.3330e+02, -2.3847e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[ 7.2092e+00,  8.6228e+00,  7.0696e-03, -3.7914e-03,  2.4833e+03,
          4.3112e+03, -1.4715e+00]], device='cuda:0'),
indices=tensor([[6073, 9810,  650, 6591, 1534, 4172, 9624]], device='cuda:0'))

torch.return_types.min(
values=tensor([[ 1.7159e+00,  7.6042e+00, -4.8560e-01, -3.1457e-01,  7.5781e+01,
          3.7710e+02, -3.7222e+00]], device='cuda:0'),
indices=tensor([[3794, 3794, 3053, 3053, 3053, 9102, 6147]], device='cuda:0'))
tensor([[ 7.3207e+00,  8.1515e+00, -6.8888e-02, -5.1458e-02,  7.8830e+02,
          5.3920e+02, -2.4056e+00]], device='cuda:0')
torch.return_types.max(
va




torch.return_types.min(
values=tensor([[ 1.7350e+00,  7.7793e+00, -4.9265e-01, -2.5758e-01,  3.5141e+02,
          5.2873e+02, -3.3096e+00]], device='cuda:0'),
indices=tensor([[7101, 7101, 4909, 4909,  977, 2930, 5486]], device='cuda:0'))
tensor([[ 6.1147e+00,  8.3114e+00, -1.1668e-01, -9.9321e-02,  1.3076e+03,
          7.3378e+02, -2.3897e+00]], device='cuda:0')
torch.return_types.max(
values=tensor([[ 7.3363e+00,  8.6051e+00,  4.8971e-03, -5.8829e-03,  2.3778e+03,
          3.5548e+03, -1.4537e+00]], device='cuda:0'),
indices=tensor([[2930, 3736, 1852, 2324, 4590, 7101, 2324]], device='cuda:0'))

torch.return_types.min(
values=tensor([[ 1.9580e+00,  7.8366e+00, -3.7435e-01, -2.1425e-01,  4.3260e+02,
          4.4213e+02, -3.1564e+00]], device='cuda:0'),
indices=tensor([[110, 110, 110, 110, 801, 133, 702]], device='cuda:0'))
tensor([[ 6.3425e+00,  8.2886e+00, -8.0040e-02, -6.7189e-02,  1.0686e+03,
          6.9034e+02, -2.4798e+00]], device='cuda:0')
torch.return_types.max(
values=te

In [None]:
output_masked = output[output['redshift_50'] != 0.0]
plt.scatter(output_masked['z_eazy'], output_masked['redshift_50'])
plt.xlabel('z_eazy')

plt.plot(plt.xlim(), plt.xlim(), ls='--', color='k')
plt.ylabel('z_sbi')
plt.xlabel('z_eazy')


In [None]:
import torch
x = torch.tensor(obs_features, dtype=torch.float32, device=fitter.device)
fitter.posteriors.sample_batched(x=x, sample_shape=(1000,))

In [9]:
fitter._prior.dist.low, fitter._prior.dist.high

for i in range(len(fitter.simple_fitted_parameter_names)):
    print(fitter.simple_fitted_parameter_names[i], 
          fitter._prior.dist.low[i].item(), 
          fitter._prior.dist.high[i].item())
    

(tensor([ 1.0017e-02,  6.0020e+00,  4.3937e-06,  5.5227e-06,  1.0007e+01,
          1.8859e+02, -3.0000e+00], device='cuda:0'),
 tensor([ 1.2000e+01,  1.2000e+01,  3.6832e+00,  7.3544e+00,  9.9999e+02,
          1.3465e+04, -1.3900e+00], device='cuda:0'))

redshift 0.010017454624176025 11.999988555908203
log_mass 6.001991271972656 11.999991416931152
tau_v_ism 4.393674771563383e-06 3.6832313537597656
tau_v_birth 5.522712854144629e-06 7.354403972625732
tau 10.007433891296387 999.9893798828125
max_age 188.58551025390625 13465.0634765625
log10metallicity -2.9999985694885254 -1.3900014162063599


In [18]:
import torch
# max
7.1637e+00,  8.4432e+00,  3.2980e-02,  6.2265e-02,  1.6504e+03,
5.2127e+03, -1.4724e+00

# min
9.8384e-01,  7.4450e+00, -5.1909e-01, -3.3764e-01,  1.9143e+02,
5.1072e+02, -3.3405e+00

# mean
a = [6.1481e+00,  8.1235e+00, -7.4607e-02, -5.8678e-02,  9.8795e+02,
7.5202e+02, -2.5445e+00]
a = torch.tensor(a, dtype=torch.float32, device=fitter.device)

help(fitter._prior.support.check)#(a)

Help on method check in module torch.distributions.constraints:

check(value) method of torch.distributions.constraints._IndependentConstraint instance
    Returns a byte tensor of ``sample_shape + batch_shape`` indicating
    whether each event in value satisfies this constraint.



Help on _IndependentConstraint in module torch.distributions.constraints object:

class _IndependentConstraint(Constraint)
 |  _IndependentConstraint(base_constraint, reinterpreted_batch_ndims)
 |  
 |  Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
 |  dims in :meth:`check`, so that an event is valid only if all its
 |  independent entries are valid.
 |  
 |  Method resolution order:
 |      _IndependentConstraint
 |      Constraint
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, base_constraint, reinterpreted_batch_ndims)
 |      Initialize self.  See help(type(self)) for accurate signature.
 |  
 |  __repr__(self)
 |      Return repr(self).
 |  
 |  check(self, value)
 |      Returns a byte tensor of ``sample_shape + batch_shape`` indicating
 |      whether each event in value satisfies this constraint.
 |  
 |  ----------------------------------------------------------------------
 |  Readonly properties defined here:
 |  
 

In [8]:
fitter.grid_path

'/home/tharvey/work/output/grid_BPASS_DelayedExponential_SFH_0.01_z_12_logN_5.7_Chab_CF00_v1.hdf5'