In [None]:
%load_ext autoreload
%autoreload 2
%aimport

In [None]:
import os 
import sys
path = os.path.abspath('../..')
if path not in sys.path: 
    sys.path.insert(0, path)
sys.path[0]

In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import fitsio 

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('torch version: ', torch.__version__)
print(device)

In [None]:
from src import psf_transform_lib
from src.utils import const
from src.data import simulated_datasets_lib
from src import starnet_lib

# Load 

In [None]:
# this is the PSF I fitted using ground truth Hubble locations/fluxes. 
init_psf_params = torch.tensor(np.load('../../data/fitted_powerlaw_psf_params.npy'), device=device)
power_law_psf = psf_transform_lib.PowerLawPSF(init_psf_params)
psf = power_law_psf.forward().detach()

# number of bands. Here, there are two. 
n_bands = psf.shape[0]
print(n_bands)
print(psf.shape)


# Plot star images

## test star simulator

In [None]:
# data parameters
with open('../../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)
data_params['max_stars'] = 10
data_params['mean_stars'] = 5
data_params['slen'] = 30 
print(data_params)

In [None]:
# set background 
background = torch.zeros(n_bands, data_params['slen'], data_params['slen'], device=device)
background[0] = 686.
background[1] = 1123.

In [None]:
# draw data 
n_images = 10

simulated_dataset = \
    simulated_datasets_lib.StarDataset.load_dataset_from_params(n_images,
                    data_params, psf,
                    background,
                    transpose_psf = False, 
                    add_noise = True, draw_poisson=True)

In [None]:
# test sampling params 
n_sources, locs, fluxes = simulated_dataset.simulator.sample_parameters(batchsize=1)

print(n_sources.shape, locs.shape, fluxes.shape)
print('n_sources:\n', n_sources)
print('locs:\n', locs)
print('fluxes:\n', fluxes)

# check psf 
print(simulated_dataset.simulator.psf.shape)

In [None]:
#drawing
images = simulated_dataset.simulator.generate_images(locs, n_sources, fluxes)
print(images.shape)
plt.matshow(images.cpu().numpy()[0,0])

locs_x = locs[0][:,0].cpu().numpy() * (simulated_dataset.slen-1)
locs_y = locs[0][:,1].cpu().numpy() * (simulated_dataset.slen-1)
plt.scatter(locs_y, locs_x, marker='x', color='red')

## test batching 

In [None]:
# check some of the images obtained and locations. 
batch = simulated_dataset.get_batch(32)
fig, axs = plt.subplots(1, 3, figsize=(15,15));
axes = axs.flatten()
for i,ax in enumerate(axes): 
    images = batch['images'][i]
    locs = batch['locs'][i]
    
    ax.matshow(images.cpu().numpy()[0]);
    locs_x = locs[:,0].cpu().numpy() * (simulated_dataset.slen-1)
    locs_y = locs[:,1].cpu().numpy() * (simulated_dataset.slen-1)
    ax.scatter(locs_y, locs_x, marker='x', color='red');

# Test star encoder

In [None]:
slen = 30 
patch_slen = 8 
step = 2
edge_padding = 3 
n_bands = 2 
max_detections = 2 
n_source_params = 2

star_encoder = starnet_lib.StarEncoder(slen, patch_slen, step, edge_padding, n_bands, 
                                       max_detections, n_source_params).cuda()

In [None]:
# data parameters
with open('../../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)
data_params['max_stars'] = 10
data_params['mean_stars'] = 5
data_params['slen'] = 30 
print(data_params)

In [None]:
# draw data 
n_images = 10

simulated_dataset = \
    simulated_datasets_lib.StarDataset.load_dataset_from_params(n_images,
                    data_params, psf,
                    background,
                    transpose_psf = False, 
                    add_noise = True, draw_poisson=True)

# test sampling params 
n_sources, locs, fluxes = simulated_dataset.simulator.sample_parameters(batchsize=1)
images = simulated_dataset.simulator.generate_images(locs, n_sources, fluxes)
print(n_sources.shape, locs.shape, fluxes.shape, images.shape)

In [None]:
image_patches, patch_locs, patch_fluxes, patch_n_stars, _ = star_encoder.get_image_patches(images, 
                                           locs = locs, 
                                           source_params= fluxes)

print(image_patches.shape, patch_locs.shape, patch_fluxes.shape, patch_n_stars.shape)

In [None]:
#print where they are stars.. 
patch_n_stars.nonzero()

In [None]:
# same plot as before, but I also mark my estimated star on this star
indx = 58
plt.matshow(image_patches[indx, 1].detach().cpu().numpy())

# the neural network returns locations / fluxes for any stars in the center of the grid
plt.axvline(x=star_encoder.edge_padding - 0.5, color = 'r')
plt.axvline(x=star_encoder.patch_slen - star_encoder.edge_padding - 0.5, color = 'r')
plt.axhline(y=star_encoder.edge_padding - 0.5, color = 'r')
plt.axhline(y=star_encoder.patch_slen - star_encoder.edge_padding - 0.5, color = 'r')


# lets mark any true sources in this particular tile 
if patch_n_stars[indx] > 0: 
    _loc = patch_locs[indx][0:patch_n_stars[indx]].detach().cpu()
    # you have to scale and shift appropriately so the location plots in the right place. 
    # all locations, whether on the full image or the tile, is parameterized to be between 0 and 1
    _loc = _loc * (star_encoder.patch_slen - 2 * star_encoder.edge_padding) +  star_encoder.edge_padding - 0.5
    plt.scatter(_loc[:, 1], _loc[:, 0], marker = 'o', color = 'b')
    
# plot my estimated star
if patch_n_stars[indx] > 0: 
    _loc = loc_mean[indx][0:patch_n_stars[indx]].detach().cpu()
    # you have to scale and shift appropriately so the location plots in the right place. 
    # all locations, whether on the full image or the tile, is parameterized to be between 0 and 1
    _loc = _loc * (star_encoder.patch_slen - 2 * star_encoder.edge_padding) +  star_encoder.edge_padding - 0.5
    plt.scatter(_loc[:, 1], _loc[:, 0], marker = 'x', color = 'r')


## Test galaxy simulation

In [None]:
# data parameters
with open('../../data/default_galaxy_parameters.json', 'r') as fp:
    data_params = json.load(fp)
print(data_params)

In [None]:
n_images=100
sim_ds = simulated_datasets_lib.GalaxyDataset.load_dataset_from_params(100, data_params)

In [None]:
# background 
torch.mean(sim_ds.background[2])

In [None]:
from torch.distributions import Normal
p_z = Normal(torch.zeros((1,), device=device), torch.ones((1,), device=device))
z = p_z.rsample(torch.tensor([2, 8])).view(2, -1)
print(z)

In [None]:
#check that a couple of galaxies from the dataset look good. 
fig, axs = plt.subplots(3, 3, figsize=(15,15))
axes = axs.flatten()
for i, gal in enumerate(sim_ds.simulator.ds): 
    if i > 8: 
        break   
    ax = axes[i]
    im = ax.imshow(gal[2].cpu().numpy())
    fig.colorbar(im, ax=ax)

In [None]:
from torch.distributions import Normal
p_z = Normal(torch.zeros(1), torch.ones(1))
z = p_z.rsample(torch.tensor([2, 8])).view(2, -1)  # shape = (8,)
print(z.shape)
z

In [None]:
# test sampling works fine 
z, gals = sim_ds.simulator.ds.sample(2)
gals = gals.detach().cpu().numpy()
print(z.shape) 
print(gals.shape)
plt.imshow(gals[0][0])

In [None]:
# test sampling source params
n_galaxy = torch.Tensor([3]).cuda(device)
galaxy_params, single_galaxies = sim_ds.simulator.get_source_params(n_galaxy)

print(galaxy_params.shape, single_galaxies.shape)
plt.imshow(single_galaxies[0][0][0].cpu().numpy())
assert single_galaxies[0, 3:, ...].sum() == 0  #all galaxies should be zero after. 
galaxy_params[0, 0:6, ...] # the rest should be zero after the third 


In [None]:
# test sample parameters
n_sources, locs, params = sim_ds.simulator.sample_parameters(1)
gal_params, single_galaxies = params 
print(n_sources.shape, 
      locs.shape, 
      gal_params.shape, 
      single_galaxies.shape)

In [None]:
#test image drawing
images = sim_ds.simulator.draw_image_from_params(locs, n_sources, sources=single_galaxies.cuda(device))
print(images.shape)

print(n_sources)
plt.figure(figsize=(12,12))
plt.imshow(images[0][0].detach().cpu().numpy(), interpolation='None')
plt.colorbar()

## Smaller galaxy scene

In [None]:
# data parameters
with open('../../data/default_galaxy_parameters.json', 'r') as fp:
    data_params = json.load(fp)
data_params['min_galaxies'] = 1
data_params['mean_galaxies'] = 3 
data_params['max_galaxies'] = 5
print(data_params)

In [None]:
n_images=100
sim_ds = simulated_datasets_lib.GalaxyDataset.load_dataset_from_params(100, data_params)

In [None]:
#draw each of the images produced. 

In [None]:
sample = sim_ds[0]
plt.figure(figsize=(10,10))
image =sample['image'][2].detach().cpu().numpy() 
plt.imshow(image)

# Galaxy Encoder

## prepare the encoder and the simulated dataset 

In [None]:
import src.starnet_lib as starnet_lib

In [None]:
# patches may need to be larger now
# and corresponding padding parameters should change

gal_dim = 8 # number of latent dimensions. 
galaxy_encoder = starnet_lib.SourceEncoder(
                            slen = 100, 
                            n_bands = 1,
                            patch_slen = 20, 
                            step = 5, 
                            edge_padding = 5, 
                            max_detections = 2, 
                            n_source_params = gal_dim).cuda()

In [None]:
# get params and image 
n_images=100
# data parameters
with open('../../data/default_galaxy_parameters.json', 'r') as fp:
    data_params = json.load(fp)
data_params['min_galaxies'] = 0
data_params['mean_galaxies'] = 5
data_params['max_galaxies'] = 10
print(data_params)


sim_ds = simulated_datasets_lib.GalaxyDataset.load_dataset_from_params(128, data_params)

## sample parameters 

In [None]:
n_sources, locs, gal_params, single_galaxies = sim_ds.simulator.sample_parameters(1)
print(
n_sources.shape, 
locs.shape, 
gal_params.shape, 
single_galaxies.shape)

In [None]:
# single galaxies look good. 
plt.imshow(single_galaxies[0,2,0].cpu().numpy())

In [None]:
images =sim_ds.simulator.generate_images(locs, n_sources, single_galaxies)
images.shape

In [None]:
# looking at one scene looks good. 
plt.imshow(images[0,0].cpu().numpy())

## batch from dataset

In [None]:
batch = sim_ds.get_batch(1)
n_sources = batch['n_sources']
locs = batch['locs']
gal_params = batch['gal_params']
images = batch['images']
print(n_sources)

print(n_sources.shape, locs.shape, gal_params.shape, images.shape)

In [None]:
# draw the image with correct centers
plt.figure(figsize=(9,9))
plt.imshow(images[0,0].cpu().numpy())
locs_x = locs[0][:,0].cpu().numpy() * (sim_ds.slen-1)
locs_y = locs[0][:,1].cpu().numpy() * (sim_ds.slen-1)
plt.scatter(locs_y, locs_x, marker='x', color='red')

In [None]:
%matplotlib inline

#compare with decoder drawing directly which should be in the center of the image? 
plt.figure(figsize=(6,6))
plt.imshow(sim_ds.simulator.ds[0][0].cpu().numpy())
print(sim_ds.simulator.galaxy_slen)
plt.scatter((sim_ds.simulator.galaxy_slen-1)/2, (sim_ds.simulator.galaxy_slen-1)/2, marker='x', color='r')

# so probably because I move it around the central pixel 
#it's not eactly centered. 

## get patches

In [None]:
# # in which patches are they sources' centers. 
# print(patch_n_sources.nonzero())
# idx = patch_n_sources.nonzero()[0].item()
# print(idx)


image_patches, patch_locs, patch_galaxy_params, patch_n_sources, patch_is_on_array = \
    galaxy_encoder.get_image_patches(images, 
                                     locs = locs, 
                                      source_params = gal_params)


print(image_patches.shape, patch_locs.shape, patch_galaxy_params.shape, patch_n_sources.shape)

%matplotlib inline
#plot some patch
plt.matshow(image_patches[idx][0].cpu().numpy())

## Get variational parameters

In [None]:
var_params_all = galaxy_encoder.get_var_params_all(image_patches)
var_params_all.shape
#why 63? 

In [None]:
# get variational parameters for n_sources
log_probs_nsource_patch = galaxy_encoder.get_logprob_n_from_var_params(var_params_all)
log_probs_nsource_patch.shape 
#why 3  = 0,1,2 possible detections

In [None]:
# get variational parameters at true number of stars 
loc_mean, loc_logvar, \
    gal_param_mean, gal_param_logvar = \
        galaxy_encoder.get_var_params_for_n_sources(var_params_all, 
                                        # we clip at max detections
                                        n_sources=patch_n_sources.clamp(max = galaxy_encoder.max_detections))

# gal_param_mean has shape n_patches x max_detections x gal_dim
print(loc_mean.shape, gal_param_mean.shape)
#

In [None]:
loc_mean[2]
# if more than one galaxy was detected there would be more locations in the second row. 

## plot a tiles to check 

In [None]:
plt.close('all')

In [None]:
# this cell will be especially useful once we train the encoder. 
# same plot as before, but I also mark my estimated galaxy on each patch
# only using i band.
def plot_patch_examples(loc_mean, patch_n_sources, patch_locs, edge_padding, patch_slen):
    #locs_mean are the esimated ones from the encoder. 

    fig, axs = plt.subplots(5,4,figsize=(17,17))
    axes = axs.flatten()
    for i, indx in enumerate(patch_n_sources.nonzero()):
        if i >= len(axes): 
            break 

        indx = indx.item()
        n_band = 0
        ax = axes[i]
        ax.matshow(image_patches[indx, n_band].detach().cpu().numpy())

        # the neural network returns locations / fluxes for any stars in the center of the grid
        ax.axvline(x=edge_padding - 0.5, color = 'r')
        ax.axvline(x=patch_slen - edge_padding - 0.5, color = 'r')
        ax.axhline(y=edge_padding - 0.5, color = 'r')
        ax.axhline(y=patch_slen - edge_padding - 0.5, color = 'r')


        # lets mark any true sources in this particular tile 
        if patch_n_sources[indx] > 0: 
            _loc = patch_locs[indx][0:patch_n_sources[indx]].detach().cpu()
            # you have to scale and shift appropriately so the location plots in the right place. 
            # all locations, whether on the full image or the tile, is parameterized to be between 0 and 1
            _loc = _loc * (patch_slen - 2 * edge_padding) +  edge_padding - 0.5
            ax.scatter(_loc[:, 1], _loc[:, 0], marker = 'o', color = 'b')

        # plot the estimated ones.
        if patch_n_sources[indx] > 0: 
            _loc = loc_mean[indx][0:patch_n_sources[indx]].detach().cpu()
            # you have to scale and shift appropriately so the location plots in the right place. 
            # all locations, whether on the full image or the tile, is parameterized to be between 0 and 1
            _loc = _loc * (patch_slen - 2 * edge_padding) +  edge_padding - 0.5
            ax.scatter(_loc[:, 1], _loc[:, 0], marker = 'x', color = 'r')


In [None]:
%matplotlib inline
plot_patch_examples(loc_mean, patch_n_sources, patch_locs, galaxy_encoder.edge_padding, galaxy_encoder.patch_slen)

# Trained encoder 

In [None]:
# patches may need to be larger now
# and corresponding padding parameters should change

gal_dim = 8 # number of latent dimensions. 
trained_encoder = sourcenet_lib.SourceEncoder(
                            slen = 100, 
                            n_bands = 1,
                            ptile_slen = 20, 
                            step = 5, 
                            edge_padding = 5, 
                            max_detections = 2, 
                            n_source_params = gal_dim).cuda()

In [None]:
trained_encoder.load_state_dict(torch.load('/home/imendoza/deblend/DeblendingStarFields/reports/results_galaxy_2020-04-20/galaxy_i.dat',
                               map_location=lambda storage, loc: storage))
trained_encoder.eval(); 

In [None]:
## want to test it on the same scene/patches as above. 

loc_mean, loc_logvar, gal_param_mean, gal_param_logvar, log_probs_n = \
    trained_encoder.forward(image_patches, n_sources=patch_n_sources)

In [None]:
%matplotlib inline
plot_patch_examples(loc_mean, patch_n_sources, patch_locs, trained_encoder.edge_padding, trained_encoder.patch_slen)

In [None]:
image = 

In [None]:
locs, source_params, n_sources = trained_encoder.sample_encoder(image, n_samples=1, return_map_n_sources=True, return_map_source_params=True,)

