In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import starnet_lib
import sleep_lib
import image_statistics_lib

import plotting_utils

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

np.random.seed(34534)

# Load the data

In [None]:
fmin = 1000

In [None]:
bands = [2, 3]

In [None]:
x0 = 630
x1 = 310
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = x0,
                                                    x1 = x1, 
                                                    bands = bands)

# image 
image = sdss_hubble_data.sdss_image 

# true parameters
which_bright = (sdss_hubble_data.fluxes[:, 0] > fmin)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
true_locs.shape

In [None]:
image = torch.Tensor(image)
print(image.shape)

In [None]:
plt.matshow(image[0])
plt.colorbar()

In [None]:
sdss_hubble_data.nelec_per_nmgy.mean()

# Get simulator 

In [None]:
import psf_transform_lib
import wake_lib

In [None]:
# psf fitted with ground truth
init_psf_params = torch.Tensor(np.load('../data/fitted_powerlaw_psf_params.npy'))
power_law_psf = psf_transform_lib.PowerLawPSF(init_psf_params.to(device))
psf_og = power_law_psf.forward().detach()

In [None]:
# background fitted with ground truth 
init_background_params = torch.Tensor(np.load('../data/fitted_planar_backgrounds.npy'))
planar_background = wake_lib.PlanarBackground(image_slen = image.shape[-1], 
                            init_background_params = init_background_params.to(device))
background = planar_background.forward().detach()


In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_og, 
                                                slen = image.shape[-1], 
                                                transpose_psf = False, 
                                                background = background)


# Simulation with ground truth

In [None]:
truth_recon = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                            fluxes = sdss_hubble_data.fluxes.unsqueeze(0),
                            n_stars = torch.Tensor([len(sdss_hubble_data.fluxes)]).type(torch.LongTensor), 
                            add_noise = False).squeeze(0)

In [None]:
for i in range(len(sdss_hubble_data.bands)): 
    foo = (truth_recon[i] - image[i]) / image[i]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
f, axarr = plt.subplots(1, 5, figsize=(12, 6))

for i in range(5): 
    plotting_utils.plot_subimage(axarr[i], image[0],
                                 true_locs, 
                                 None, 
                                 x0 = int(np.random.choice(100, 1)), 
                                 x1 = int(np.random.choice(100, 1)), 
                                 patch_slen = 7)
    axarr[i].set_xticks([]);
    axarr[i].set_yticks([]);
    


# Load results

In [None]:
results_dir = '../pcat_results/20200123-122354/'

chain_results = np.load(results_dir + 'chain.npz')

In [None]:
# n bands 
chain_results['f'].shape

In [None]:
fudge_factor = sdss_hubble_data.sdss_data[0]['gain'][0] 

In [None]:
include_classical_catalogue = False

if include_classical_catalogue: 
    pcat_catalog = np.loadtxt(results_dir + 'classical_catalog.txt')
    
    x1_loc = pcat_catalog[:, 0]
    x0_loc = pcat_catalog[:, 2]
        
    fluxes = pcat_catalog[:, 4] * fudge_factor
    
    # remove na
    is_na = np.isnan(x1_loc) | np.isnan(x1_loc) | np.isnan(fluxes) | (fluxes < fmin)
    
    x1_loc = x1_loc[~is_na]
    x0_loc = x0_loc[~is_na]
    fluxes = fluxes[~is_na]
    
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes).unsqueeze(1)
else: 
    # just take one sample 
    fluxes = chain_results['f'][:, -1, ].transpose() * fudge_factor
    
    x1_loc = chain_results['x'][-1, ].flatten()[fluxes[:, 0] > fmin]
    x0_loc = chain_results['y'][-1, ].flatten()[fluxes[:, 0] > fmin]
    
    fluxes = fluxes[fluxes[:, 0] > fmin]
        
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes) 
    

# x1_loc_samples = chain_results['x'][-300:, ].flatten()
# x0_loc_samples = chain_results['y'][-300:, ].flatten()

# portillos_est_fluxes_sampled = torch.Tensor(chain_results['f'][0, -300:, ].flatten()) * fudge_factor
# portillos_est_locs_sampled = torch.Tensor([x0_loc_samples, x1_loc_samples]).transpose(0,1) \
#                                 / (image.shape[-1] - 1)
    
# # filter by fmin
# port_which_bright = portillos_est_fluxes_sampled > fmin
# portillos_est_fluxes_sampled = portillos_est_fluxes_sampled[port_which_bright]
# portillos_est_locs_sampled = portillos_est_locs_sampled[port_which_bright]

# get reconstruction mean 

In [None]:
_locs = portillos_est_locs.unsqueeze(0) 
_fluxes = portillos_est_fluxes.unsqueeze(0)
_n_stars = torch.Tensor([len(x0_loc)]).type(torch.LongTensor)

portillos_recon_mean = simulator.draw_image_from_params(locs = _locs, 
                                            fluxes = _fluxes,
                                             n_stars = _n_stars,  
                                             add_noise = False).squeeze(0)

plt.matshow(portillos_recon_mean[0]); 
plt.colorbar()

In [None]:
portillos_residuals = portillos_recon_mean - image

for i in range(portillos_recon_mean.shape[0]): 
    foo = (portillos_residuals[i] / image[i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
    plt.colorbar()

# Plot subimages

In [None]:
band = 0

In [None]:
subimage_slen = 10

# possible coordinates
x0_vec = np.arange(10, 90, subimage_slen)
x1_vec = np.arange(10, 90, subimage_slen)

In [None]:
x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

fig, axarr = plt.subplots(1, 4, figsize=(15, 4))

# posterior samples
# plotting_utils.plot_subimage(axarr[0], image[band], 
#                              portillos_est_locs_sampled, 
#                              true_locs, 
#                              x0, x1, subimage_slen)
# axarr[0].set_title('observed; coords: {}\n'.format([x0, x1]));

# condensed catalog
plotting_utils.plot_subimage(axarr[1], image[band],
                             portillos_est_locs, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[1].set_title('observed; coords: {}\n'.format([x0, x1]));

# reconstruction
plotting_utils.plot_subimage(axarr[2], portillos_recon_mean[band],
                             portillos_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[2].set_title('reconstructed\n');

# residuals
plotting_utils.plot_subimage(axarr[3], portillos_residuals[band] / image[band], 
                            portillos_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True)

axarr[3].set_title('reconstructed\n');



# Compare with my NN 

In [None]:
star_encoder = starnet_lib.StarEncoder(slen = 100,
                                            patch_slen = 8,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = len(bands),
                                            max_detections = 2)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/results_2020-02-27/starnet_ri', 
                               map_location=lambda storage, loc: storage))


star_encoder.eval(); 


In [None]:
# get parameters on the full image 
# map_locs_image, map_fluxes_image, map_n_stars_full = \
#     star_encoder.get_results_on_image(image.unsqueeze(0).unsqueeze(0), 
#                                            full_background.unsqueeze(0).unsqueeze(0))

map_locs_image, map_fluxes_image, map_n_stars = \
    star_encoder.sample_star_encoder(image.unsqueeze(0), 
                                    return_map_n_stars = True, 
                                    return_map_star_params = True)[0:3]

In [None]:
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_image, 
                                                fluxes = map_fluxes_image,
                                                 n_stars = map_n_stars, 
                                                 add_noise = False).squeeze(0)

vae_residuals = vae_recon_mean - image

In [None]:
my_est_locs = map_locs_image.squeeze(0) 
my_est_fluxes = map_fluxes_image.squeeze(0)

In [None]:
band = 0

In [None]:
image.shape

In [None]:
vae_recon_mean

In [None]:
x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

# my catalog
plotting_utils.plot_subimage(axarr[0], image[band], my_est_locs, true_locs, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0].set_title('observed; coords: {}\n'.format([x0, x1]));

# reconstruction
plotting_utils.plot_subimage(axarr[1], vae_recon_mean[band], my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[1].set_title('reconstructed\n');

# residuals
vmax = torch.abs((vae_residuals / image)[band, x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
plotting_utils.plot_subimage(axarr[2], (vae_residuals / image)[band], 
                            my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[2].set_title('residuals\n');



# Checkout MSEs

In [None]:
len(my_est_fluxes)

In [None]:
len(portillos_est_fluxes)

In [None]:
len(true_fluxes)

In [None]:
# reconstructions 
fig, axarr = plt.subplots(1, 3, figsize=(15, 12))

axarr[0].matshow(vae_recon_mean[band])
axarr[1].matshow(portillos_recon_mean[band])
axarr[2].matshow(truth_recon[band])

In [None]:
band = 1

In [None]:
# check out MSEs
_image = image[band, 10:90, 10:90] 

_my_residual = (vae_recon_mean[band, 10:90, 10:90] - _image) / _image
_portillos_residual = (portillos_recon_mean[band, 10:90, 10:90] - _image) / _image
_true_residual = (truth_recon[band, 10:90, 10:90] - _image) / _image

print('my_mse: ', torch.mean(_my_residual.abs()))
print('portillos_mse: ', torch.mean(_portillos_residual.abs()))
print('truth_mse: ', torch.mean(_true_residual.abs()))

fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

im1 = axarr[0].matshow(_my_residual, 
                       vmin = -_my_residual.abs().max(), 
                       vmax = _my_residual.abs().max(), 
                      cmap=plt.get_cmap('bwr'))
im2 = axarr[1].matshow(_portillos_residual, 
                       vmin = -_portillos_residual.abs().max(), 
                       vmax = _portillos_residual.abs().max(), 
                      cmap=plt.get_cmap('bwr'))
im3 = axarr[2].matshow(_true_residual,
                       vmin = -_true_residual.abs().max(), 
                       vmax = _true_residual.abs().max(), 
                      cmap=plt.get_cmap('bwr'))

fig.colorbar(im1, ax = axarr[0])
fig.colorbar(im2, ax = axarr[1])
fig.colorbar(im3, ax = axarr[2])

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(20, 5))

axarr[0].hist((_my_residual / _image).flatten(), bins = 100);

axarr[1].hist((_portillos_residual / _image).flatten(), bins = 100);

axarr[2].hist((_true_residual / _image).flatten(), bins = 100);

# Compare

In [None]:
fig, axarr = plt.subplots(2, 3, figsize=(15, 12))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

###################
# Plot catalogs
##################
# my catalog
for j in range(2):
    plotting_utils.plot_subimage(axarr[j, 0], image[band], my_est_locs, true_locs, x0, x1, subimage_slen, 
                                add_colorbar = True, global_fig = fig)
    axarr[j, 0].set_title('observed; coords: {}\n'.format([x0, x1]));

    # portillos catalogue
    _portillos_est_locs = portillos_est_locs * (image.shape[-1] - 1)
    which_locs = (_portillos_est_locs[:, 0] > x0) & \
                    (_portillos_est_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                    (_portillos_est_locs[:, 1] > x1) & \
                    (_portillos_est_locs[:, 1] < (x1 + subimage_slen - 1))
    portillos_locs = (_portillos_est_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
    axarr[j, 0].scatter(portillos_locs[:, 1], portillos_locs[:, 0], color = 'c', marker = 'x')

#######################
# Reconstructions 
#######################
# my reconstruction
plotting_utils.plot_subimage(axarr[0, 1], vae_recon_mean[band], my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0, 1].set_title('vae reconstructed\n');

# Portillos reconstruction
plotting_utils.plot_subimage(axarr[1, 1], portillos_recon_mean[band], 
                             portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig, 
                            color = 'c', marker = 'x')
axarr[1, 1].set_title('portillos reconstructed\n');

######################
# residuals
######################
# my residuals
plotting_utils.plot_subimage(axarr[0, 2], (vae_residuals / image)[band], 
                            my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True)

axarr[0, 2].set_title('vae residuals\n');



# portillos residuals
plotting_utils.plot_subimage(axarr[1, 2], (portillos_residuals / image)[band], 
                            portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            color = 'c', marker = 'x')

axarr[1, 2].set_title('portillos residuals\n');

# Get summary statistics

These are rather coarse measures. My tpr does not take into account the fact that several true stars might be matched with just one estimated star (so not all the true stars were detected); conversely my true positive rate does not take into account that several estimated stars might be matched with just one true star (so only one estimated star is a true positive). 

I tried the Hungarian algorithm to find a minimal matching, but this gave weird results because we're searching for a permutation that minimizes the **global** cost of the matching. 

In [None]:
# tpr and ppv using locations only
# my_tpr, my_ppv, _, _ = \
#     image_statistics_lib.get_summary_stats(my_est_locs, true_locs, 
#                                            image.shape[-1], None, None)
# portillos_tpr, portillos_ppv, _, _ = \
#     image_statistics_lib.get_summary_stats(portillos_est_locs, true_locs, 
#                                            image.shape[-1], None, None)

    
# print('my tpr: {:0.3f}'.format(my_tpr))
# print('portillos tpr: {:0.3f}\n'.format(portillos_tpr))

# print('my true positive rate: {:0.3f}'.format(my_ppv))
# print('portillos true positive rate: {:0.3f}'.format(portillos_ppv))

In [None]:
# take into account fluxes
my_tpr, my_ppv = \
    image_statistics_lib.get_summary_stats(my_est_locs, 
                                           true_locs, 
                                           star_encoder.slen, 
                                           my_est_fluxes[:, 0], 
                                           true_fluxes[:, 0], 
                                        sdss_hubble_data.nelec_per_nmgy.mean())[0:2]
    
portillos_tpr, portillos_ppv, portillos_complete_bool, portillos_ppv_bool = \
    image_statistics_lib.get_summary_stats(portillos_est_locs, 
                                           true_locs, 
                                           image.shape[-1], 
                                           portillos_est_fluxes[:, 0], 
                                           true_fluxes[:, 0], 
                                          sdss_hubble_data.nelec_per_nmgy.mean())

    
print('my tpr: {:0.3f}'.format(my_tpr))
print('portillos tpr: {:0.3f}\n'.format(portillos_tpr))

print('my ppv: {:0.3f}'.format(my_ppv))
print('portillos ppv: {:0.3f}'.format(portillos_ppv))

In [None]:
import numpy as np

In [None]:
nelec_per_nmgy = sdss_hubble_data.nelec_per_nmgy.mean()

In [None]:
true_mags = sdss_dataset_lib.convert_nmgy_to_mag(true_fluxes[:, 0] / nelec_per_nmgy)
percentiles = np.linspace(0, 1, 10) * 100
mag_vec = np.percentile(true_mags, percentiles)

my_tpr_vec, my_mag_vec,  = \
    image_statistics_lib.get_tpr_vec(my_est_locs, 
                                              true_locs, 
                                              image.shape[-1],
                                              my_est_fluxes[:, 0], 
                                              true_fluxes[:, 0], 
                                             sdss_hubble_data.nelec_per_nmgy.mean(), 
                                             mag_vec = mag_vec)[0:2]

portillos_tpr_vec, portillos_mag_vec = \
    image_statistics_lib.get_tpr_vec(portillos_est_locs, 
                                              true_locs, 
                                              image.shape[-1],
                                              portillos_est_fluxes[:, 0], 
                                              true_fluxes[:, 0], 
                                             sdss_hubble_data.nelec_per_nmgy.mean(), 
                                             mag_vec = mag_vec)[0:2]

# plt.plot(my_mag_vec[0:-1], my_tpr_vec, '--x', label = 'Starnet')
# plt.plot(portillos_mag_vec[0:-1], portillos_tpr_vec, '--x', label = 'Portillos')
plt.plot(percentiles[:-1], my_tpr_vec, '--x', label = 'Starnet')
plt.plot(percentiles[:-1], portillos_tpr_vec, '--x', label = 'Portillos')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('tpr')

In [None]:
my_mags = sdss_dataset_lib.convert_nmgy_to_mag(my_est_fluxes[:, 0] / nelec_per_nmgy)
portillos_mags = sdss_dataset_lib.convert_nmgy_to_mag(portillos_est_fluxes[:, 0] / nelec_per_nmgy)

my_ppv_vec, my_mag_vec, my_counts = \
    image_statistics_lib.get_ppv_vec(my_est_locs, 
                                              true_locs, 
                                              image.shape[-1],
                                              my_est_fluxes[:, 0], 
                                              true_fluxes[:, 0], 
                                    sdss_hubble_data.nelec_per_nmgy.mean(), 
                                    mag_vec = np.percentile(my_mags, percentiles))

portillos_ppv_vec, portillos_mag_vec, portillos_counts = \
    image_statistics_lib.get_ppv_vec(portillos_est_locs, 
                                              true_locs, 
                                              image.shape[-1],
                                              portillos_est_fluxes[:, 0], 
                                              true_fluxes[:, 0], 
                                    sdss_hubble_data.nelec_per_nmgy.mean(), 
                                    mag_vec = np.percentile(portillos_mags, percentiles))

# plt.plot(my_mag_vec[0:-1], my_ppv_vec, '--x', label = 'Starnet')
# plt.plot(portillos_mag_vec[0:-1], portillos_ppv_vec, '--x', label = 'Portillos')
plt.plot(percentiles[0:-1], my_ppv_vec, '--x', label = 'Starnet')
plt.plot(percentiles[0:-1], portillos_ppv_vec, '--x', label = 'Portillos')

plt.legend()
plt.xlabel('estimated log flux')
plt.ylabel('ppv')

In [None]:
fig, axarr = plt.subplots(3, 2, figsize=(16, 24))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))
subimage_slen = 10

##########################
# PLOT STARS CAUGHT BY ME
##########################
plotting_utils.plot_subimage(axarr[0, 0], image[band], 
                             my_est_locs, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                             global_fig = fig, 
                            color = 'b')
axarr[0, 0].set_title('my results; coords: {}\n'.format([x0, x1]));

# true locations that I missed
_locs = true_locs[my_complete_bool == 0] * (image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'o')

# estimated locations that were false
_locs = my_est_locs[my_ppv_bool == 0] * (image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'x')

#####################
# PLOT STARS CAUGHT BY PORTILLOS
####################
plotting_utils.plot_subimage(axarr[0, 1], image[band], 
                             portillos_est_locs, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                             global_fig = fig, 
                            color = 'b')
axarr[0, 1].set_title('portillos results; coords: {}\n'.format([x0, x1]));

# true locations that I missed
_locs = true_locs[portillos_complete_bool == 0] * (image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 1].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'o')

# estimated locations that were false
_locs = portillos_est_locs[portillos_ppv_bool == 0] * (image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 1].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'x')

##########################
# PLOT STARS CAUGHT BY ONLY ME
##########################
plotting_utils.plot_subimage(axarr[1, 0], image[band], 
                             my_est_locs, 
                             true_locs[(my_complete_bool == 1) & (portillos_complete_bool == 0)], 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)



##########################
# PLOT STARS CAUGHT BY Portillos
##########################
plotting_utils.plot_subimage(axarr[1, 1], image[band], 
                             portillos_est_locs, 
                             true_locs[(my_complete_bool == 0) & (portillos_complete_bool == 1)], 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)

##########################
# RECONSTRUCTIONS
##########################
plotting_utils.plot_subimage(axarr[2, 0], vae_recon_mean[band], 
                             my_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)
axarr[2, 0].set_title('my reconstruction')
plotting_utils.plot_subimage(axarr[2, 1], portillos_recon_mean[band], 
                             portillos_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)
axarr[2, 0].set_title('portillos reconstruction')


##########################
# RESIDUALS
##########################
# _resid = torch.log(vae_recon_mean / image)
# _resid = (vae_recon_mean - image)/image
# vmax = torch.abs(_resid[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
# plotting_utils.plot_subimage(axarr[2, 0], _resid, 
#                             my_est_locs, None, x0, x1, subimage_slen, 
#                             add_colorbar = True, global_fig = fig,
#                             diverging_cmap = True, 
#                             vmax = vmax, vmin = -vmax)

# axarr[2, 0].set_title('my residuals\n');

# # _resid = torch.log(portillos_recon_mean / image)
# _resid = (portillos_recon_mean - image)/image
# vmax = torch.abs(_resid[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
# plotting_utils.plot_subimage(axarr[2, 1], _resid, 
#                             portillos_est_locs, None, x0, x1, subimage_slen, 
#                             add_colorbar = True, global_fig = fig,
#                             diverging_cmap = True, 
#                             vmax = vmax, vmin = -vmax)

# axarr[2, 1].set_title('Portillos residuals\n');
