Running this notebook require some packages not included in the `paltax` requirements. If you do not already have these packages you will need to run:
.. code-block:: bash

    $ pip install matplotlib


In [None]:
import functools

import jax
import matplotlib.pyplot as plt

from paltax import input_pipeline
from paltax.InputConfigs import input_config_psf

# Generating Images Using `paltax`

__Author:__ Sebastian Wagner-Carena

__Goals:__ 

1. Import a `paltax` input configuration file.
2. Use the input_pipeline functions to draw batches of images.

### Table of Contents

1. [Input Configuration File](#input_config)
2. [Drawing Images](#draw_images) 
    

## Input Configuration File <a class="anchor" id="input_config"></a>

**Import a paltax input configuration file.**

Let's start by importing a paltax configuration file and disecting some of its values.

In [None]:
# Load the input configuration from a file
input_config = input_config_psf.get_config()

Some of the values in the input configuration are fairly straight forward. They are dictonaries that specify parameter values required for simulating the strong lenses. For example, the `kwargs_detector` specifies parameters that control the size of the image, the supersampling, and the noise.

In [None]:
input_config['kwargs_detector']

The parameters that control the lensing configuration and the psf are more complicated. They hold large arrays the encode distributions that can be used to draw parameter values. Here's an example:

In [None]:
# Start by taking a look at the x-coordinate center of the main deflector.
input_config['lensing_config']['main_deflector_params']['center_x']

In [None]:
rng = jax.random.PRNGKey(1)
input_pipeline.draw_sample(input_config['lensing_config']['main_deflector_params']['center_x'], rng)

This sampling will be done for free for us by the input_pipeline functions.

## Drawing Images <a class="anchor" id="draw_images"></a>

**Use the input_pipeline functions to draw batches of images.**

Using this configuration file, let's draw some images. All of the parameters we need to generate our images are already sitting in our input configuration. Let's start by creating our vmapped and jitted functions.

In [None]:
# This is the main function we will be calling to create our images.
draw_image_and_truth_vmap = jax.jit(jax.vmap(
        functools.partial(
            input_pipeline.draw_image_and_truth,
            all_models=input_config['all_models'],
            principal_model_indices=input_config['principal_model_indices'],
            kwargs_simulation=input_config['kwargs_simulation'],
            kwargs_detector=input_config['kwargs_detector'],
            kwargs_psf=input_config['kwargs_psf'],
            truth_parameters=input_config['truth_parameters'],
            normalize_image=False),
        in_axes=(None, None, None, None, 0, None)))

# We will also want to pre-calculate a few values using input_pipeline functions.
cosmology_params = input_pipeline.initialize_cosmology_params(input_config, rng)
grid_x, grid_y = input_pipeline.generate_grids(input_config)

In [None]:
# Let's draw some images
n_images = 32
rng_array = jax.random.split(jax.random.PRNGKey(1), n_images)
# You can specify rotation angles to be applied to the images. The rotation angles are applied
# at the level of the ray-tracing code, so there are no issues with pixelization. For the most
# part, we're only interested in this functionality when we want to capture the effects of
# image augmentations during training.
rotation_angle = 0.0

# This first call will be slow since we are compiling the function. On CPU this may be
# painfully slow.
_ = draw_image_and_truth_vmap(input_config['lensing_config'], cosmology_params, grid_x,
                              grid_y, rng_array, rotation_angle)

Once we're done, we have a fast function for drawing batches of images.

In [None]:
%%timeit
draw_image_and_truth_vmap(input_config['lensing_config'], cosmology_params, grid_x,
                          grid_y, rng_array, rotation_angle)

In [None]:
# Finally, let's take a look at one of our generated strong lenses.
images, truths = draw_image_and_truth_vmap(input_config['lensing_config'], cosmology_params, grid_x,
                                           grid_y, rng_array, rotation_angle)
for image in images[:5]:
    plt.imshow(image)
    plt.show()