In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [8]:
import os
import re
from tqdm import tqdm, trange
import jax.numpy as np
from jax import ops
from jax import config
import numpy as onp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from gzbuilder_analysis.fitting.misc import psf_conv, downsample
from gzbuilder_analysis.fitting.optimizer import get_spirals, render_comps, Optimizer

In [9]:
config.update("jax_enable_x64", True)

In [10]:
fitting_metadata = pd.read_pickle('lib/fitting_metadata.pkl')

In [14]:
agg_res_path = 'output_files/aggregation_results'
agg_results = pd.Series([], dtype=object)
with tqdm(os.listdir(agg_res_path), desc='Loading aggregation results') as bar:
    for f in bar:
         if re.match(r'[0-9]+\.pkl.gz', f):
                agg_results[int(f.split('.')[0])] = pd.read_pickle(
                    os.path.join(agg_res_path, f)
                )

Loading aggregation results: 100%|██████████| 296/296 [00:02<00:00, 102.15it/s]


In [15]:
fit_model_path = 'output_files/tuning_results'
fit_models = pd.Series([], dtype=object)
with tqdm(os.listdir(fit_model_path), desc='Loading fitting results') as bar:
    for f in bar:
         if re.match(r'[0-9]+\.pickle.gz', f):
                fit_models[int(f.split('.')[0])] = pd.read_pickle(
                    os.path.join(fit_model_path, f)
                )
fit_models = fit_models.apply(pd.Series)

Loading fitting results: 100%|██████████| 296/296 [00:04<00:00, 61.33it/s]


In [13]:
subject_id = 20902040

In [16]:
fm = fitting_metadata.loc[subject_id]
agg_res = pd.read_pickle(
    'output_files/aggregation_results/{}.pkl.gz'.format(subject_id)
)
o = Optimizer(
    agg_res,
    *fm[['psf', 'galaxy_data', 'sigma_image']],
    oversample_n=5
)



In [35]:
fit_res = pd.Series(fit_models.loc[subject_id])

## Hessian Errors

In [42]:
errs = np.sqrt(max(1, abs(fit_res.res.fun)) * ftol * np.diag(fit_res.res.hess_inv.todense()))
print('\n'.join([
    '{: >15} = {:.4f} ± {:.4f}'.format(
        ' '.join(k),
        val,
        err,
    )
    for k, val, err in zip(
        fit_res['keys'],
        fit_res.res.x,
        errs,
    )
]))

          bar c = 4.1750 ± 0.0861
       bar frac = 0.0495 ± 0.0049
          bar n = 0.9309 ± 0.1095
          bar q = 0.2265 ± 0.0174
       bar roll = 2.9168 ± 0.0427
      bar scale = 0.1340 ± 0.0118
     bulge frac = 0.0230 ± 0.0084
        bulge n = 0.5000 ± 0.0148
        bulge q = 0.6126 ± 0.0169
     bulge roll = 2.9028 ± 0.2797
    bulge scale = 0.1474 ± 0.0155
     centre mux = 128.2383 ± 0.0718
     centre muy = 127.8199 ± 0.0258
         disk L = 673.8897 ± 0.3820
        disk Re = 45.1188 ± 0.2895
       disk mux = 128.2046 ± 0.1486
       disk muy = 129.3674 ± 0.3207
         disk q = 0.7666 ± 0.0123
      disk roll = 2.9676 ± 0.0289
     spiral A.0 = 10.4750 ± 0.0282
     spiral A.1 = 20.6104 ± 0.0155
     spiral I.0 = 0.9687 ± 0.0625
     spiral I.1 = 0.6812 ± 0.0246
   spiral phi.0 = 14.6669 ± 0.2985
   spiral phi.1 = 21.1346 ± 0.5153
spiral spread.0 = 3.1251 ± 0.1480
spiral spread.1 = 4.0953 ± 0.0957
 spiral t_max.0 = 6.0570 ± 0.0351
 spiral t_max.1 = 2.3915 ± 0.0280