In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import starnet_lib
import inv_kl_objective_lib as inv_kl_lib
import image_statistics_lib

import plotting_utils

np.random.seed(34534)

# Load the data

In [None]:
fmin = 1000

In [None]:
bands = [2, 3]

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = bands)

# image 
full_image = sdss_hubble_data.sdss_image
full_background = sdss_hubble_data.sdss_background

# true parameters
which_bright = (sdss_hubble_data.fluxes[:, 0] > fmin)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
plt.scatter(sdss_hubble_data.hubble_color, 
            sdss_hubble_data.hubble_rmag, 
            marker = 'x', alpha = 0.1)

# Load Portillos results

In [None]:
results_dir = '../../multiband_pcat/pcat-lion-results/20191116-135326/'

chain_results = np.load(results_dir + 'chain.npz')

In [None]:
# n bands 
chain_results['f'].shape

In [None]:
# fudge_factor = 1 / (1 - 0.83)
fudge_factor = sdss_hubble_data.sdss_data[0]['gain'][0] 

In [None]:
include_classical_catalogue = False

if include_classical_catalogue: 
    pcat_catalog = np.loadtxt(results_dir + 'classical_catalog.txt')
    
    x1_loc = pcat_catalog[:, 0]
    x0_loc = pcat_catalog[:, 2]
        
    fluxes = pcat_catalog[:, 4] * fudge_factor
    
    # remove na
    is_na = (fluxes < fmin) | np.isnan(fluxes)
    
    x1_loc = x1_loc[~is_na]
    x0_loc = x0_loc[~is_na]
    fluxes = fluxes[~is_na]
    
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (full_image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes).unsqueeze(-1)
    
else: 
    # just take one sample 
    fluxes = chain_results['f'][:, -1, ].transpose() * fudge_factor
    
    x1_loc = chain_results['x'][-1, ].flatten()[fluxes[:, 0] > fmin]
    x0_loc = chain_results['y'][-1, ].flatten()[fluxes[:, 0] > fmin]
    
    fluxes = fluxes[fluxes[:, 0] > fmin]
        
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (full_image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes) 
    

# x1_loc_samples = chain_results['x'][-300:, ].flatten()
# x0_loc_samples = chain_results['y'][-300:, ].flatten()

# portillos_est_fluxes_sampled = torch.Tensor(chain_results['f'][0, -300:, ].flatten()) * fudge_factor
# portillos_est_locs_sampled = torch.Tensor([x0_loc_samples, x1_loc_samples]).transpose(0,1) \
#                                 / (full_image.shape[-1] - 1)
    
# # filter by fmin
# port_which_bright = portillos_est_fluxes_sampled > fmin
# portillos_est_fluxes_sampled = portillos_est_fluxes_sampled[port_which_bright]
# portillos_est_locs_sampled = portillos_est_locs_sampled[port_which_bright]

In [None]:
def convert_fluxes_to_mag(fluxes, nelect_per_nmgy, fudge_conversion): 
    fluxes_nmgy = fluxes / (nelect_per_nmgy * fudge_conversion)
    return 22.5 - torch.log10(fluxes_nmgy) * 2.5 

In [None]:
portillos_mag = convert_fluxes_to_mag(portillos_est_fluxes, sdss_hubble_data.nelec_per_nmgy_full.mean(), 
                                     sdss_hubble_data.fudge_conversion)

In [None]:
plt.scatter(portillos_mag[:, 1] - portillos_mag[:, 0], 
           portillos_mag[:, 0], marker = 'x', alpha = 0.1)

In [None]:
plt.scatter(sdss_hubble_data.hubble_color[sdss_hubble_data.hubble_rmag < portillos_mag[:, 0].max()], 
            sdss_hubble_data.hubble_rmag[sdss_hubble_data.hubble_rmag < portillos_mag[:, 0].max()], 
            marker = 'x', alpha = 0.1)

# My starnet result

In [None]:
star_encoder1 = starnet_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 7,
                                            step = 2,
                                            edge_padding = 2, 
                                            n_bands = len(bands),
                                            max_detections = 2)

In [None]:
star_encoder1.load_state_dict(torch.load('../fits/results_11202019/starnet_ri', 
                               map_location=lambda storage, loc: storage))


star_encoder1.eval(); 


In [None]:
# get parameters on the full image 
# map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
#     star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
#                                            full_background.unsqueeze(0).unsqueeze(0))

map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
    star_encoder1.sample_star_encoder(full_image.unsqueeze(0), 
                                    full_background.unsqueeze(0), 
                                    return_map = True)[0:3]

In [None]:
my_mag = convert_fluxes_to_mag(map_fluxes_full_image.squeeze(0), sdss_hubble_data.nelec_per_nmgy_full.mean(), 
                                     sdss_hubble_data.fudge_conversion)

In [None]:
plt.scatter(my_mag[:, 1] - my_mag[:, 0], 
           my_mag[:, 0], marker = 'x', alpha = 0.1)

In [None]:
portillos_mag

In [None]:
plt.hist(my_mag[:, 0] - my_mag[:, 1])
plt.hist(portillos_mag[:, 0] - portillos_mag[:, 1], alpha = 0.5)