In [286]:
from typing import Dict, List, Optional

from hydra import compose, initialize
from hydra.utils import instantiate
import numpy as np
import pytorch_lightning as pl
import torch
from einops import rearrange
from torch.utils.data import DataLoader, Dataset

from bliss.catalog import TileCatalog, get_is_on_from_n_sources, FullCatalog
from bliss.datasets.background import ConstantBackground
from bliss.datasets.sdss import convert_flux_to_mag
from bliss.models.galsim_decoder import SingleGalsimGalaxyDecoder,SingleGalsimGalaxyPrior,UniformGalsimGalaxiesPrior
from bliss.datasets.galsim_galaxies import GalsimBlends
from bliss.reporting import get_single_galaxy_ellipticities

In [332]:
def _add_noise_and_background(image: Tensor, background: Tensor) -> Tensor:
    image_with_background = image + background
    noise = image_with_background.sqrt() * torch.randn_like(image_with_background)
    return image_with_background + noise

def load_psf_from_file(psf_image_file: str, pixel_scale: float) -> galsim.GSObject:
    """Return normalized PSF galsim.GSObject from numpy psf_file."""
    assert Path(psf_image_file).suffix == ".npy"
    psf_image = np.load(psf_image_file)
    assert len(psf_image.shape) == 3 and psf_image.shape[0] == 1
    psf_image = galsim.Image(psf_image[0], scale=pixel_scale)
    return galsim.InterpolatedImage(psf_image).withFlux(1.0)

def _sample_n_sources(max_n_sources) -> int:
    return int(torch.randint(1, max_n_sources + 1, (1,)).int().item())

def _uniform(a, b, n_samples=1) -> Tensor:
    # uses pytorch to return a single float ~ U(a, b)
    return (a - b) * torch.rand(n_samples) + b

class CoaddUniformGalsimGalaxiesPrior(UniformGalsimGalaxiesPrior):
    def __init__(
        self,
        single_galaxy_prior: SingleGalsimGalaxyPrior,
        max_n_sources: int,
        max_shift: float,
        num_dithers: int,
    ):
        super().__init__(
            single_galaxy_prior,
            max_n_sources,
            max_shift,
        )
    def sample(self) -> Dict[str, Tensor]:
        """Returns a single batch of source parameters."""
        n_sources = _sample_n_sources(self.max_n_sources)

        params = torch.zeros(self.max_n_sources, self.dim_latents)
        params[:n_sources, :] = self.single_galaxy_prior.sample(n_sources)

        locs = torch.zeros(self.max_n_sources, 2)
        locs[:n_sources, 0] = _uniform(-self.max_shift, self.max_shift, n_sources) + 0.5
        locs[:n_sources, 1] = _uniform(-self.max_shift, self.max_shift, n_sources) + 0.5

        # for now, galaxies only
        galaxy_bools = torch.zeros(self.max_n_sources, 1)
        galaxy_bools[:n_sources, :] = 1
        star_bools = torch.zeros(self.max_n_sources, 1)

        dithers = [((-0.5 - 0.5) * torch.rand((2,)) + 0.5).numpy() for x in range(num_dithers)]

        return {
            "n_sources": torch.tensor(n_sources),
            "galaxy_params": params,
            "locs": locs,
            "galaxy_bools": galaxy_bools,
            "star_bools": star_bools,
            "dithers": dithers,
        }

class CoaddSingleGalaxyDecoder(SingleGalsimGalaxyDecoder):
    def __init__(
        self,
        slen: int,
        n_bands: int,
        pixel_scale: float,
#        psf_image_file: str,
     ):
        super().__init__(
            slen,
            n_bands,
            pixel_scale,
            psf_image_file,
        )
        assert n_bands == 1, "Only 1 band is supported"
        self.slen = slen
        self.n_bands = 1
        self.pixel_scale = pixel_scale
        self.psf = load_psf_from_file(psf_image_file, self.pixel_scale)
    
    def render_galaxy(
        self,
        galaxy_params: Tensor,
        psf: galsim.GSObject,
        slen: int,
        offset: Optional[Tensor] = None,
        dithers: Optional[Tensor] = None,
    ) -> Tensor:
        assert offset is None or offset.shape == (2,)
        if isinstance(galaxy_params, Tensor):
            galaxy_params = galaxy_params.cpu().detach()
        total_flux, disk_frac, beta_radians, disk_q, a_d, bulge_q, a_b = galaxy_params
        bulge_frac = 1 - disk_frac

        disk_flux = total_flux * disk_frac
        bulge_flux = total_flux * bulge_frac

        components = []
        if disk_flux > 0:
            b_d = a_d * disk_q
            disk_hlr_arcsecs = np.sqrt(a_d * b_d)
            disk = galsim.Exponential(flux=disk_flux, half_light_radius=disk_hlr_arcsecs).shear(
                q=disk_q,
                beta=beta_radians * galsim.radians,
            )
            components.append(disk)
        if bulge_flux > 0:
            b_b = bulge_q * a_b
            bulge_hlr_arcsecs = np.sqrt(a_b * b_b)
            bulge = galsim.DeVaucouleurs(
                flux=bulge_flux, half_light_radius=bulge_hlr_arcsecs
            ).shear(q=bulge_q, beta=beta_radians * galsim.radians)
            components.append(bulge)
        galaxy = galsim.Add(components)
        gal_conv = galsim.Convolution(galaxy, psf)
        offset = (0,0) if offset is None else offset.numpy()
        shift = torch.add(torch.Tensor(dithers), torch.Tensor(offset))
        images = []
        for i in shift:
            image = gal_conv.drawImage(
                nx=slen, ny=slen, method="auto", scale=self.pixel_scale, offset=i
            )
            image = image.array
            images.append(image)
        return torch.tensor(images[:]).reshape(len(dithers), 1, slen, slen)

class FullCatalogDecoder:
    def __init__(
        self, single_galaxy_decoder: CoaddSingleGalaxyDecoder, slen: int, bp: int
    ) -> None:
        self.single_decoder = single_galaxy_decoder
        self.slen = slen
        self.bp = bp
        assert self.slen + 2 * self.bp >= self.single_decoder.slen

    def __call__(self, full_cat: FullCatalog):
        return self.render_catalog(full_cat, self.single_decoder.psf)

    def render_catalog(self, full_cat: FullCatalog, psf: galsim.GSObject, dithers: Optional[Tensor]):
        size = self.slen + 2 * self.bp
        full_plocs = full_cat.plocs
        b, max_n_sources, _ = full_plocs.shape
        assert b == 1, "Only one batch supported for now."
 #       assert self.single_decoder.n_bands == 1, "Only 1 band supported for now"

        image = torch.zeros(1, size, size)
        noiseless_centered = torch.zeros(max_n_sources, 1, size, size)
        noiseless_uncentered = torch.zeros(max_n_sources, 1, size, size)

        n_sources = int(full_cat.n_sources[0].item())
        galaxy_params = full_cat["galaxy_params"][0]
        plocs = full_plocs[0]
        for ii in range(n_sources):
            offset_x = plocs[ii][1] + self.bp - size / 2
            offset_y = plocs[ii][0] + self.bp - size / 2
            offset = torch.tensor([offset_x, offset_y])
            centered = self.single_decoder.render_galaxy(galaxy_params[ii], psf, size)
            uncentered = self.single_decoder.render_galaxy(galaxy_params[ii], psf, size, offset, dithers)
            noiseless_centered[ii] = centered
            noiseless_uncentered[ii] = uncentered
            image += uncentered
        return image, noiseless_centered, noiseless_uncentered

class CoaddGalsimBlends(GalsimBlends):
    """Dataset of coadd galsim blends."""

    def __init__(self,
        prior: UniformGalsimGalaxiesPrior,
        decoder: FullCatalogDecoder,
        background: ConstantBackground,
        tile_slen: int,
        max_sources_per_tile: int,
        num_workers: int,
        batch_size: int,
        n_batches: int,
        fix_validation_set: bool = False,
        valid_n_batches: Optional[int] = None,
    ):
        super().__init__(
            prior, 
            decoder,
            background,
            tile_slen,
            max_sources_per_tile,
            num_workers,
            batch_size,
            n_batches,
            fix_validation_set,
            valid_n_batches,
        )
        self.slen = self.decoder.slen
        self.pixel_scale = self.decoder.single_decoder.pixel_scale
        

    def _get_images(self, full_cat, dithers):
        noiseless, noiseless_centered, noiseless_uncentered = FullCatalogDecoder.render_catalog(
            full_cat, dithers
        )

        # get background and noisy image.
        background = self.background.sample((1, *noiseless.shape)).squeeze(0)
        noisy_image = _add_noise_and_background(noiseless, background)

        return noisy_image, noiseless, noiseless_centered, noiseless_uncentered, background





In [128]:
with initialize(config_path="../sdss_galaxies/config"):
    cfg = compose("config", overrides=[])

In [165]:
prior = instantiate(cfg.datasets.galsim_blended_galaxies.prior) 
decoder = instantiate(cfg.datasets.sdss_galaxies.decoder)
background = instantiate(cfg.datasets.galsim_blended_galaxies.background)
tile_slen = 4
max_tile_n_sources = 1
num_workers = 5
n_batches = 1
max_n_sources = 1
max_shift = 0.5
num_dithers = 4
psf_sampler = PsfSampler(0.5, 1.2)
mprior = instantiate(cfg.models.prior)
tile_catalog = mprior.sample_prior(tile_slen=4, batch_size=3, n_tiles_h=2, n_tiles_w=2)
full_catalog = TileCatalog.to_full_params(tile_catalog)

In [131]:
# Test CUGGP
cuggp = CoaddUniformGalsimGalaxiesPrior(prior, max_n_sources, max_shift, num_dithers)


In [212]:
cuggprior = cuggp.sample()
n_sources = cuggprior["n_sources"]
galaxy_params = cuggprior["galaxy_params"]
locs = cuggprior["locs"]
galaxy_bools = cuggprior["galaxy_bools"]
star_bools = cuggprior["star_bools"]
dithers = cuggprior["dithers"]
offset = None
psf = galsim.Kolmogorov((700 * 1.e-9) / (0.15 * (700/500)**1.2))


In [227]:
csgd = CoaddSingleGalaxyDecoder(decoder.slen, decoder.n_bands, decoder.pixel_scale)
csgd.render_galaxy(galaxy_params = galaxy_params[0], slen = decoder.slen, psf = decoder.psf, offset = offset, dithers = dithers)

tensor([[[[3.3258e-05, 2.9777e-05, 2.6534e-05,  ..., 3.0838e-07,
           3.2309e-07, 4.9081e-07],
          [4.6939e-05, 4.1960e-05, 3.7550e-05,  ..., 1.3410e-07,
           2.6802e-07, 3.2376e-07],
          [6.6510e-05, 5.9702e-05, 5.3342e-05,  ..., 2.0643e-07,
           1.5306e-07, 3.1153e-07],
          ...,
          [3.2187e-07, 2.8910e-07, 3.5551e-07,  ..., 8.0205e-05,
           8.9391e-05, 9.9055e-05],
          [2.6956e-07, 2.6508e-07, 2.3644e-07,  ..., 5.6071e-05,
           6.2556e-05, 6.9705e-05],
          [3.4609e-07, 2.7451e-07, 2.9154e-07,  ..., 3.9406e-05,
           4.4148e-05, 4.9229e-05]]],


        [[[4.0317e-05, 3.5991e-05, 3.2106e-05,  ..., 1.9041e-07,
           2.2000e-07, 2.9964e-07],
          [5.6988e-05, 5.1018e-05, 4.5518e-05,  ..., 1.9611e-07,
           1.9967e-07, 2.3819e-07],
          [8.1168e-05, 7.2808e-05, 6.5133e-05,  ..., 3.1555e-07,
           2.6318e-07, 2.6624e-07],
          ...,
          [1.7119e-07, 2.6860e-07, 2.9527e-07,  ..., 6.55

In [354]:
batch_size = 1000
decoder = instantiate(cfg.datasets.galsim_blended_galaxies.decoder)
full_catalog = CoaddGalsimBlends(prior, decoder, background, tile_slen, max_tile_n_sources, num_workers, n_batches, batch_size)._sample_full_catalog()

In [356]:
decoder = instantiate(cfg.datasets.sdss_galaxies.decoder)
FullCatalogDecoder(decoder, decoder.slen, bp).render_catalog(full_catalog, psf, dithers)

GalSimFFTSizeError: drawFFT requires an FFT that is too large.
The required FFT size would be 9588 x 9588, which requires 2.05 GB of memory.
If you can handle the large FFT, you may update gsparams.maximum_fft_size.

In [357]:

size = decoder.slen + 2 * bp
full_plocs = full_catalog.plocs
b, max_n_sources, _ = full_plocs.shape
assert b == 1, "Only one batch supported for now."
assert FullCatalogDecoder(single_galaxy_decoder = decoder, slen = decoder.slen, bp = bp).single_decoder.n_bands == 1, "Only 1 band supported for now"

image = torch.zeros(1, size, size)
noiseless_centered = torch.zeros(max_n_sources, 1, size, size)
noiseless_uncentered = torch.zeros(max_n_sources, 1, size, size)

n_sources = int(full_catalog.n_sources[0].item())
galaxy_params = full_catalog["galaxy_params"][0]
plocs = full_plocs[0]
for ii in range(n_sources):
    offset_x = plocs[ii][1] + bp - size / 2
    offset_y = plocs[ii][0] + bp - size / 2
    offset = torch.tensor([offset_x, offset_y])
    centered = FullCatalogDecoder(single_galaxy_decoder = decoder, slen = decoder.slen, bp = bp).single_decoder.render_galaxy(galaxy_params[ii], psf, size)
    uncentered = FullCatalogDecoder(single_galaxy_decoder = decoder, slen = decoder.slen, bp = bp).single_decoder.render_galaxy(galaxy_params[ii], psf, size, offset, dithers)
    noiseless_centered[ii] = centered
    noiseless_uncentered[ii] = uncentered
    image += uncentered


GalSimFFTSizeError: drawFFT requires an FFT that is too large.
The required FFT size would be 9588 x 9588, which requires 2.05 GB of memory.
If you can handle the large FFT, you may update gsparams.maximum_fft_size.