In [None]:
import torch

import torch.nn as nn

import matplotlib.pyplot as plt


In [None]:
import sys
sys.path.insert(0, '../')

In [None]:
from simulated_datasets_lib import StarSimulator
from psf_transform_lib import PsfLocalTransform

In [None]:
psf_fit_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'

In [None]:
simulator = StarSimulator(psf_fit_file,
                          slen = 101, 
                          sky_intensity = 0.)

In [None]:
plt.matshow(simulator.psf)

In [None]:
plt.matshow(simulator.psf_og)

In [None]:
psf_og = torch.Tensor(simulator.psf_og)

In [None]:
m = PsfLocalTransform(psf_og)

In [None]:
m.psf_tiled.shape

In [None]:
_foo = nn.functional.pad(psf_og, (1, 1, 1, 1))

for i in range(m.psf_tiled.shape[0]): 
    
    k = i // 51
    j = i % 51
        
    assert torch.all(m.psf_tiled[i] == _foo[k:(k + 3), j:(j + 3)].flatten())
    assert torch.all(m.psf_tiled[i, 4] == _foo[(k + 1), (j + 1)])

In [None]:
m.weight = nn.Parameter(torch.zeros(m.weight.shape) - 16.) 
m.weight[:, 4] = 16.

In [None]:
m.weight

In [None]:
out = m.forward()

In [None]:
assert (out - simulator.psf).max().abs() < 1e-10

# Checkout out my training

In [None]:
import numpy as np

import torch
import torch.optim as optim

import sdss_dataset_lib
import simulated_datasets_lib
import starnet_vae_lib
import kl_objective_lib

import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ', device)

print('torch version: ', torch.__version__)


In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

In [None]:
full_image = sdss_hubble_data.sdss_image.unsqueeze(0).to(device)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0).to(device)

# true paramters
true_full_locs = sdss_hubble_data.locs.unsqueeze(0)
true_full_fluxes = sdss_hubble_data.fluxes.unsqueeze(0)


In [None]:
# simulator
simulator = simulated_datasets_lib.StarSimulator(
                    psf_fit_file=str(sdss_hubble_data.psf_file),
                    slen = full_image.shape[-1],
                    sky_intensity = 0.)


In [None]:
# define VAE
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 9,
                                           step = 2,
                                           edge_padding = 3,
                                           n_bands = 1,
                                           max_detections = 4)


In [None]:
# define transform
psf_transform = PsfLocalTransform(torch.Tensor(simulator.psf_og),
                                    simulator.slen,
                                    kernel_size = 3)

In [None]:
# define optimizer
learning_rate = 1e-3
weight_decay = 1e-5
optimizer = optim.Adam([
                    {'params': psf_transform.parameters(),
                    'lr': learning_rate}],
                    weight_decay = weight_decay)


In [None]:
init = psf_transform.weight.clone()
print(init)

In [None]:
optimizer.zero_grad()

In [None]:
true_full_locs.shape

In [None]:
_, subimage_locs, subimage_fluxes, _, _ = \
    star_encoder.get_image_stamps(full_image, true_full_locs, true_full_fluxes,
                                    trim_images = False)

In [None]:
import psf_transform_lib

In [None]:
recon_mean, loss = psf_transform_lib.get_psf_transform_loss(full_image, full_background,
                                subimage_locs,
                                subimage_fluxes,
                                star_encoder.tile_coords,
                                star_encoder.stamp_slen,
                                star_encoder.edge_padding,
                                simulator,
                                psf_transform)


In [None]:
plt.matshow(full_image.squeeze())

In [None]:
plt.matshow(recon_mean.detach().squeeze())

In [None]:
loss.mean().backward()

In [None]:
optimizer.step()

In [None]:
(init - psf_transform.weight).abs().max()

In [None]:
torch.any((init - psf_transform.weight) > 0.)

In [None]:
torch.save(psf_transform.state_dict(), './test_out')

In [None]:
psf_transform.weight

In [None]:
psf_transform2 = PsfLocalTransform(torch.Tensor(simulator.psf_og),
                                    simulator.slen,
                                    kernel_size = 3)

In [None]:
psf_transform2.weight

In [None]:
psf_transform2.load_state_dict(torch.load('./test_out', map_location=lambda storage, loc: storage))

In [None]:
psf_transform2.weight

In [None]:
(psf_transform.weight.grad).max()