In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib

import fitsio
from astropy.io import fits
from astropy.wcs import WCS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os

# Get PSF

In [None]:
psf_dir = '../data/'

In [None]:
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

In [None]:
psf_og = torch.Tensor(np.array([psf_r, psf_i]))

In [None]:
n_bands = psf_og.shape[0]
print(psf_og.shape)

# Get simulator

In [None]:
slen = 10

In [None]:
background = torch.ones(psf_og.shape[0], slen, slen)
background[0] = 686.
background[1] = 1123.

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf = psf_og, 
                                                    slen = slen, 
                                                    background = background, 
                                                    transpose_psf = False)

In [None]:
simulator.psf.shape

# the psf

In [None]:
f, axarr = plt.subplots(1, 2, figsize=(8, 4))

im0 = axarr[0].matshow(simulator.psf_og[0])
f.colorbar(im0, ax=axarr[0])

im1 = axarr[1].matshow(simulator.psf_og[1])
f.colorbar(im0, ax=axarr[1])

# Check a simulation

In [None]:
for i in range(10):
    locs = torch.Tensor([[[1/(slen - 1) * i, 0.2]]])
    fluxes = torch.Tensor([[[80000, 8000]]])

    n_stars = torch.Tensor([1]).type(torch.long)
    
    out = simulator.draw_image_from_params(locs, fluxes, n_stars)
    
    f, axarr = plt.subplots(1, 2, figsize=(8, 4))    

    im0 = axarr[0].matshow(out[0, 0])
    f.colorbar(im0, ax=axarr[0])

    im0 = axarr[1].matshow(out[0, 1])
    f.colorbar(im0, ax=axarr[1])

In [None]:
locs = torch.Tensor([[[0.8, 0.2], [0.4, 0.4], [0.6, 0.6]]])
fluxes = torch.Tensor([[[20000, 8000], [4000, 4000], [8000, 2000]]])

n_stars = torch.Tensor([3]).type(torch.long)

out = simulator.draw_image_from_params(locs, fluxes, n_stars)

f, axarr = plt.subplots(1, 2, figsize=(8, 4))

im0 = axarr[0].matshow(out[0, 0])
f.colorbar(im0, ax=axarr[0])

im0 = axarr[1].matshow(out[0, 1])
f.colorbar(im0, ax=axarr[1])

# Check psf is placed on image correctly

In [None]:
x0 = 6
out = simulator.draw_image_from_params(locs = torch.Tensor([[[x0/(slen -1), x0/(slen -1)]]]), 
                                       fluxes = torch.Tensor([[[10000, 10000]]]), 
                                       n_stars = torch.Tensor([1.0]).type(torch.long), 
                                      add_noise = False) - simulator.background.unsqueeze(0)

In [None]:
psf_plotted = out[0, 0, (x0 - 2):(x0 + 3), (x0 - 2):(x0 + 3)]
plt.matshow(psf_plotted)

In [None]:
plt.matshow(psf_plotted - simulated_datasets_lib._trim_psf(simulator.psf, 5)[0] * 10000 )
plt.colorbar()

In [None]:
plt.matshow(psf_plotted - simulated_datasets_lib._trim_psf(simulator.psf, 5)[1] * 10000)
plt.colorbar()

# Check my dataset

In [None]:
import json

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 100
print(data_params)


In [None]:
n_images = 5

In [None]:
background = torch.zeros(psf_og.shape[0], data_params['slen'], data_params['slen'])
background[0] = 686
background[1] = 1123

In [None]:
star_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                            data_params,
                            background = background,
                            n_images = n_images,
                            transpose_psf = False,
                            add_noise = True)

In [None]:
star_dataset.n_stars

In [None]:
star_dataset.images.shape

In [None]:
star_dataset.fluxes.shape

In [None]:
plt.hist(star_dataset.locs.flatten()[star_dataset.locs.flatten() > 0])

In [None]:
star_dataset.fluxes.shape

In [None]:
plt.hist(torch.log10(star_dataset.fluxes[:, :, 0].flatten()[star_dataset.fluxes[:, :, 0].flatten() > 0]))

In [None]:
plt.hist(torch.log10(star_dataset.fluxes[:, :, 1].flatten()[star_dataset.fluxes[:, :, 1].flatten() > 0]))

In [None]:
which_on = star_dataset.fluxes[:, :, 1].flatten() > 0

color = (torch.log10(star_dataset.fluxes[:, :, 1].flatten()) - \
    torch.log10(star_dataset.fluxes[:, :, 0].flatten())) * (-2.5)
plt.hist(color[which_on], bins = 100); 

In [None]:
color[which_on].mean()

In [None]:
color[which_on].var().sqrt()

In [None]:
plt.matshow(star_dataset.images[0, 0, :, :])

In [None]:
plt.matshow(star_dataset.images[0, 1, :, :])

In [None]:
# check alignment between two bands
foo = (star_dataset.images[0, 1, :, :] - star_dataset.background[0, 0]) - \
           (star_dataset.images[0, 0, :, :] - star_dataset.background[0, 1])
    
plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
plt.colorbar()