In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import sdss_psf

from astropy.io import fits
from astropy.wcs import WCS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os

In [None]:
# load data
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

In [None]:
# check the hubble coordinates overlap with the globular cluster
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze())
plt.plot(sdss_hubble_data.locs_x1, sdss_hubble_data.locs_x0, alpha = 0.2)

In [None]:
# check the counts matrix
plt.matshow(sdss_hubble_data.counts_mat)
plt.colorbar()

In [None]:
# these are the tiles we kept
plt.matshow((sdss_hubble_data.counts_mat > 0) & (sdss_hubble_data.counts_mat < sdss_hubble_data.max_detections))

In [None]:
# sample a few images 

In [None]:
for i in range(10): 
    indx = np.random.choice(len(sdss_hubble_data), 1)
    data = sdss_hubble_data[indx]

    plt.matshow(data['image'].squeeze())
    plt.plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
             data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'o', color = 'b')
    
    plt.title('image {}'.format(indx))

# The images shown in Portillos

In [None]:
sdss_hubble_data.plot_image_patch(x0 = 630 + 53, 
                                  x1 = 309 + 70, 
                                  slen = 10,
                                  flip_y = True)

In [None]:
sdss_hubble_data.plot_image_patch(x0 = 630 + 41, 
                                  x1 = 309 + 23, 
                                  slen = 10,
                                  flip_y = True)

In [None]:
sdss_hubble_data.plot_image_patch(x0 = 630 + 31, 
                                  x1 = 309 + 83, 
                                  slen = 10,
                                  flip_y = True)

In [None]:
sdss_hubble_data.plot_image_patch(x0 = 630 + 32, 
                                  x1 = 309 + 64, 
                                  slen = 10,
                                  flip_y = True)

# Check out pcat results

In [None]:
sdss_hubble_data.plot_image_patch(x0 = 630 + 53, 
                                  x1 = 309 + 70, 
                                  slen = 10,
                                  flip_y = True)

In [None]:
chain_results = np.load('../../multiband_pcat/pcat-lion-results/20190920-152141/chain.npz')

In [None]:
pcat_catalog = np.loadtxt('../../multiband_pcat/pcat-lion-results/20190920-152141/classical_catalog.txt')

In [None]:
pcat_catalog.shape

In [None]:
use_classical_catalogue = False

if use_classical_catalogue: 
    x1_loc = pcat_catalog[:, 0]
    x0_loc = pcat_catalog[:, 2]
else: 
    x1_loc = chain_results['x'][-300:, ]
    x0_loc = chain_results['y'][-300:, ]

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full[630:730, 309:409])

In [None]:
pcat_image = np.loadtxt('../../multiband_pcat/Data/idR-002583-2-0136/cts/idR-002583-2-0136-ctsr.txt')

In [None]:
plt.matshow(pcat_image)

In [None]:
foo = sdss_hubble_data.sdss_image_full[630:730, 309:409]

In [None]:
plt.matshow(foo[0:10, 0:10] / pcat_image[0:10, 0:10])
plt.colorbar()

In [None]:
plt.matshow(pcat_image[0:10, 0:10])

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full[630:640, 309:319])

In [None]:
x0 = 32
x1 = 64
slen = 10

which_pixels = (x0_loc > x0 - 0.5) & \
                    (x0_loc < (x0 + slen - 0.5)) & \
                (x1_loc > x1 - 0.5) & \
                    (x1_loc < (x1 + slen - 0.5))


sdss_hubble_data.plot_image_patch(x0 = 630 + x0, 
                                  x1 = 309 + x1, 
                                  slen = 10,
                                  flip_y = False)

plt.scatter(x1_loc[which_pixels] - x1, 
            x0_loc[which_pixels] - x0, 
            color = 'r', marker = 'x', alpha = 0.01)

# Check results on my simulated data

In [None]:
simulated_image = np.loadtxt('../../multiband_pcat/Data/sdss_simulated/cts/sdss_simulated-ctsr.txt')

In [None]:
plt.matshow(simulated_image)

In [None]:
chain_results_sim = np.load('../../multiband_pcat/pcat-lion-results/20190924-111329/chain.npz')

In [None]:
x1_loc = chain_results_sim['x'][-300:, ]
x0_loc = chain_results_sim['y'][-300:, ]

In [None]:
plt.matshow(simulated_image)

In [None]:
x0 = 40
x1 = 10
slen = 10

which_pixels = (x0_loc > x0 - 0.5) & \
                    (x0_loc < (x0 + slen - 0.5)) & \
                (x1_loc > x1 - 0.5) & \
                    (x1_loc < (x1 + slen - 0.5))

plt.matshow(simulated_image[x0:(x0+slen), x1:(x1+slen)])
            
plt.scatter((x1_loc[which_pixels] - x1).flatten(), 
            (x0_loc[which_pixels] - x0).flatten(), 
            color = 'r', marker = 'x', alpha = 0.01)

In [None]:
x1_loc[which_pixels]