In [None]:
import inspect
import functools

import jax
import jaxstronomy

In [None]:
from jaxstronomy.source_models import Interpol

class PaltasGalaxyGatalog:

  parameters = ('galaxy_index', 'z_source', 'amp', 'center_x', 'center_y', 'angle',)

  # The hash and eq methods are needed for jit on methods to work. 
  # These implementations offer very poor guarantees on equality --
  # Do NOT mutate instances post-init!!
  def __hash__(self):
    return hash(self.n_images)

  def __eq__(self, other):
    return (isinstance(other, PaltasGalaxyGatalog) and
            (self.n_images == other.n_images))

  def __init__(self, 
               cosmology_params=None,
               paltas_class=None, 
               maximum_size_in_pixels=256, 
               **source_parameters):
    if paltas_class is None:
      # No paltas class specified -- this source model shouldn't be used
      # (and will crash if you try it anyway)
      return
    assert cosmology_params is not None
    self.cosmology_parameters = cosmology_params
    self.paltas_catalog = paltas_class(
      # We only use the paltas class for extracting raw images, and do
      # redshift and magnitude scaling ourselves in jax.
      # Thus we can load an arbitrary cosmology here.
      cosmology_parameters='planck18',
      # We don't need all parameters, but paltas won't let us initialize
      # without a complete set
      source_parameters=(
        source_parameters
        | dict(center_x=np.nan, center_y=np.nan, 
               z_source=np.nan, 
               # rotations are done by jaxstronomy, not paltas
               random_rotation=False)))

    passes_cuts = self.paltas_catalog._passes_cuts()
    catalog_indices = np.where(passes_cuts)[0]
    self.n_images = passes_cuts.sum()

    # Allocate memory (in main RAM, not on the GPU) 
    # for all the galaxy images that pass the cuts
    size = maximum_size_in_pixels
    images = np.zeros((self.n_images, size, size))
    pixel_sizes = np.zeros(self.n_images)
    redshifts = np.zeros(self.n_images)

    for galaxy_i, catalog_i in tqdm(
        zip(np.arange(self.n_images), catalog_indices),
        desc='Slurping galaxies into RAM...'):
      img, meta = self.paltas_catalog.image_and_metadata(catalog_i)
      pixel_sizes[galaxy_i] = meta['pixel_width']
      redshifts[galaxy_i] = meta['z']
      
      # Check if the image is too large, if so, downsample
      img_size = max(img.shape[0], img.shape[1])
      if img_size > size:
        # Image is too large: downsample it and adjust the pixel size.
        # Since images are in electrons/sec/pixel, we have to use sum,
        # not mean, to downsample.
        downsample_factor =int(np.ceil(img_size / size))
        assert downsample_factor > 1
        img = (
          downscale_local_mean(img, (downsample_factor, downsample_factor))
          * downsample_factor**2)
        pixel_sizes[galaxy_i] *= downsample_factor
        # Recompute image size
        img_size = max(img.shape[0], img.shape[1])

      # Check if the image is too small, if so, pad with zeros
      if img_size < size:
        images[galaxy_i] = pad_image(img, size, size)
        

    # Convert attributes we need later to jax arrays
    self.pixel_sizes = jnp.asarray(pixel_sizes)
    self.redshifts = jnp.asarray(redshifts)
    # Place the giant image array in main RAM, not GPU memory
    with jax.default_device(jax.devices("cpu")[0]):
      self.images = jnp.asarray(images)

  def function(self, x, y, galaxy_index, z_source, amp, center_x, center_y, angle):
    # Conver from uniform[0,1] into a discrete index
    galaxy_index = jnp.floor(galaxy_index * self.n_images).astype(int)
    img = self.images[galaxy_index]

    # Convert to from electrons/sec/pixel to electrons/sec/arcsec
    pixel_size = self.pixel_sizes[galaxy_index]
    img = img / pixel_size**2

    # Take into account the difference in the magnitude zeropoints
    # of the input survey and the output survey. Note this doesn't
    # take into account the color of the object!
    img *= 10**((
        self.paltas_catalog.source_parameters['output_ab_zeropoint']
        - self.paltas_catalog.ab_zeropoint
      ) / 2.5)

    pixel_size *= self.z_scale_factor(self.redshifts[galaxy_index], z_source)

    # TODO: implement k-corrections
    # Apply the k correction to the image from the redshifting
    # self.k_correct_image(img,metadata['z'],z_new)

    return Interpol.function(x, y, img, amp, center_x, center_y, angle, pixel_size)

  @functools.partial(jax.jit, static_argnums=0)
  def z_scale_factor(self, z_old, z_new):
    """Return multiplication factor for object/pixel size for moving its
    redshift from z_old to z_new.

    Args:
      z_old (float): The original redshift of the object.
      z_new (float): The redshift the object will be placed at.

    Returns:
      (float): The multiplicative pixel size.
    """
    # Pixel length ~ angular diameter distance
    # (colossus uses funny /h units, but for ratios it
    #  fortunately doesn't matter)
    return (
      cosmology_utils.angular_diameter_distance(self.cosmology_parameters, z_old)
      / cosmology_utils.angular_diameter_distance(self.cosmology_parameters, z_new))


def pad_image(img, nx, ny):
  """Returns img with zeros padded on both sides so shape is (nx, ny)"""
  old_nx, old_ny = img.shape
  result = np.zeros((nx, ny), dtype=img.dtype)
  x_center = (nx - old_nx) // 2
  y_center = (ny - old_ny) // 2
  result[
    x_center:x_center + old_nx, 
    y_center:y_center + old_ny] = img
  return result

## Config and init

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import paltas

from comparison_files import input_config_jaxstronomy as input_config_module
from jaxstronomy.input_pipeline import encode_constant, encode_uniform

paltas_root_path = paltas.__path__[0][:-7]
cosmos_folder = paltas_root_path + r'/datasets/cosmos/COSMOS_23.5_training_sample/'

In [None]:
input_config = input_config_module.get_config()

rng = jax.random.PRNGKey(0)

cosmology_params = jaxstronomy.input_pipeline.intialize_cosmology_params(input_config, rng)

input_config['all_models']['all_source_models'] = input_config['all_models']['all_source_models'] + (
    jaxstronomy.source_models.PaltasGalaxyGatalog(
        paltas_class=paltas.Sources.cosmos.COSMOSCatalog,
        cosmology_params=cosmology_params,
        cosmos_folder=cosmos_folder,
        max_z=1.0,
        minimum_size_in_pixels=64,
        min_flux_radius=10.0,
        faintest_apparent_mag=20,
        smoothing_sigma=0.,
        output_ab_zeropoint=input_config['kwargs_detector']['magnitude_zero_point'],
        source_exclusion_list=np.append(
            pd.read_csv(
                os.path.join(paltas_root_path,'paltas/Sources/bad_galaxies.csv'),
                names=['catalog_i'])['catalog_i'].to_numpy(),
            pd.read_csv(
                os.path.join(paltas_root_path,'paltas/Sources/val_galaxies.csv'),
                names=['catalog_i'])['catalog_i'].to_numpy())
    ),
)

grid_x, grid_y = jaxstronomy.input_pipeline.generate_grids(input_config)
# input_config['all_source_models'][0].images.nbytes / 1e6

In [None]:
old_params = input_config['lensing_config']['source_params']
input_config['lensing_config']['source_params'].update(dict(
    center_x=encode_constant(0.),
    center_y=encode_constant(0.),
    amp=encode_constant(1.),
    galaxy_index=encode_uniform(0., 1.),
))
# Lens light will need to contain parallel configs... 
input_config['lensing_config']['lens_light_params'].update(dict(
    galaxy_index=encode_constant(0.5),
    z_source=input_config['lensing_config']['main_deflector_params']['z_lens'],
))

In [None]:
draw = jaxstronomy.input_pipeline.draw_image_and_truth
# Names of arguments to pass from the config to draw during jit:
# everything we can pass, except for rng (varies each draw) 
# and cosmology params (for which we instead pass the initialized version)
draw_kwargs = (
    set(inspect.signature(draw).parameters)
    .intersection(input_config)
    .difference(('rng', 'cosmology_params')))
draw = jax.jit(
    functools.partial(
        draw, 
        **{k: input_config[k] for k in draw_kwargs},
        cosmology_params=cosmology_params,
        grid_x=grid_x,
        grid_y=grid_y,
        normalize_image=False),    
    # HACK for Jelle's laptop GPU with only 4GB RAM
    backend='cpu',
)

In [None]:
rng, rng_draw = jax.random.split(rng)
image_jax, truth_jax = draw(rng=rng_draw)
plt.imshow(image_jax)
plt.colorbar()