In [None]:
import torch
import tqdm
import torch.nn.functional as F
from torch import optim

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device

import matplotlib.pyplot as plt

from bliss.surveys.dc2 import DC2DataModule
from case_studies.dc2_diffusion.utils.catalog_parser import CatalogParser
from bliss.catalog import TileCatalog
from bliss.encoder.metrics import CatalogMatcher
from case_studies.dc2_diffusion.utils.metrics import DetectionPerformance
from bliss.global_env import GlobalEnv

from case_studies.dc2_new_diffusion.utils.autoencoder import CatalogEncoder, CatalogDecoder

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./latent_diffusion_config", version_base=None):
    new_diffusion_notebook_cfg = compose("latent_diffusion_notebook_config")

In [None]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.batch_size = 512
dc2.setup(stage="fit")
GlobalEnv.current_encoder_epoch = 1
GlobalEnv.seed_in_this_program = 7272
dc2_train_dataloader = dc2.train_dataloader()

catalog_parser: CatalogParser = instantiate(new_diffusion_notebook_cfg.encoder.catalog_parser)

In [None]:
target_ch = catalog_parser.n_params_per_source
encoder = CatalogEncoder(target_ch, hidden_dim=32)
decoder = CatalogDecoder(target_ch, hidden_dim=32 // 4)

In [None]:
encoder = encoder.to(device=device)
decoder = decoder.to(device=device)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)

### Train encoder and decoder with image weights

In [None]:
encoder.train()
decoder.train()

total_batch = len(dc2_train_dataloader)
epoch = 3
for _ in range(epoch):
    i = 0
    dc2_train_dataloader = dc2.train_dataloader()
    for batch in tqdm.tqdm(dc2_train_dataloader):
        batch = move_data_to_device(batch, device=device)
        target_cat = TileCatalog(batch["tile_catalog"])
        target_cat1 = target_cat.get_brightest_sources_per_tile(
            band=2, exclude_num=0
        )
        encoded_catalog_tensor = catalog_parser.encode(target_cat1).permute([0, 3, 1, 2])  # (b, k, h, w)
    
        optimizer.zero_grad()
        encoder_pred = encoder(encoded_catalog_tensor)
        image_weights = (torch.log(torch.norm(batch["images"], 
                                              dim=1, p=2, keepdim=True) + 1) * 100) + 1
        weighted_encoder_pred = encoder_pred * (1 / image_weights)
        recovered_target = decoder(weighted_encoder_pred)
        loss = ((recovered_target - encoded_catalog_tensor) ** 2).mean()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"step [{i + 1}/{total_batch}], loss: {loss.item():.6f}")
        i += 1
    GlobalEnv.current_encoder_epoch += 1

### Test encoder and decoder

In [None]:
matcher = CatalogMatcher(dist_slack=1.0)
f1_metric = DetectionPerformance().to(device=device)
dc2_val_dataloader = dc2.val_dataloader()

In [None]:
encoder.eval()
decoder.eval()
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch = move_data_to_device(batch, device=device)
    target_tile_cat = TileCatalog(batch["tile_catalog"])
    target_tile_cat = target_tile_cat.get_brightest_sources_per_tile(band=2,  exclude_num=0)
    target_full_cat = target_tile_cat.to_full_catalog(tile_slen)

    encoded_catalog_tensor = catalog_parser.encode(target_tile_cat).permute([0, 3, 1, 2])  # (b, k, h, w)
    with torch.no_grad():
        encoder_pred = encoder(encoded_catalog_tensor)
        image_weights = (torch.log(torch.norm(batch["images"], 
                                              dim=1, p=2, keepdim=True) + 1) * 100) + 1  # regularization 
        weighted_encoder_pred = encoder_pred * (1 / image_weights)
        recovered_target = decoder(weighted_encoder_pred)
    recovered_target = catalog_parser.clip_tensor(recovered_target.permute([0, 2, 3, 1]))
    recovered_tile_cat = catalog_parser.decode(recovered_target)
    recovered_full_cat = recovered_tile_cat.to_full_catalog(tile_slen)

    matching = matcher.match_catalogs(target_full_cat, recovered_full_cat)
    f1_metric.update(target_full_cat, recovered_full_cat, matching)

In [None]:
for k, v in f1_metric.compute().items():
    print(f"{k}: {v}")

In [None]:
plt.imshow((encoder_pred[0].norm(dim=0, p=2) + 1).log().cpu().numpy(), 
           cmap="viridis", interpolation="nearest")
plt.colorbar()
plt.title("Heatmap")
plt.xlabel("Columns")
plt.ylabel("Rows")
plt.show()

### Train a new decoder without image weights

In [None]:
new_decoder = CatalogDecoder(target_ch, hidden_dim=32 // 4)
new_decoder = new_decoder.to(device=device)
new_optimizer = optim.Adam(new_decoder.parameters(), lr=1e-3)

In [None]:
encoder.eval()
new_decoder.train()

total_batch = len(dc2_train_dataloader)
epoch = 3
for _ in range(epoch):
    i = 0
    dc2_train_dataloader = dc2.train_dataloader()
    for batch in tqdm.tqdm(dc2_train_dataloader):
        batch = move_data_to_device(batch, device=device)
        target_cat = TileCatalog(batch["tile_catalog"])
        target_cat1 = target_cat.get_brightest_sources_per_tile(
            band=2, exclude_num=0
        )
        encoded_catalog_tensor = catalog_parser.encode(target_cat1).permute([0, 3, 1, 2])  # (b, k, h, w)
    
        new_optimizer.zero_grad()
        with torch.no_grad():
            encoder_pred = encoder(encoded_catalog_tensor)
        recovered_target = new_decoder(encoder_pred)
        loss = ((recovered_target - encoded_catalog_tensor) ** 2).mean()
        loss.backward()
        new_optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"step [{i + 1}/{total_batch}], loss: {loss.item():.6f}")
        i += 1
    GlobalEnv.current_encoder_epoch += 1

### Test this new decoder

In [None]:
matcher = CatalogMatcher(dist_slack=1.0)
f1_metric = DetectionPerformance().to(device=device)
dc2_val_dataloader = dc2.val_dataloader()

In [None]:
encoder.eval()
new_decoder.eval()
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch = move_data_to_device(batch, device=device)
    target_tile_cat = TileCatalog(batch["tile_catalog"])
    target_tile_cat = target_tile_cat.get_brightest_sources_per_tile(band=2,  exclude_num=0)
    target_full_cat = target_tile_cat.to_full_catalog(tile_slen)

    encoded_catalog_tensor = catalog_parser.encode(target_tile_cat).permute([0, 3, 1, 2])  # (b, k, h, w)
    with torch.no_grad():
        encoder_pred = encoder(encoded_catalog_tensor)
        recovered_target = new_decoder(encoder_pred)
    recovered_target = catalog_parser.clip_tensor(recovered_target.permute([0, 2, 3, 1]))
    recovered_tile_cat = catalog_parser.decode(recovered_target)
    recovered_full_cat = recovered_tile_cat.to_full_catalog(tile_slen)

    matching = matcher.match_catalogs(target_full_cat, recovered_full_cat)
    f1_metric.update(target_full_cat, recovered_full_cat, matching)

In [None]:
for k, v in f1_metric.compute().items():
    print(f"{k}: {v}")

In [None]:
torch.save(encoder.state_dict(), "encoder.pt")
torch.save(new_decoder.state_dict(), "decoder.pt")

### Plotting

In [None]:
image_index = 6

In [None]:
plt.imshow((batch["images"][image_index].norm(dim=0, p=2) + 1).log().cpu().numpy(), 
           cmap="viridis", interpolation="nearest")
plt.colorbar()
plt.title("Heatmap")
plt.xlabel("Columns")
plt.ylabel("Rows")
plt.show()

In [None]:
encoder_pred.max(), encoder_pred.min()

In [None]:
plt.imshow((encoder_pred[image_index].norm(dim=0, p=2) + 1).log().cpu().numpy(), 
           cmap="viridis", interpolation="nearest")
plt.colorbar()
plt.title("Heatmap")
plt.xlabel("Columns")
plt.ylabel("Rows")
plt.show()

In [None]:
plt.imshow((1 / (batch["images"][image_index].norm(dim=0, p=2) * 100 + 1)).cpu().numpy(), 
           cmap="viridis", interpolation="nearest")
plt.colorbar()
plt.title("Heatmap")
plt.xlabel("Columns")
plt.ylabel("Rows")
plt.show()

In [None]:
upsampled_cat_tensor = F.interpolate(encoded_catalog_tensor, scale_factor=4, mode="bilinear")
plt.imshow(upsampled_cat_tensor[image_index, 0].cpu().numpy(), 
           cmap="viridis", interpolation="nearest")
plt.colorbar()
plt.title("Heatmap")
plt.xlabel("Columns")
plt.ylabel("Rows")
plt.show()

In [None]:
plt.imshow(recovered_target.permute([0, 3, 1, 2])[image_index, 0].cpu().numpy(), 
           cmap="viridis", interpolation="nearest")
plt.colorbar()
plt.title("Heatmap")
plt.xlabel("Columns")
plt.ylabel("Rows")
plt.show()

### Get min and max

In [None]:
encoder.eval()

dc2_train_dataloader = dc2.train_dataloader()
encoder_pred_max = -torch.inf
encoder_pred_min = torch.inf
for batch in tqdm.tqdm(dc2_train_dataloader):
    batch = move_data_to_device(batch, device=device)
    target_cat = TileCatalog(batch["tile_catalog"])
    target_cat1 = target_cat.get_brightest_sources_per_tile(
        band=2, exclude_num=0
    )
    encoded_catalog_tensor = catalog_parser.encode(target_cat1).permute([0, 3, 1, 2])  # (b, k, h, w)

    with torch.no_grad():
        encoder_pred = encoder(encoded_catalog_tensor)
    if encoder_pred.max() > encoder_pred_max:
        encoder_pred_max = encoder_pred.max()
    if encoder_pred.min() < encoder_pred_min:
        encoder_pred_min = encoder_pred.min()

dc2_val_dataloader = dc2.val_dataloader()
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch = move_data_to_device(batch, device=device)
    target_cat = TileCatalog(batch["tile_catalog"])
    target_cat1 = target_cat.get_brightest_sources_per_tile(
        band=2, exclude_num=0
    )
    encoded_catalog_tensor = catalog_parser.encode(target_cat1).permute([0, 3, 1, 2])  # (b, k, h, w)

    with torch.no_grad():
        encoder_pred = encoder(encoded_catalog_tensor)

    if encoder_pred.max() > encoder_pred_max:
        encoder_pred_max = encoder_pred.max()
    if encoder_pred.min() < encoder_pred_min:
        encoder_pred_min = encoder_pred.min()

In [None]:
encoder_pred_min, encoder_pred_max