In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import inv_kl_objective_lib as objectives_lib

import kl_objective_lib 

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/3900/6/269/psField-003900-6-0269.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 31
data_params['min_stars'] = 200
data_params['max_stars'] = 200
data_params['alpha'] = 0.5

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_images = 4

star_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_images = n_images,
                            add_noise = True)

num_unlabeled = star_dataset.fluxes.shape[0]
print('num unlabeled', num_unlabeled)

In [None]:
# get loader 
batchsize = n_images

loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=False)

loader.dataset.set_params_and_images()

In [None]:
for _, data in enumerate(loader):
    true_full_fluxes = data['fluxes']
    true_full_locs = data['locs']
    full_images = data['image']
    full_backgrounds = data['background']
    
    break

In [None]:
for i in range(1): 
    plt.matshow(full_images[i, 0, :, :])

# Define VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'],
                                           stamp_slen = 7,
                                           step = 4,
                                           edge_padding = 1,
                                           n_bands = 1,
                                           max_detections = 3)


In [None]:
star_encoder.eval();

In [None]:
def set_bn_eval(model):
    classname = model.__class__.__name__
    if classname.find('BatchNorm') != -1:
        model.eval()

In [None]:
star_encoder.parameters

In [None]:
# check my reconstruction loss

In [None]:
full_backgrounds = torch.ones(full_images.shape) * data_params['sky_intensity']

In [None]:
map_loss, ps_loss = kl_objective_lib.get_kl_loss(star_encoder, 
                            full_images, 
                            full_backgrounds, 
                            star_dataset.simulator)

In [None]:
full_images.shape

In [None]:
ps_loss

In [None]:
image_stamps = star_encoder.get_image_stamps(full_images,
                                        locs = None,
                                        fluxes = None)[0]

assert full_backgrounds.shape == full_images.shape
background_stamps = star_encoder.get_image_stamps(full_backgrounds,
                                locs = None,
                                fluxes = None)[0]

In [None]:
image_stamps.shape

In [None]:
h = star_encoder._forward_to_last_hidden(image_stamps, background_stamps)

In [None]:
log_probs = star_encoder._get_logprobs_from_last_hidden_layer(h)
map_n_stars = log_probs.argmax(1)

In [None]:
map_n_stars

In [None]:
map_loss = kl_objective_lib.get_loss_cond_nstars(star_encoder, full_images, full_backgrounds, h,
                                    map_n_stars, star_dataset.simulator)

In [None]:
map_loss

In [None]:
seq_tensor = torch.Tensor([i for i in range(log_probs.shape[0])]).type(torch.LongTensor)
log_q = log_probs[seq_tensor, map_n_stars].view(full_images.shape[0], -1).sum(1)


In [None]:
log_q.shape

In [None]:
map_loss

In [None]:
mask = objectives_lib.get_one_hot_encoding_from_int(map_n_stars, star_encoder.max_detections + 1)
conditional_probs = torch.exp(log_probs) * (1 - mask)
conditional_probs = conditional_probs / conditional_probs.sum(1, keepdim = True)

In [None]:
mask.shape

In [None]:
log_probs.shape

In [None]:
n_stars_sampled = kl_objective_lib.sample_class_weights(conditional_probs).detach()

In [None]:
n_stars_sampled - map_n_stars