### (1) Imports, paths, and config

In [None]:
import sys
import os
from os import environ
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt

from hydra import initialize, compose
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

# Change twhit to your username (and rest of path as necessary)
os.chdir('/home/twhit/bliss/case_studies/weak_lensing/')
from lensing_catalog import LensingTileCatalog
from lensing_variational_dist import LensingVariationalDistSpec, LensingVariationalDist

In [None]:
environ["BLISS_HOME"] = "/home/twhit/bliss"
with initialize(config_path="../", version_base=None):
    cfg = compose("lensing_config", overrides={"surveys.sdss.load_image_data=true"})

---

### (2) Generate synthetic images

In [None]:
simulator = instantiate(cfg.simulator)
batch = simulator.get_batch()

In [None]:
_ = plt.imshow(batch['images'][0][2])

---

### (3) Instantiate encoder

In [None]:
encoder = instantiate(cfg.train.encoder)
target_cat = LensingTileCatalog(encoder.tile_slen, batch["tile_catalog"])
truth_callback = lambda _: target_cat

---

### (4) Train encoder on a single batch of synthetic images

Here we'll just try to learn the shear and convergence for the single batch of synthetic images we generated above.

Later, we'll generate a new batch inside the training loop in each iteration.

In [None]:
num_iters = 1000
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

for i in range(num_iters):
    # Forward pass of encoder
    pred = encoder.infer(batch, truth_callback)
    lvd = pred['marginal']
    
    # Compute loss and take optimizer step
    optimizer.zero_grad()
    loss = lvd.compute_nll(target_cat).mean()
    loss.backward()
    optimizer.step()
    
    if i % 10 == 0:
        print('Iteration {}: Loss {}'.format(i, loss.item()))

---

### (5) Summarize results

In [None]:
# Optimized variational distribution
q = lvd.factors

#### Horizontal shear

In [None]:
true_shear1_map = batch["tile_catalog"]["shear"].squeeze()[0][:,:,0]
posterior_mean_shear1_map = q["shear"].mean[0][:,:,0].detach()

In [None]:
np.corrcoef(true_shear1_map.flatten(), posterior_mean_shear1_map.flatten())

In [None]:
fig, (true, posterior) = plt.subplots(nrows=1, ncols=2)
_ = true.imshow(true_shear1_map)
_ = posterior.imshow(posterior_mean_shear1_map)

#### Diagonal shear

In [None]:
true_shear2_map = batch["tile_catalog"]["shear"].squeeze()[0][:,:,1]
posterior_mean_shear2_map = q["shear"].mean[0][:,:,1].detach()

In [None]:
np.corrcoef(true_shear2_map.flatten(), posterior_mean_shear2_map.flatten())

In [None]:
fig, (true, posterior) = plt.subplots(nrows=1, ncols=2)
_ = true.imshow(true_shear2_map)
_ = posterior.imshow(posterior_mean_shear2_map)

#### Convergence

In [None]:
true_convergence_map = batch["tile_catalog"]["convergence"].squeeze()[0]
posterior_mean_convergence_map = q["convergence"].base_dist.mean.sigmoid()[0].detach()

In [None]:
np.corrcoef(true_convergence_map.flatten(), posterior_mean_convergence_map.flatten())

In [None]:
fig, (true, posterior) = plt.subplots(nrows=1, ncols=2)
_ = true.imshow(true_convergence_map)
_ = posterior.imshow(posterior_mean_convergence_map)