In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib

import fitsio
from astropy.io import fits
from astropy.wcs import WCS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os

# Test my simulator on a small image

In [None]:
psf_dir = '../../multiband_pcat/Data/idR-002583-2-0136/psfs/'

In [None]:
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_g = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-g.fits')[0].read()

In [None]:
# psf_og = np.array([psf_r, psf_g])
psf_og = np.array([psf_r])

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf = psf_og, 
                                    slen = 11, 
                                    sky_intensity = torch.Tensor([686.]))

In [None]:
# the psf

In [None]:
f, axarr = plt.subplots(1, 2, figsize=(8, 4))

im0 = axarr[0].matshow(simulator.psf_og[0])
f.colorbar(im0, ax=axarr[0])

im1 = axarr[1].matshow(simulator.psf_og[1])
f.colorbar(im0, ax=axarr[1])

In [None]:
locs = torch.Tensor([[[0.2, 0.2], [0.4, 0.4], [0.6, 0.6]]])
fluxes = torch.Tensor([[[2000, 8000], [4000, 4000], [8000, 2000]]])
n_stars = torch.Tensor([3]).type(torch.long)

In [None]:
out = simulator.draw_image_from_params(locs, fluxes, n_stars)

In [None]:
f, axarr = plt.subplots(1, 2, figsize=(8, 4))

im0 = axarr[0].matshow(out[0, 0])
f.colorbar(im0, ax=axarr[0])

im0 = axarr[1].matshow(out[0, 1])
f.colorbar(im0, ax=axarr[1])

# Check psf is placed on image correctly

In [None]:
out = simulator.draw_image_from_params(locs = torch.Tensor([[[0.5, 0.5]]]), 
                                       fluxes = torch.Tensor([[[10000, 10000]]]), 
                                       n_stars = torch.Tensor([1.0]).type(torch.long), 
                                      add_noise = False) - simulator.sky_intensity

In [None]:
plt.matshow(out[0, 0].squeeze() - simulator.psf[0] * 10000 )
plt.colorbar()

In [None]:
plt.matshow(out[0, 1].squeeze() - simulator.psf[1] * 10000 )
plt.colorbar()

# Check my dataset

In [None]:
import json

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

print(data_params)


In [None]:
n_images = 5

In [None]:
star_dataset = simulated_datasets_lib.load_dataset_from_params(psf_og, data_params, n_images,
                                sky_intensity = torch.Tensor([686.]), 
                                add_noise = True)

In [None]:
star_dataset.n_stars

In [None]:
star_dataset.images.shape

In [None]:
star_dataset.fluxes.shape

In [None]:
plt.hist(star_dataset.locs.flatten()[star_dataset.locs.flatten() > 0])

In [None]:
star_dataset.fluxes.shape

In [None]:
plt.hist(torch.log10(star_dataset.fluxes[:, :, 0].flatten()[star_dataset.fluxes[:, :, 0].flatten() > 0]))

In [None]:
plt.hist(torch.log10(star_dataset.fluxes[:, :, 1].flatten()[star_dataset.fluxes[:, :, 1].flatten() > 0]))

In [None]:
which_on = star_dataset.fluxes[:, :, 1].flatten() > 0

color = (torch.log10(star_dataset.fluxes[:, :, 1].flatten()) - \
    torch.log10(star_dataset.fluxes[:, :, 0].flatten())) * (-2.5)
plt.hist(color[which_on], bins = 100); 

In [None]:
color[which_on].mean()

In [None]:
color[which_on].var().sqrt()

In [None]:
plt.matshow(star_dataset.images[0, 0, :, :])

In [None]:
plt.matshow(star_dataset.images[0, 1, :, :])

In [None]:
foo = (star_dataset.images[0, 1, :, :] - star_dataset.sky_intensity[1]) - \
           (star_dataset.images[0, 0, :, :] - star_dataset.sky_intensity[0])
    
plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
plt.colorbar()

In [None]:
star_dataset.simulator.psf[0].abs().max()

In [None]:
star_dataset.simulator.psf[1].abs().max()

In [None]:
star_dataset.n_stars

In [None]:
batchsize = 2

loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize)


In [None]:
for _, data in enumerate(loader): 
    foo = data['background']
    
    break

In [None]:
foo.shape

In [None]:
plt.matshow(foo[0, 1])