In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch import nn
import tqdm
from einops import rearrange, repeat
import numpy as np
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from case_studies.dc2_mdt.utils.rml_df import RMLDF
from case_studies.dc2_mdt.utils.new_simulate_image import ImageSimulator

In [None]:
image_size = 8
max_objects = 2
image_normalize_strategy = "linear_scale"

## RML

In [None]:
def rml_loss_mask_fn(x0_population: torch.Tensor, output_population: torch.Tensor):
    ns_mask = x0_population[..., 0:1] > 0.0
    output_ns = output_population[..., 0:1]
    output_other = output_population[..., 1:]
    return torch.cat([output_ns, 
                      torch.where(ns_mask, output_other, torch.full_like(output_other, fill_value=-1.0))],
                      dim=-1)

In [None]:
def rml_pred_x0_rectify_fn(pred_x0: torch.Tensor):
    ns_mask = pred_x0[..., 0:1] > 0.0
    pred_ns = pred_x0[..., 0:1]
    pred_other = pred_x0[..., 1:]
    return torch.cat([pred_ns, 
                      torch.where(ns_mask, pred_other, torch.full_like(pred_other, fill_value=-1.0))],
                      dim=-1)

In [None]:
def rml_loss_weight_fn(t: torch.Tensor, alpha: torch.Tensor, sigma: torch.Tensor):
    alpha_2 = alpha ** 2 + 1e-3
    sigma_2 = sigma ** 2 + 1e-3
    loss_weights = 1 / (1 + sigma_2 / alpha_2)
    return loss_weights.to(device=t.device)[t.flatten()].view(t.shape)

In [None]:
class FourierMLP(nn.Module):
    def __init__(self, data_shape, num_layers, hidden_ch):
        super().__init__()
        
        data_flat_len = int(np.prod(data_shape))

        self.register_buffer("timestep_coeff", torch.linspace(start=0.1, end=100, steps=hidden_ch))  # (hidden, )
        self.timestep_phase = nn.Parameter(torch.randn(hidden_ch))  # (hidden, )
        self.input_embed = nn.Sequential(
            nn.Linear(1, hidden_ch),
            nn.GELU(),
            nn.Linear(hidden_ch, hidden_ch)
        )
        self.timestep_embed = nn.Sequential(
            nn.Linear(2 * hidden_ch, hidden_ch),
            nn.GELU(),
            nn.Linear(hidden_ch, hidden_ch),
        )
        self.image_embed = nn.Sequential(
            nn.Linear(image_size * image_size, hidden_ch),
            nn.GELU(),
            nn.Linear(hidden_ch, hidden_ch)
        )
        # self.layers = nn.Sequential(
        #     nn.Linear(hidden_ch * (data_flat_len + 1) + data_flat_len, hidden_ch * 4), 
        #     nn.GELU(),
        #     nn.Linear(hidden_ch * 4, hidden_ch * 2),
        #     nn.GELU(),
        #     nn.Linear(hidden_ch * 2, hidden_ch),
        #     nn.GELU(),
        #     *[
        #         nn.Sequential(nn.Linear(hidden_ch, hidden_ch), nn.GELU())
        #         for _ in range(num_layers)
        #     ],
        #     nn.Linear(hidden_ch, data_flat_len),
        # )

        self.layers_net = nn.ModuleList([
            nn.Linear(hidden_ch * (data_flat_len + 1) + data_flat_len, hidden_ch * 2), 
            nn.GELU(),
            # nn.Linear(hidden_ch * 4, hidden_ch * 2),
            # nn.GELU(),
            nn.Linear(hidden_ch * 2, hidden_ch),
            nn.GELU(),
            *[
                nn.Sequential(nn.Linear(hidden_ch, hidden_ch), nn.GELU())
                for _ in range(num_layers)
            ],
            nn.Linear(hidden_ch, data_flat_len),
        ])

    def layers(self, x):
        for i, m in enumerate(self.layers_net):
            if i < 4 or i == len(self.layers_net) - 1:
                x = m(x)
                continue
            x = m(x) + x
        return x

    def forward(self, x, t, image, epsilon, is_training):
        if is_training:
            t = t.unsqueeze(-1)  # (b, m, max_objects, k, 1)
            sin_embed_t = torch.sin(
                (self.timestep_coeff * t.float()) + self.timestep_phase
            )
            cos_embed_t = torch.cos(
                (self.timestep_coeff * t.float()) + self.timestep_phase
            )
            embed_t = self.timestep_embed(
                rearrange(torch.stack([sin_embed_t, cos_embed_t], dim=0), 
                          "d b m max_objects k hidden -> b m max_objects k (d hidden)")
            )  # (b, m, max_objects, k, hidden)
            embed_xt = self.input_embed(x.unsqueeze(-1))  # (b, m, max_objects, k, hidden)
            embed_image = self.image_embed(image.flatten(1))  # (b, hidden)
            embed_image = repeat(embed_image, "b hidden -> b m hidden", m=t.shape[1])
            out = self.layers(
                torch.cat([(embed_xt + embed_t).flatten(2), 
                           embed_image, 
                           epsilon.flatten(2)], dim=-1)
            )
            return out.view(x.shape)
        t = t.clone()
        image = image.clone()
        t = t.unsqueeze(-1)  # (b, max_objects, k, 1)
        sin_embed_t = torch.sin(
            (self.timestep_coeff * t.float()) + self.timestep_phase
        )
        cos_embed_t = torch.cos(
            (self.timestep_coeff * t.float()) + self.timestep_phase
        )
        embed_t = self.timestep_embed(
            rearrange(torch.stack([sin_embed_t, cos_embed_t], dim=0), 
                      "d b max_objects k hidden -> b max_objects k (d hidden)")
        )
        embed_xt = self.input_embed(x.unsqueeze(-1))
        embed_image = self.image_embed(image.flatten(1))
        out = self.layers(
            torch.cat([(embed_xt + embed_t).flatten(1), 
                       embed_image, 
                       epsilon.flatten(1)], dim=-1)
        )
        return out.view(x.shape)

In [None]:
device = torch.device("cuda:6")
batch_size = 1024
training_time_steps = 1000
training_iters = 20_000
ddim_eta = 0.0  # use 0.0 for better results when max_objects >= 2
log_freq = 500
seed = 1201023
pytorch_lightning.seed_everything(seed)

In [None]:
image_simulator = ImageSimulator(img_height=image_size,
                                 img_width=image_size,
                                 max_objects=max_objects,
                                 psf_stdev=1.0,
                                 flux_alpha=10.0,
                                 flux_beta=0.01,
                                 pad=0,
                                 always_max_count=False,
                                 constant_locs=False,
                                 coadd_images=True).to(device=device)

In [None]:
training_diffusion = RMLDF(num_timesteps=training_time_steps,
                            m=32,
                            lambda_=1.0,
                            beta=1.0,
                            loss_mask_fn=rml_loss_mask_fn,
                            pred_x0_rectify_fn=rml_pred_x0_rectify_fn,
                            loss_weight_fn=rml_loss_weight_fn)
sampling_diffusion = training_diffusion

In [None]:
my_net = FourierMLP(data_shape=[2, 4], 
                    num_layers=8, 
                    hidden_ch=64).to(device=device)
my_optimizer = torch.optim.Adam(my_net.parameters(), lr=1e-3, amsgrad=True)
my_scheduler = torch.optim.lr_scheduler.MultiStepLR(my_optimizer, milestones=[training_iters // 5 * 4], gamma=0.1)
# my_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(my_optimizer, T_max=1000)

In [None]:
# max_flux_boundary = 2000.0
# def encode_flux(flux: torch.Tensor):
#     assert flux.min() >= 0.0
#     flux = flux.clamp(max=max_flux_boundary)
#     return (torch.log1p(flux) / torch.log1p(torch.tensor(max_flux_boundary))) * 2 - 1

# def decode_flux(flux_minus1_to_1: torch.Tensor):
#     assert flux_minus1_to_1.min() >= -1.0 and flux_minus1_to_1.max() <= 1.0
#     return torch.expm1((flux_minus1_to_1 + 1) / 2 * torch.log1p(torch.tensor(max_flux_boundary)))

In [None]:
max_flux_boundary = 2000.0
def encode_flux(flux: torch.Tensor):
    assert flux.min() >= 0.0
    flux = flux.clamp(max=max_flux_boundary)
    return (flux / max_flux_boundary) * 2 - 1

def decode_flux(flux_minus1_to_1: torch.Tensor):
    assert flux_minus1_to_1.min() >= -1.0 and flux_minus1_to_1.max() <= 1.0
    return (flux_minus1_to_1 + 1) / 2 * max_flux_boundary

In [None]:
def encode_x_start(catalog):
    n_sources = catalog["counts"]  # (b, )
    locs = catalog["locs"]  # (b, m, 2)
    fluxes = catalog["fluxes"].unsqueeze(-1)  # (b, m, 1)
    n_sources = (n_sources.unsqueeze(-1) >= torch.arange(1, locs.shape[1] + 1, device=locs.device)).unsqueeze(-1)  # (b, m, 1)
    x_start = torch.cat([n_sources * 2 - 1, locs / image_size * 2 - 1, encode_flux(fluxes)], dim=-1)  # (b, m, 4)
    dist_to_ori = torch.sqrt(((x_start[..., 1:3] + 1) ** 2).sum(dim=-1))
    sorted_index = dist_to_ori.argsort(dim=-1, descending=True)  # (b, m)
    return torch.take_along_dim(x_start, repeat(sorted_index, "... -> ... r", r=4), dim=-2)  # (b, m, 4)

In [None]:
def decode_x_start(output_x_start):
    n_sources = output_x_start[..., 0] > 0.0  # (b, m)
    locs = (output_x_start[..., 1:3] + 1) / 2 * image_size  # (b, m, 2)
    fluxes = decode_flux(output_x_start[..., 3])  # (b, m)
    return {
        "counts": n_sources.sum(dim=-1),
        "n_sources": n_sources.int(),
        "locs": locs,
        "fluxes": fluxes,
    }

In [None]:
def normalize_image(input_image):
    match image_normalize_strategy:
        case "none":
            output_image = input_image
        case "log":
            output_image = torch.log1p(input_image)
        case "linear_scale":
            output_image = input_image / 1000
        case _:
            raise NotImplementedError()
    return output_image

In [None]:
def training_t_schedule(batch_size):
    pred_ns_mode = torch.from_numpy(np.random.choice(training_time_steps, size=(batch_size, 1)))
    pred_ns_mode = torch.cat([pred_ns_mode,
                              torch.full((batch_size, 3), fill_value=training_time_steps - 1, dtype=torch.long)],
                              dim=-1)
    pred_locs_mode = torch.from_numpy(np.random.choice(training_time_steps, size=(batch_size, 1)))
    pred_locs_mode = torch.cat([torch.full((batch_size, 1), fill_value=0, dtype=torch.long),
                                pred_locs_mode,
                                pred_locs_mode,
                                torch.full((batch_size, 1), fill_value=training_time_steps - 1, dtype=torch.long)],
                                dim=-1)
    pred_fluxes_mode = torch.from_numpy(np.random.choice(training_time_steps, size=(batch_size, 1)))
    pred_fluxes_mode = torch.cat([torch.full((batch_size, 3), fill_value=0, dtype=torch.long),
                                  pred_fluxes_mode],
                                  dim=-1)
    t = torch.cat([pred_ns_mode, pred_locs_mode, pred_fluxes_mode], dim=0)
    t = t[torch.randperm(t.shape[0])[:batch_size]]
    return repeat(t, "b k -> b m k", m=max_objects)

In [None]:
saved_file_path = Path(f"./rml_df_model_{training_iters}.pt")
if not saved_file_path.exists():
    my_net.train()
    loss_record = []
    for i in tqdm.tqdm(list(range(training_iters))):
        catalog = image_simulator.generate(batch_size)
        t = training_t_schedule(batch_size).to(device=device)
        input_image = catalog["images"]  # (b, h, w)
        input_image = normalize_image(input_image)
        train_loss_args = {
            "model": my_net,
            "x_start": encode_x_start(catalog),
            "t": t,
        }
        loss = training_diffusion.training_losses(**train_loss_args, 
                                                model_kwargs={"image": input_image})["loss"]
        loss = loss.mean()
        loss_record.append(loss.item())
        my_optimizer.zero_grad()
        loss.backward()
        my_optimizer.step()
        my_scheduler.step()
        if (i + 1) % log_freq == 0:
            print(f"[{i + 1}/{training_iters}] loss: {loss.item():.3e}")    
    
    torch.save(my_net.state_dict(), saved_file_path)

    plt.plot(loss_record)
    plt.xlabel("Training Steps")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.show()
else:
    with open(saved_file_path, "rb") as f:
        my_net_state_dict = torch.load(f, map_location=device)
        my_net.load_state_dict(my_net_state_dict)

In [None]:
def pad_t(input_t: torch.Tensor, pad_front_num, pad_rear_num):
    assert input_t.ndim == 1
    pad_front_v = torch.full((pad_front_num, ), fill_value=input_t[0].item(), device=input_t.device)
    pad_rear_v = torch.full((pad_rear_num, ), fill_value=input_t[-1].item(), device=input_t.device)
    return torch.cat([pad_front_v, input_t, pad_rear_v], dim=0)

In [None]:
ns_sample_steps = 5
locs_sample_steps = 5
fluxes_sample_steps = 3

def generate_k_vec(sample_steps):
    return torch.linspace(0, training_time_steps - 1, sample_steps).int().flip(dims=(0,))

n_sources_k_vec = pad_t(generate_k_vec(ns_sample_steps), 
                        pad_front_num=0, 
                        pad_rear_num=locs_sample_steps + fluxes_sample_steps)
locs_y_k_vec = pad_t(generate_k_vec(locs_sample_steps), 
                     pad_front_num=ns_sample_steps, 
                     pad_rear_num=fluxes_sample_steps)
locs_x_k_vec = locs_y_k_vec
flux_k_vec = pad_t(generate_k_vec(fluxes_sample_steps), 
                   pad_front_num=ns_sample_steps + locs_sample_steps, 
                   pad_rear_num=0)
k_matrix = repeat(torch.stack([n_sources_k_vec, locs_y_k_vec, locs_x_k_vec, flux_k_vec], dim=-1),
                  "k f -> k b m f", b=batch_size, m=max_objects).to(device=device)

In [None]:
k_matrix.shape

In [None]:
my_net.eval()
val_true_cat = []
val_est_cat = []
with torch.inference_mode():
    for i in tqdm.tqdm(list(range(1000))):
        val_catalog = image_simulator.generate(batch_size=1024)
        input_image = normalize_image(val_catalog["images"])
        val_catalog = decode_x_start(encode_x_start(val_catalog))
        val_true_cat.append(move_data_to_device(val_catalog, "cpu"))
        diffusion_sampling_config = {
            "model": my_net,
            "shape": (1024, max_objects, 4),
            "clip_denoised": True,
            "model_kwargs": {"image": input_image}
        }
        sample = sampling_diffusion.ddim_sample_loop(**diffusion_sampling_config, 
                                                     k_matrix=k_matrix,
                                                     eta=ddim_eta)
        val_est_cat.append(move_data_to_device(decode_x_start(sample), "cpu"))

In [None]:
diffusion_pred_ns = []
diffusion_true_ns = []
diffusion_pred_locs = []
diffusion_true_locs = []
diffusion_pred_fluxes = []
diffusion_true_fluxes = []
for ec, tc in zip(val_est_cat, val_true_cat, strict=True):
    diffusion_pred_ns.append(ec["n_sources"])
    diffusion_true_ns.append(tc["n_sources"])
    diffusion_pred_locs.append(ec["locs"])
    diffusion_true_locs.append(tc["locs"])
    diffusion_pred_fluxes.append(ec["fluxes"])
    diffusion_true_fluxes.append(tc["fluxes"])
diffusion_pred_ns = torch.cat(diffusion_pred_ns, dim=0)
diffusion_true_ns = torch.cat(diffusion_true_ns, dim=0)
diffusion_pred_locs = torch.cat(diffusion_pred_locs, dim=0)
diffusion_true_locs = torch.cat(diffusion_true_locs, dim=0)
diffusion_pred_fluxes = torch.cat(diffusion_pred_fluxes, dim=0)
diffusion_true_fluxes = torch.cat(diffusion_true_fluxes, dim=0)

In [None]:
diffusion_true_ns.shape, diffusion_pred_ns.shape

In [None]:
diffusion_pred_locs.shape, diffusion_true_locs.shape

In [None]:
diffusion_pred_fluxes.shape, diffusion_true_fluxes.shape

In [None]:
def plot_cm(d_pred_bin_index, d_true_bin_index, bin_num, bin_labels, axis_label):
    d_cm = torch.zeros(bin_num, bin_num, dtype=torch.int)
    for ri in range(d_cm.shape[0]):
        for ci in range(d_cm.shape[1]):
            d_cm[ri, ci] = ((d_pred_bin_index == ri) & (d_true_bin_index == ci)).sum()

    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    sns.heatmap(d_cm,
                annot=True,
                fmt="d", cmap="Greens", cbar=False,
                xticklabels=bin_labels,
                yticklabels=bin_labels,
                ax=ax)
    ax.set_xlabel(f"True {axis_label}")
    ax.set_ylabel(f"Pred {axis_label}")
    ax.set_title("Diffusion")
    fig.show()

    # fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    # sns.heatmap(d_cm / torch.sum(d_cm, dim=0, keepdim=True),
    #             annot=True,
    #             fmt=".2f", cmap="Greens", cbar=False,
    #             xticklabels=bin_labels,
    #             yticklabels=bin_labels,
    #             ax=ax)
    # ax.set_xlabel(f"True {axis_label}")
    # ax.set_ylabel(f"Pred {axis_label}")
    # ax.set_title("Diffusion (CM in Percent)")
    # fig.show()

    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    sns.heatmap(((d_cm - d_cm.T) / torch.minimum(d_cm, d_cm.T).clamp(min=1)).abs(),
                annot=(d_cm - d_cm.T) / torch.minimum(d_cm, d_cm.T).clamp(min=1),
                fmt=".2f", cmap="Greens", cbar=False,
                xticklabels=bin_labels,
                yticklabels=bin_labels,
                ax=ax)
    ax.set_xlabel(f"True {axis_label}")
    ax.set_ylabel(f"Pred {axis_label}")
    af = ((d_cm - d_cm.T) / d_cm.sum()).abs().sum() / (bin_num * (bin_num - 1))
    ax.set_title(f"Diffusion (Asymmetry Factor = {af:.2e})")
    fig.show()

In [None]:
plot_cm(diffusion_pred_ns.sum(dim=-1), diffusion_true_ns.sum(dim=-1), 
        bin_num=3, bin_labels=list(range(3)), axis_label="Source Count")

In [None]:
def inclusive_bucektize(input_t, boundary):
    new_boundary = boundary.clone()
    new_boundary[0] -= 1e-3
    new_boundary[-1] += 1e-3
    b_index = torch.bucketize(input_t, new_boundary)
    assert (b_index > 0).all()
    assert (b_index < new_boundary.shape[0]).all()
    return b_index - 1

In [None]:
valid_source_mask = diffusion_pred_ns.bool() & \
      (diffusion_pred_ns.sum(dim=-1) == diffusion_true_ns.sum(dim=-1)).unsqueeze(-1)

In [None]:
locs_bin_boundary = torch.linspace(0.0, 8.0, 5)
d_pred_locs_x_bin_index = inclusive_bucektize(diffusion_pred_locs[valid_source_mask][:, 1], locs_bin_boundary)
d_true_locs_x_bin_index = inclusive_bucektize(diffusion_true_locs[valid_source_mask][:, 1], locs_bin_boundary)
plot_cm(d_pred_locs_x_bin_index, d_true_locs_x_bin_index, 
        bin_num=4, bin_labels=[f"[{bl1:.1f}, {bl2:.1f}]" 
                               for bl1, bl2 in zip(locs_bin_boundary[:-1], 
                                                   locs_bin_boundary[1:])], 
        axis_label="Loc X")

In [None]:
locs_bin_boundary = torch.linspace(0.0, 8.0, 5)
d_pred_locs_y_bin_index = inclusive_bucektize(diffusion_pred_locs[valid_source_mask][:, 0], locs_bin_boundary)
d_true_locs_y_bin_index = inclusive_bucektize(diffusion_true_locs[valid_source_mask][:, 0], locs_bin_boundary)
plot_cm(d_pred_locs_y_bin_index, d_true_locs_y_bin_index, 
        bin_num=4, bin_labels=[f"[{bl1:.1f}, {bl2:.1f}]" 
                               for bl1, bl2 in zip(locs_bin_boundary[:-1], 
                                                   locs_bin_boundary[1:])], 
        axis_label="Loc Y")

In [None]:
fluxes_bin_boundary = torch.linspace(0.0, 2000.0, 6)
d_pred_fluxes_bin_index = inclusive_bucektize(diffusion_pred_fluxes[valid_source_mask], fluxes_bin_boundary)
d_true_fluxes_bin_index = inclusive_bucektize(diffusion_true_fluxes[valid_source_mask], fluxes_bin_boundary)
plot_cm(d_pred_fluxes_bin_index, d_true_fluxes_bin_index, 
        bin_num=5, bin_labels=[f"[{bl1:.1f}, {bl2:.1f}]" 
                               for bl1, bl2 in zip(fluxes_bin_boundary[:-1], 
                                                   fluxes_bin_boundary[1:])], 
        axis_label="Fluxes")