In [None]:
import matplotlib.pyplot as plt
import pytorch_lightning as pl

import bliss
import bliss.models.galaxy_net
import bliss.datasets.galsim_galaxies

# Create dataset

First we create the dataset containing SDSS galaxies. These galaxies have realistic sizes and fluxes from a catalog, but the morphology is bulge+disk+agn (basically a parametric Sersic mixture) so they are not as realistic as they could be. 

In [None]:
# the catalog we will be using has a sample of 'easiest' (not too small or too faint) galaxies in the catalog.
catalog_file = '../../data/gold_dc2_catalog.fits'

In [None]:
# we prepare a configuration object that is used to create the dataset & model.
from hydra.experimental import initialize, compose
def get_cfg(overrides):
    overrides = [f"{key}={value}" for key, value in overrides.items()]
    with initialize(config_path="../../config"):
        cfg = compose("config", overrides=overrides)
    return cfg

overrides = {
             # dataset information
             'dataset':'sdss_galaxies', 'dataset.cosmoDC2_file': catalog_file,
    
             # model info. 
             'model':'galaxy_net',
    
             # pytorch lightning trainer. 
             'training': 'default'
}
cfg = get_cfg(overrides)
dataset = bliss.datasets.galsim_galaxies.SDSSGalaxies(cfg)

In [None]:
# some example centered individual galaxies from the dataset.
fig, axes = plt.subplots(4,4, figsize=(20, 20))

for ax in axes.flatten():
    idx = np.random.randint(len(dataset.catalog))
    ex = dataset[idx]
    im = ax.imshow(ex['images'][0])
    fig.colorbar(im, ax=ax)

plt.tight_layout()

# Create VAE and Train

The configuration object we created above already contains the model information for our galaxy VAE. 

In [None]:
print(cfg.model)

We can create the VAE directly from this configuration. 

In [None]:
VAE = bliss.models.galaxy_net.OneCenteredGalaxy(cfg)

And we also need a trainer to train. 

In [None]:
# create trainer
n_epochs = 101
trainer = pl.Trainer(profiler=None, logger=False, checkpoint_callback=False, 
                     max_epochs=n_epochs, min_epochs=n_epochs, 
                     gpus=[2], check_val_every_n_epoch=1001,)


# train! 
trainer.fit(VAE, datamodule=dataset)

# Compare results

Now that our model is trained we can compare some results. 

In [None]:
import torch
VAE.eval()

# example reconstruction
# some example centered individual galaxies from the dataset.
fig, axes = plt.subplots(5,2, figsize=(10, 24))

#always different.
for b in dataset.train_dataloader():
    batch = b
    break

recon_mean, recon_var, _ = VAE(batch['images'],batch['background'])

for i, (ax1, ax2) in enumerate(axes):
    
    # plot
    im1 = ax1.imshow(batch['images'][i][0])
    fig.colorbar(im1, ax=ax1)
    
    im2 = ax2.imshow(recon_mean[i][0].detach().numpy())
    fig.colorbar(im2, ax=ax2)
plt.tight_layout()