In [1]:
import torch

import matplotlib.pyplot as plt
%matplotlib inline

import pytorch_lightning as pl

import bliss
import bliss.models.galaxy_net
import bliss.datasets.galsim_galaxies

In [2]:
import numpy as np

In [3]:
pl.seed_everything(743384)

743384

# Create dataset

First we create the dataset containing SDSS galaxies. These galaxies have realistic sizes and fluxes from a catalog, but the morphology is bulge+disk+agn (basically a parametric Sersic mixture) so they are not as realistic as they could be. 

In [4]:
# the catalog we will be using has a sample of 'easiest' (not too small or too faint) galaxies in the catalog.
catalog_file = '../../data/gold_dc2_catalog.fits'

In [5]:
# we prepare a configuration object that is used to create the dataset & model.
from hydra.experimental import initialize, compose
def get_cfg(overrides):
    overrides = [f"{key}={value}" for key, value in overrides.items()]
    with initialize(config_path="../../config"):
        cfg = compose("config", overrides=overrides)
    return cfg

overrides = {
             # dataset information
             'dataset':'toy_gaussian',
    
             # model info. 
             'model':'galaxy_net',
    
             # pytorch lightning trainer. 
             'training': 'default'
}
cfg = get_cfg(overrides)


dataset = bliss.datasets.galsim_galaxies.ToyGaussian(cfg)

In [6]:
cfg

{'mode': 'train', 'general': {'overwrite': False}, 'gpus': [0], 'paths': {'root': '${env:BLISS_HOME}', 'data': '${paths.root}/data', 'models': '${paths.root}/models', 'output': '${paths.root}/temp/default', 'sdss': '${paths.root}/data/sdss'}, 'optimizer': {'name': 'torch.optim.Adam', 'params': {'lr': 0.0001, 'weight_decay': 1e-06}}, 'model': {'name': 'OneCenteredGalaxy', 'warm_up': 0, 'params': {'slen': 51, 'latent_dim': 8, 'n_bands': 1, 'hidden': 256}}, 'training': {'deterministic': False, 'plotting': True, 'n_epochs': 121, 'trainer': {'profiler': None, 'logger': True, 'checkpoint_callback': False, 'reload_dataloaders_every_epoch': False, 'max_epochs': '${training.n_epochs}', 'min_epochs': '${training.n_epochs}', 'gpus': '${gpus}', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'check_val_every_n_epoch': 10}}, 'dataset': {'name': 'ToyGaussian', 'n_batches': 10, 'batch_size': 128, 'num_workers': 0, 'params': {'slen': 51, 'deviate_center': False, 'background': 865, 'noise_factor'

In [7]:
%matplotlib inline

In [8]:
fig, ax = plt.subplots(5, 5, figsize=(20, 20))

for i in range(25): 
    
    indx = np.random.choice(len(dataset))
    
    x0 = i // 5
    x1 = i % 5
    
    ex = dataset[indx]
    
    im = ax[x0, x1].matshow(ex['images'][0])
    
    fig.colorbar(im, ax = ax[x0, x1])

Error in callback <function flush_figures at 0x7f5c5401d670> (for post_execute):


KeyboardInterrupt: 

# Create VAE and Train

The configuration object we created above already contains the model information for our galaxy VAE. 

In [9]:
print(cfg.model)

{'name': 'OneCenteredGalaxy', 'warm_up': 0, 'params': {'slen': 51, 'latent_dim': 8, 'n_bands': 1, 'hidden': 256}}


We can create the VAE directly from this configuration. 

In [10]:
VAE = bliss.models.galaxy_net.OneCenteredGalaxy(cfg)

And we also need a trainer to train. 

In [None]:
# create trainer
n_epochs = 2000
trainer = pl.Trainer(profiler=None, logger=False, checkpoint_callback=False, 
                     max_epochs=n_epochs, min_epochs=n_epochs, 
                     gpus=[2], check_val_every_n_epoch=100,)


# train! 
trainer.fit(VAE, datamodule=dataset)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type                  | Params
-----------------------------------------------
0 | enc  | CenteredGalaxyEncoder | 10.9 M
1 | dec  | CenteredGalaxyDecoder | 11.3 M


Validation sanity check:  50%|█████     | 1/2 [00:00<00:00,  7.72it/s]



Epoch 0:   0%|          | 0/20 [00:00<?, ?it/s]                       



Epoch 34:  40%|████      | 4/10 [00:00<00:01,  5.62it/s, loss=12104249344.000] 

In [None]:
torch.save(VAE.state_dict(), './galaxy_vae')

# Compare results

Now that our model is trained we can compare some results. 

In [None]:
VAE.eval()

# example reconstruction
# some example centered individual galaxies from the dataset.

#always different.
for b in dataset.train_dataloader():
    batch = b
    break

recon_mean, recon_var, _ = VAE(batch['images'],batch['background'])

for i in range(10): 
    
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4, figsize=(15, 3))

    
    obs = batch['images'][i][0]
    recon = recon_mean[i][0].detach()
    var = recon_var[i][0].detach()
    
    # plot
    im1 = ax1.imshow(obs)
    fig.colorbar(im1, ax=ax1)
    
    im2 = ax2.imshow(recon)
    fig.colorbar(im2, ax=ax2)
    
    im3 = ax3.imshow(var * 0.)
    fig.colorbar(im3, ax=ax3)
    
    
    diff = (obs - recon) # / torch.sqrt(var) 
    vmax = diff.abs().max()
    im4 = ax4.imshow(diff, vmax = vmax, vmin = -vmax, cmap = plt.get_cmap('bwr'))
    fig.colorbar(im4, ax = ax4)
    
    fig.tight_layout()

In [None]:
z_mean, z_var = VAE.enc.forward(batch['images'] - batch['background'])

In [None]:
plt.hist(z_mean.flatten().detach().numpy());