# Check how to limit the maximum depth

# Method 2: using grad from a ReLU network between depth and target

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
sys.path.append("/home/wukailu/latent-nerf/src/ControlNet")
import torch
from torch import nn
from ldm.modules.diffusionmodules.openaimodel import UNetModel

device = torch.device("cpu")
mean_std = torch.tensor([
    (-0.24595006, 0.54566), # R
    (-0.13202806, 0.55846), # G
    (-0.02775778, 0.57430), # B
    ( 1.72314970, 0.99023), # D
])
mean, std = mean_std.to(device=device).unbind(dim=1)

In [None]:
IMG_SIZE = 128
def get_unet():
    unet = UNetModel(image_size=IMG_SIZE,
                     in_channels=8, out_channels=4,
                     model_channels=128, # the base channel (smallest)
                     channel_mult=[1, 2, 3, 3, 4, 4],
                     num_res_blocks=2,
                     num_head_channels=32,
                     # down 1     2      4          8          16         32
                     # res  128   64     32         16         8          4
                     # chan 128   256    384        384        512        512
                     # type conv  conv   conv+attn  conv+attn  conv+attn  conv+attn
                     attention_resolutions=[4, 8, 16, 32],
                     use_checkpoint=True,
                     use_fp16=False,
                    )
    # use num_groups==1, to avoid color shift problem
    for name, module in unet.named_modules():
        if isinstance(module, nn.GroupNorm):
            module.num_groups = 1
            print(f"convert GN to LN for module: {name}")
    return unet

class Model(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.unet = get_unet()
        # dummy params
        self.no_pixel = nn.Parameter(torch.zeros(4))

    def forward(self, rgbd, t):
        """
        Args:
            rgbd.shape == (B, 4, H, W)

            NOTE
            curr view (i.e. rgbd) is noised
        """
        rgbd_render = torch.zeros_like(rgbd)
        unet_in = torch.cat([rgbd_render, rgbd], dim=1) # (B, 4 + 4, H, W)
        pred = self.unet(unet_in, t) # (B, C, H, W)
        return pred

In [None]:
# load model from dir_ckpt, "checkpoint", "model.pt"
model_path = "/home/wukailu/RGBD-Diffusion/out/RGBD2/checkpoint/model.pt"
ckpt = torch.load(model_path)
model = Model()
model.load_state_dict(ckpt['model'])

In [None]:
from einops import rearrange
from diffusers import DDIMScheduler
from tqdm import tqdm

num_steps  = 50
diffusion_scheduler = DDIMScheduler(num_train_timesteps=1000, clip_sample=False, set_alpha_to_one=False)
@torch.no_grad()
def sampling_forward_fn(rgbd, t): # forward func only for classifier-free sampling
    t = t.reshape([1])
    # forward
    with torch.cuda.amp.autocast(enabled=True):
        pred = model.unet(rgbd, t)
    return pred.float()

class Sampler:
    @torch.no_grad()
    def __call__(self, seed, call_back=None):
        # sample noise
        kwargs_rand = lambda seed, device=device: dict(generator=torch.Generator(device).manual_seed(seed), device=device)
        z_t = torch.randn([1, 4, IMG_SIZE, IMG_SIZE], **kwargs_rand(seed))
        # compute the known part by rendering the mesh onto the current view
        known_part = torch.zeros((1, 4, IMG_SIZE, IMG_SIZE), device=device)
        #
        diffusion_scheduler.set_timesteps(num_steps)
        #
        time_step_lst = diffusion_scheduler.timesteps
        assert num_steps == diffusion_scheduler.num_inference_steps == len(time_step_lst)
        #
        model.eval()
        for i, t in tqdm(enumerate(time_step_lst)):
            rgbd_in = torch.cat([known_part, z_t], dim=1)  # (1, 8, H, W)
            pred_noise = sampling_forward_fn(rgbd_in, t.to(device=device))
            if call_back is not None:
                call_back(i, t, pred_noise, rgbd_in)
            z_t = diffusion_scheduler.step(pred_noise, t.to(device=device), z_t).prev_sample
        # reshape to (H, W, C)
        rgbd_curr = rearrange(z_t, "() C H W -> H W C")
        rgbd_curr = rgbd_curr * std + mean
        rgbd_curr[..., :3] = (rgbd_curr[..., :3] + 1) / 2 # 0~1
        rgbd_curr[..., :3] = rgbd_curr[..., :3].clamp(min=0, max=1)
        rgbd_curr[...,  3] = rgbd_curr[...,  3].clamp(min=0)
        return rgbd_curr

def midas_callback_rgbd(i, t, noise_pred_uncond, model_input):
    # TODO: check this
    with torch.cuda.amp.autocast(enabled=True):
        with torch.enable_grad():
            ltt: torch.Tensor = model_input.detach()
            ltt = ltt.requires_grad_(True)

            pred = ltt[:, -1]
            target = torch.ones_like(pred) * (pred.mean().detach())
            loss_func = torch.nn.MSELoss()  # MSE because we assume it follows a normal distribution
            loss = loss_func(pred.flatten(start_dim=1), target.flatten(start_dim=1))
            loss.backward()

        grad = ltt.grad * -0.5 * 3 / (1**2)  # divide sigma^2, if depth distribution follows N(target, sigma)
        grad = grad.to(dtype=noise_pred_uncond.dtype)
        sqrt_one_minus_alpha_prod = (1 - diffusion_scheduler.alphas_cumprod[t]) ** 0.5
        noise_pred_uncond -= sqrt_one_minus_alpha_prod.to(grad) * grad
        print("grad", grad.mean(), grad.std(), grad.shape)
        print("noise_pred_uncond", noise_pred_uncond.mean(), noise_pred_uncond.std(), noise_pred_uncond.shape)

In [None]:
rgbd = Sampler()(seed=237) # (H, W, 4)

In [None]:
from PIL import Image
import numpy as np

def show_rgbd(rgbd_display):
    rgbd_input = rgbd_display.clone()
    rgbd_input[..., :3] = rgbd_input[..., :3] * 255
    d_max = rgbd_input[..., 3:].max()
    threshold = d_max * 0.05
    ref = rgbd_input[..., 3].clone()
    d_min = ref[ref >= threshold].min()
    d_max = ref[ref >= threshold].max()
    ref = (ref - d_min).clip(0) / (d_max - d_min) * 255
    rgbd_input[..., 3] = ref

    pil = Image.fromarray(rgbd_input.numpy().round().astype(np.uint8)[..., :3])
    display(pil)
    pil = Image.fromarray(rgbd_input.numpy().round().astype(np.uint8)[..., 3])
    display(pil)
    pil = Image.fromarray(rgbd_input.numpy().round().astype(np.uint8))
    display(pil)

In [None]:
show_rgbd(rgbd)
rgbd.max(), rgbd.mean()

In [None]:
def generate_depth_limit(rgbd):
    depths = rgbd[..., 3]
    threshold = depths.max() * 0.05
    d_mean = depths[depths > threshold].mean()
    return np.ones_like(depths) * d_mean

In [None]:
d_limits = generate_depth_limit(rgbd)

In [None]:
def midas_limit_depth(depth_limit):
    def callback(i, t, noise_pred_uncond, model_input, target=depth_limit):
        # TODO: check this
        with torch.cuda.amp.autocast(enabled=True):
            with torch.enable_grad():
                ltt: torch.Tensor = model_input.detach()
                ltt = ltt.requires_grad_(True)

                pred = ltt[:, -1]
                mask = (pred > target)
                loss_func = torch.nn.MSELoss()  # MSE because we assume it follows a normal distribution
                loss = loss_func(pred[mask], target[mask])
                loss.backward()

            grad = ltt.grad * -0.5 * 3 / (1**2)  # divide sigma^2, if depth distribution follows N(target, sigma)
            grad = grad.to(dtype=noise_pred_uncond.dtype)
            sqrt_one_minus_alpha_prod = (1 - diffusion_scheduler.alphas_cumprod[t]) ** 0.5
            noise_pred_uncond -= sqrt_one_minus_alpha_prod.to(grad) * grad
            print("grad", grad.mean(), grad.std(), grad.shape)
            print("noise_pred_uncond", noise_pred_uncond.mean(), noise_pred_uncond.std(), noise_pred_uncond.shape)
    return callback

# Method 3: extend method 2 to latent diffusion