In [None]:
import functools

from flax.training import checkpoints
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from scipy.stats import linregress

from jaxstronomy import train
from jaxstronomy import input_pipeline

In [None]:
# Pull our input configurations
from jaxstronomy import train_config
from jaxstronomy import input_config as ic

config = train_config.get_config()
input_config = ic.get_config()
image_size = input_config['kwargs_detector']['n_x']
rng = jax.random.PRNGKey(0)

# A few parameters we'll need later on to generate our images for comparison.
cosmology_params = input_pipeline.intialize_cosmology_params(input_config, rng)
grid_x, grid_y = input_pipeline.generate_grids(input_config)

# A function for drawing large batches of images without running out of memory.
def generate_images_and_truth(input_config, cosmology_params, grid_x, grid_y, rng, n_batches):
    draw_image_and_truth_vmap = jax.jit(jax.vmap(functools.partial(
        input_pipeline.draw_image_and_truth, all_models=input_config['all_models'],
        principal_md_index=input_config['principal_md_index'],
        principal_source_index=input_config['principal_source_index'],
        kwargs_simulation=input_config['kwargs_simulation'], kwargs_detector=input_config['kwargs_detector'],
        kwargs_psf=input_config['kwargs_psf'], truth_parameters=input_config['truth_parameters']),
                                                 in_axes=(None, None, None, None, 0)))
    rng_batch, _ = jax.random.split(rng)
    image = []
    truth = []
    for _ in range(n_batches):
        rng_batch, _ = jax.random.split(rng_batch)
        rng_images = jax.random.split(rng_batch, config.batch_size)
        image_draw, truth_draw = draw_image_and_truth_vmap(input_config['lensing_config'], cosmology_params, grid_x, 
                                                           grid_y, rng_images)
        image.append(image_draw)
        truth.append(truth_draw)
    
    return jnp.concatenate(image, axis=0), jnp.concatenate(truth, axis=0)

# A jitted function for getting the outputs out of the current model.
@jax.jit
def get_outputs(state, image):
    return state.apply_fn({'params': state.params, 'batch_stats': state.batch_stats}, 
                        jnp.expand_dims(image, axis=-1), mutable=['batch_stats'])

# We're just going for short trainings on these examples. The problems will be kept very easy, so the
# signal should be strong.
config.steps_per_epoch = 10
config.num_train_steps = config.steps_per_epoch * 50
config.keep_every_n_steps = config.steps_per_epoch * 1
config.warmup_steps = config.steps_per_epoch * 5
config.batch_size = 32
config

### Training a model that has no variation in any parameter except substructure.

In [None]:
# Change every parameter but the substructure to static and make the model only predict the substructure
# parameter.
input_config['lensing_config']['main_deflector_params']['theta_e'] = input_pipeline.encode_constant(1.1)
input_config['lensing_config']['main_deflector_params']['slope'] = input_pipeline.encode_constant(2.0)
input_config['lensing_config']['main_deflector_params']['center_x'] = input_pipeline.encode_constant(0.08)
input_config['lensing_config']['main_deflector_params']['center_y'] = input_pipeline.encode_constant(-0.16)
input_config['lensing_config']['main_deflector_params']['axis_ratio'] = input_pipeline.encode_constant(0.9)
input_config['lensing_config']['main_deflector_params']['angle'] = input_pipeline.encode_constant(0.0)
input_config['lensing_config']['main_deflector_params']['gamma_ext'] = input_pipeline.encode_constant(0.0)
input_config['lensing_config']['source_params']['amp'] = input_pipeline.encode_constant(5.0)
input_config['lensing_config']['source_params']['sersic_radius'] = input_pipeline.encode_constant(1.5)
input_config['lensing_config']['source_params']['n_sersic'] = input_pipeline.encode_constant(1.5)
input_config['lensing_config']['source_params']['axis_ratio'] = input_pipeline.encode_constant(0.9)
input_config['lensing_config']['source_params']['angle'] = input_pipeline.encode_constant(0.0)
input_config['lensing_config']['source_params']['center_x'] = input_pipeline.encode_constant(-0.08)
input_config['lensing_config']['source_params']['center_y'] = input_pipeline.encode_constant(0.04)
input_config['truth_parameters'] = (['subhalo_params'],['sigma_sub'])

In [None]:
workdir = '/scratch/users/swagnerc/notebook_outputs/no_variation'
learning_rate = 1e-2
state = train.train_and_evaluate(config, input_config, workdir, rng, image_size, learning_rate)

In [None]:
# Let's take a peak at how we did.
n_batches = 8
image, truth = generate_images_and_truth(input_config, cosmology_params, grid_x, grid_y, rng, n_batches)
state = checkpoints.restore_checkpoint(workdir, state)
outputs = get_outputs(state, image)
n_params = len(input_config['truth_parameters'][0])
x = jnp.linspace(-2, 2, 10)
for i in range(n_params):
    plt.xlabel('True Normalized ' + input_config['truth_parameters'][1][i])
    plt.ylabel(r'Predicted $\mu$' + input_config['truth_parameters'][1][i])
    rho = linregress(truth[:,i], outputs[0][:,i]).rvalue
    print('rho for ' + input_config['truth_parameters'][1][i] + ': ' + str(rho))
    plt.plot(truth[:,i], outputs[0][:,i], '.')
    plt.plot(x,x, c='k')
    plt.xlim((-2,2))
    plt.ylim((-2,2))
    plt.show()

### Training a model that has variation in theta_e.

In [None]:
# Varying theta_e should make the problem a bit harder since, to first order, the sigma_sub signal comes from
# increasing and decreasing the Eisntein radius.
input_config['lensing_config']['main_deflector_params']['theta_e'] = input_pipeline.encode_normal(mean=1.1, std=0.15)
input_config['truth_parameters'] = (['main_deflector_params', 'subhalo_params'],['theta_e', 'sigma_sub'])

In [None]:
workdir = '/scratch/users/swagnerc/notebook_outputs/theta_e_variation'
learning_rate = 1e-2
state = train.train_and_evaluate(config, input_config, workdir, rng, image_size, learning_rate)

In [None]:
# Let's take a peak at how we did.
n_batches = 8
image, truth = generate_images_and_truth(input_config, cosmology_params, grid_x, grid_y, rng, n_batches)
state = checkpoints.restore_checkpoint(workdir, state)
outputs = get_outputs(state, image)
n_params = len(input_config['truth_parameters'][0])
x = jnp.linspace(-2, 2, 10)
for i in range(n_params):
    plt.xlabel('True Normalized ' + input_config['truth_parameters'][1][i])
    plt.ylabel(r'Predicted $\mu$' + input_config['truth_parameters'][1][i])
    rho = linregress(truth[:,i], outputs[0][:,i]).rvalue
    print('rho for ' + input_config['truth_parameters'][1][i] + ': ' + str(rho))
    plt.plot(truth[:,i], outputs[0][:,i], '.')
    plt.plot(x,x, c='k')
    plt.xlim((-2,2))
    plt.ylim((-2,2))
    plt.show()