In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import jax.numpy as np
from jax import jit, random, ops, jacfwd, jacrev, vmap
import pandas as pd
import gzbuilder_analysis.parsing as pg
from scipy.optimize import minimize
from gzbuilder_analysis.rendering.jax.spiral import vmap_polyline_distance
import gzbuilder_analysis.rendering.jax.fit as fit

from jax.config import config
config.update("jax_enable_x64", True)

In [108]:
key = random.PRNGKey(0)

In [3]:
diff_data_df = pd.read_pickle('lib/fitting_metadata.pkl')

In [4]:
subject_id = 20902040

In [5]:
diff_data = diff_data_df.loc[subject_id]
psf = diff_data['psf']
target = np.asarray(diff_data['galaxy_data'].data)
mask = np.asarray(diff_data['galaxy_data'].mask)
sigma = np.asarray(diff_data['sigma_image'].data)



In [6]:
agg_res = pd.read_pickle(f'output_files/aggregation_results/{subject_id}.pkl.gz')

In [None]:
fake_model = dict(
    disk=dict(mux=125., muy=125., q=0.9, roll=0., Re=80., L=500.),
    bulge=dict(mux=125., muy=125., q=0.9, roll=0., scale=0.3, frac=0, n=1.),
    bar=dict(mux=50., muy=50., q=0.3, roll=0., scale=0.5, frac=0.0, n=1., c=2.),
    spiral={},
)
fake_errors = dict(
    disk=dict(mux=5., muy=5., q=0.05, roll=0.01, Re=50, L=np.inf),
    bulge=dict(mux=5., muy=5., q=0.05, roll=0.01, scale=0.1, frac=np.inf, n=np.inf),
    bar=dict(mux=2., muy=2., q=0.05, roll=0.01, scale=0.1, frac=np.inf, n=np.inf, c=np.inf),
)
starting_guess = dict(
    disk=dict(mux=125.6, muy=122.8, q=0.9, roll=0., Re=85., L=520.),
    bulge=dict(mux=125., muy=125., q=0.9, roll=0., scale=0.8, frac=0.2, n=1.0),
    bar=dict(mux=50., muy=50., q=0.3, roll=0., scale=0.5, frac=0.0, n=1., c=2.),
    spiral={},
)

In [None]:
keys = [(k0, k1) for k0 in fake_model for k1 in fake_model[k0]]
keys = [v for v in keys if v[0] in {'disk',}]
p_true = np.array([fake_model[k0][k1] for k0, k1 in keys])

In [None]:
key, subkey = random.split(key)
fake_galaxy = fit.create_model(p_true, keys, 0, fake_model, psf, target, 5)
noise = random.normal(subkey, shape=fake_galaxy.shape) * sigma * 0.02
fake_target = fake_galaxy #+ noise
plt.imshow(fake_target)

In [None]:
p2 = np.array([fake_model[k0][k1] for k0, k1 in keys])
sp2 = np.array([fake_errors[k0][k1] for k0, k1 in keys])

In [None]:
args_fake = (keys, 0, starting_guess, lims_fake_errors, psf, mask, fake_target, sigma)
_jac = jacrev(fit.step)

In [None]:
from scipy.optimize import minimize
from tqdm import tqdm
with tqdm(desc='Fitting') as pbar:
    def callback(*args, **kwargs):
        pbar.update(1)
    res = minimize(fit.step, p2, args=args_fake, jac=_jac, callback=callback)

In [None]:
plt.imshow(fit.create_model(res['x'], keys, 0, fake_model, psf, target, 5) - fake_galaxy)
plt.colorbar()

In [None]:
pd.concat((
    pd.DataFrame(fit.to_dict(p_true, keys)).unstack(),
    pd.DataFrame(fit.to_dict(res['x'], keys)).unstack(),
), axis=1).dropna()

## With a real model:

In [86]:
reparametrized_model = fit.to_reparametrization(agg_res)
reparametrized_errors = fit.get_reparametrized_erros(agg_res)

keys = [(k0, k1) for k0 in reparametrized_model for k1 in reparametrized_model[k0]]
p0 = np.array([reparametrized_model[k0][k1] for k0, k1 in keys])

mu_param = p0.copy()
sigma_param = np.array([reparametrized_errors[k0][k1] for k0, k1 in keys])

In [81]:
args = (keys, len(agg_res.spiral_arms), reparametrized_model, reparametrized_errors, psf, mask, target, sigma)

Compile the fitting function:

In [82]:
fit.step(p0, *args).block_until_ready()

DeviceArray(-375264.76338677, dtype=float64)

Let's see its evaluation time

In [10]:
%timeit fit.step(p0, *args).block_until_ready()

388 ms ± 97 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


We can also obtain a jacobian, which too needs to be compiled before profiling (reccomended to only run on the GPU):

In [126]:
%%time
_jac = jacrev(fit.step)
_jac(p2, *args2).block_until_ready();

CPU times: user 9.14 s, sys: 450 ms, total: 9.59 s
Wall time: 9.13 s


DeviceArray([-1.99042703e+01, -1.92961206e+01, -1.00000000e+00,
              6.27103160e+01, -2.79930050e+04,  5.39481157e-22,
             -2.00047934e+01, -1.90901787e+01,  1.49857005e+02,
             -3.02955479e+04,  0.00000000e+00,  1.85596299e+01,
              3.23550227e-34], dtype=float64)

In [128]:
%%time
_jac(p2, *args2).block_until_ready();

CPU times: user 548 ms, sys: 179 ms, total: 727 ms
Wall time: 415 ms


DeviceArray([-1.99042703e+01, -1.92961206e+01, -1.00000000e+00,
              6.27103160e+01, -2.79930050e+04,  5.39481157e-22,
             -2.00047934e+01, -1.90901787e+01,  1.49857005e+02,
             -3.02955479e+04,  0.00000000e+00,  1.85596299e+01,
              3.23550227e-34], dtype=float64)