In [None]:
import torch

import torch.nn as nn

import matplotlib.pyplot as plt

import fitsio

import numpy as np

In [None]:
import sys
sys.path.insert(0, '../')

In [None]:
from simulated_datasets_lib import StarSimulator
from psf_transform_lib import PsfLocalTransform

# load psf 

In [None]:
psf_dir = '../../multiband_pcat/Data/idR-002583-2-0136/psfs/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

psf_og = torch.Tensor(np.array([psf_r]))

# get transform

In [None]:
m = PsfLocalTransform(psf_og)

In [None]:
psf_og.shape

In [None]:
m.psf_tiled.shape

# check tiling

In [None]:
_foo = nn.functional.pad(psf_og, (1, 1, 1, 1))

for i in range(m.psf_tiled.shape[0]): 
    
    k = i // 25
    j = i % 25
    
    for b in range(m.n_bands): 
        assert torch.all(m.psf_tiled[i, b] == _foo[b, k:(k + 3), j:(j + 3)].flatten())
        assert torch.all(m.psf_tiled[i, b, 4] == _foo[b, (k + 1), (j + 1)])

# Check identity map

In [None]:
w = torch.zeros(m.psf_slen ** 2, m.n_bands, m.kernel_size ** 2)
w[:, :, 4] = 1.

In [None]:
assert (m.apply_weights(w) == psf_og).all()

# Check normalization

In [None]:
## with initial weights

In [None]:
psf_og.max()

In [None]:
psf0 = m.forward().detach()

In [None]:
for b in range(m.n_bands): 
    assert (psf_og[b].sum() - psf0[b].sum()).abs() < 1e-6

In [None]:
plt.matshow(psf0[0, 38:63, 38:63] - psf_og[0])
plt.colorbar()

In [None]:
## with random weights

In [None]:
m.weight = nn.Parameter(torch.randn(m.weight.shape)) 
psf1 = m.forward()

In [None]:
for b in range(m.n_bands): 
    assert (psf_og[b].sum() - psf1[b].sum()).abs() < 1e-6

# Checkout my training

In [None]:
import sys
sys.path.insert(0, '../')

In [None]:
import numpy as np

import torch
import torch.optim as optim

import sdss_dataset_lib
import simulated_datasets_lib
import starnet_vae_lib
from psf_transform_lib import PsfLocalTransform, get_psf_loss

import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ', device)

print('torch version: ', torch.__version__)

# set seed
np.random.seed(4534)
_ = torch.manual_seed(2534)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
import fitsio

In [None]:
# get sdss data
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = [2, 3])


In [None]:
# image
full_image = sdss_hubble_data.sdss_image.unsqueeze(0).to(device)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0).to(device)


In [None]:
# true paramters
true_full_locs = sdss_hubble_data.locs.unsqueeze(0).to(device)
true_full_fluxes = sdss_hubble_data.fluxes.unsqueeze(0).to(device)

In [None]:
true_full_fluxes.shape

In [None]:
print(full_image.mean());
print(full_background.mean());
print(true_full_locs.mean());
print(true_full_fluxes.mean());

In [None]:
true_full_fluxes

In [None]:
# load psf
psf_dir = '../../multiband_pcat/Data/idR-002583-2-0136/psfs/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

psf_og = np.array([psf_r, psf_i])

In [None]:
# define transform
psf_transform = PsfLocalTransform(torch.Tensor(psf_og).to(device),
                                full_image.shape[-1],
                                kernel_size = 3)


In [None]:
locs = true_full_locs
fluxes = true_full_fluxes
n_stars = torch.sum(true_full_fluxes[:, :, 0] > 0, dim = 1);

psf_trained = psf_transform.forward()
# psf_trained = torch.Tensor(simulated_datasets_lib._expand_psf(psf_og, 101))

print((psf_trained**2).mean());
print(full_image.mean());
print(full_background.mean());
print(locs.mean());
print(fluxes.mean());
print(n_stars);


# get loss
recon_mean, loss = get_psf_loss(full_image, full_background,
                    locs, 
                    fluxes, 
                    n_stars,
                    psf_trained,
                    pad = 5,
                    grid = None)

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(8, 4))

for b in range(psf_transform.n_bands): 
    resid0 = ((recon_mean[0, b].detach() - full_image[0, b]) / full_image[0, b])[5:95, 5:95]
    im0 = axarr[b].matshow(resid0, 
                     vmax = resid0.abs().max(), 
                     vmin = -resid0.abs().max(), 
                     cmap = plt.get_cmap('bwr'))

    fig.colorbar(im0, ax = axarr[b])

In [None]:
print(loss)

In [None]:
optimizer = optim.Adam([
                    {'params': psf_transform.parameters(),
                    'lr': 0.1}],
                    weight_decay = 1e-5)


In [None]:
optimizer.zero_grad()

In [None]:
loss = (psf_transform.forward()**2).sum()

In [None]:
loss

In [None]:
loss.backward()