In [None]:
%load_ext autoreload
%autoreload 2
%aimport

In [None]:
import os 
import sys
path = os.path.abspath('../..')
if path not in sys.path: 
    sys.path.insert(0, path)
sys.path[0]

In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import fitsio 

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('torch version: ', torch.__version__)
print(device)

In [None]:
from src.data import simulated_datasets_lib
from src import psf_transform_lib
from src.utils import const

# Load 

In [None]:
# this is the PSF I fitted using ground truth Hubble locations/fluxes. 
init_psf_params = torch.Tensor(np.load('../../data/fitted_powerlaw_psf_params.npy'))
power_law_psf = psf_transform_lib.PowerLawPSF(init_psf_params.to(device))
psf = power_law_psf.forward().detach()

# number of bands. Here, there are two. 
n_bands = psf.shape[0]


In [None]:
psf.shape

# Plot images

## test star simulator

In [None]:
# data parameters
with open('../../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)
data_params['max_stars'] = 5
data_params['mean_stars'] = 3
data_params['slen'] = 30 
print(data_params)

In [None]:
# set background 
background = torch.zeros(n_bands, data_params['slen'], data_params['slen'])
background[0] = 686.
background[1] = 1123.

In [None]:
# draw data 
n_images = 10

simulated_dataset = \
    simulated_datasets_lib.StarsDataset.load_dataset_from_params(n_images,
                    data_params, psf,
                    background,
                    transpose_psf = False, 
                    add_noise = True, draw_poisson=True)

In [None]:
# test sampling params 
n_sources, locs, params = simulated_dataset.simulator.sample_parameters(batchsize=1)

print(n_sources.shape, locs.shape, params.shape)
print('n_sources:\n', n_sources)
print('locs:\n', locs)
print('params:\n', params)

# check psf 
print(simulated_dataset.simulator.psf.shape)

In [None]:
#drawing
images = simulated_dataset.simulator.draw_image_from_params(locs, n_sources, fluxes=params)
plt.matshow(images.cpu().numpy()[0,0])

## Test galaxy simulation

In [None]:
# data parameters
with open('../../data/default_galaxy_parameters.json', 'r') as fp:
    data_params = json.load(fp)
print(data_params)

In [None]:
n_images=100
sim_ds = simulated_datasets_lib.GalaxyDataset.load_dataset_from_params(100, data_params)

In [None]:
#test ds works fine
gal = sim_ds.simulator.ds[0]
plt.imshow(gal[0])

In [None]:
from torch.distributions import Normal
p_z = Normal(torch.zeros(1), torch.ones(1))
z = p_z.rsample(torch.tensor([2, 8])).view(2, -1)  # shape = (8,)
z

In [None]:
# test sampling works fine 
z, gals = sim_ds.simulator.ds.sample(2)
gals = gals.detach().cpu().numpy()
print(z.shape) 
print(gals.shape)
plt.imshow(gals[0][0])

In [None]:
# test sampling source params
n_galaxy = torch.Tensor([3]).cuda(device)
galaxy_params, single_galaxies = sim_ds.simulator.get_source_params(n_galaxy)

print(galaxy_params.shape, single_galaxies.shape)
plt.imshow(single_galaxies[0][0][0])
assert single_galaxies[0, 3:, ...].sum() == 0  #all galaxies should be zero after. 
galaxy_params[0, 0:6, ...] # the rest should be zero after the third 


In [None]:
# test sample parameters
n_sources, locs, params = sim_ds.simulator.sample_parameters(1)
gal_params, single_galaxies = params 
print(n_sources.shape, 
      locs.shape, 
      gal_params.shape, 
      single_galaxies.shape)

In [None]:
single_galaxies.shape

In [None]:
#test image drawing
images = sim_ds.simulator.draw_image_from_params(locs, n_sources, sources=single_galaxies.cuda(device))
print(images.shape)

In [None]:
print(n_sources)
plt.imshow(images[0][0].detach().cpu().numpy())

# Testing

In [None]:
is_on_array = const.get_is_on_from_n_sources(n_stars, 5)
is_on_array
# this means that they are only 5 stars in the first image. 

In [None]:
(2 < n_stars).float()

In [None]:
n=2
is_on_n = (n < n_stars).float()
locs[:, n, :] * is_on_n.unsqueeze(1)

In [None]:
class Parent(object):
    def __init__(self, x): 
        self.x = x
        self.y = self.get_y()
        
    @staticmethod
    def get_y():
        return 3

class Child(Parent):
        
    @staticmethod
    def get_y(): 
        return 4 

In [None]:
c = Child(3) 

In [None]:
c.y