# Preamble 

In [None]:
%load_ext autoreload
%autoreload 2
%aimport

In [None]:
import os 
import sys
path = os.path.abspath('../..')
if path not in sys.path: 
    sys.path.insert(0, path)
sys.path[0]

In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import fitsio 

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('torch version: ', torch.__version__)
print(device)

In [None]:
from celeste import psf_transform, utils
from celeste.datasets import simulated_datasets
from celeste.models import sourcenet

# Load dataset 

In [None]:
# this is the PSF I fitted using ground truth Hubble locations/fluxes. 
init_psf_params = torch.tensor(np.load('../../data/fitted_powerlaw_psf_params.npy'), device=device)
power_law_psf = psf_transform.PowerLawPSF(init_psf_params)
psf = power_law_psf.forward().detach()

# number of bands. Here, there are two. 
n_bands = psf.shape[0]
print(n_bands)
print(psf.shape)


In [None]:
# data parameters
with open('../../config/dataset_params/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)
data_params['max_stars'] = 10
data_params['mean_stars'] = 5
data_params['slen']= 20
data_params

In [None]:
# set background 
background = torch.zeros(n_bands, data_params['slen'], data_params['slen'], device=device)
background[0] = 686.
background[1] = 1123.

In [None]:
# draw data 
n_images = 32

simulated_dataset = \
    simulated_datasets.StarDataset.load_dataset_from_params(n_images,
                    data_params, psf,
                    background,
                    transpose_psf = False, 
                    add_noise = True, draw_poisson=True)

In [None]:
# get a nice batch 
images = simulated_dataset.get_batch(2)['images']
locs = simulated_dataset.get_batch(2)['locs']
log_fluxes =  simulated_dataset.get_batch(2)['log_fluxes']
images.shape
locs.shape
log_fluxes.shape

In [None]:
plt.imshow(images[0,0].cpu().numpy())

# Matrices from encoder 

## load encoder 

In [None]:
slen = data_params['slen'] 
patch_slen = 8 
step = 2
edge_padding = 3 
n_bands = 2 
max_detections = 2
n_source_params = 2

star_encoder = sourcenet.SourceEncoder(slen, patch_slen, step, edge_padding, n_bands, 
                                       max_detections, n_source_params).cuda()

In [None]:
# get image ptiles and the corresponding h matrix. 
image_ptiles, true_tile_locs , _ , true_tile_n_sources, true_tile_is_on_array = star_encoder.get_image_ptiles(images, locs, log_fluxes, clip_max_sources=True)
h = star_encoder._get_var_params_all(image_ptiles)
image_ptiles.shape

# take a note!  the tiles are not separated into batches, they are just added to the first shape here. 
# This makes sense, because for the encoder all tiles across all batches should be on the same footing. 

In [None]:
true_tile_is_on_array.shape

## investigate shapes 

In [None]:
print(image_ptiles.shape  )
print(h.shape)  # the first shape[0] corresponds to number of ptiles. 

In [None]:
# shape[0] correponds to (max_detections + 1)
# shape[1] corresponds to (max_detections * len(x,y))
star_encoder.locs_mean_indx_mat


In [None]:
#indices for probably of there being 0 , 1 or 2 objects. (over-parametrized)
star_encoder.prob_indx

In [None]:
## what happens when we apply softmax 
print(h[:, star_encoder.prob_indx].shape)
star_encoder._get_logprob_n_from_var_params(h).shape

## => retains the tile shape and the other one too. 

In [None]:
log_probs_n = star_encoder._get_logprob_n_from_var_params(h)
n_sources = torch.argmax(log_probs_n, dim=1)
n_sources.shape

In [None]:
(loc_mean, loc_logvar, source_param_mean, source_param_logvar) = star_encoder._get_var_params_for_n_sources(h, n_sources)
print(loc_mean.shape)
print(source_param_mean.shape)
# shape is n_ptiles x max_detections x len(x,y)

In [None]:
(loc_mean,
loc_logvar,
source_param_mean,
source_param_logvar,
bernoulli_param,
log_probs_n_sources_per_tile) = star_encoder.forward(image_ptiles, n_sources=true_tile_n_sources)

In [None]:
bernoulli_param.shape

# Understanding indexing in sleep phase 