# 
https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/annotated_diffusion.ipynb#scrollTo=cc57b01f

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import torch

from ddpm import LatentDiffusionModel, CustomDataset, get_ddpm_scheduler_variables, extract
from tqdm.auto import tqdm

timesteps = 1000
timesteps, betas, alphas, alphas_cumprod, alphas_cumprod_prev, sqrt_recip_alphas, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, posterior_variance = get_ddpm_scheduler_variables(timesteps=timesteps)

device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# 初始化模型
latent_dim = 256
hidden_dims = [2048, 2048, 2048, 2048] 
# hidden_dims = [1024, 1024, 1024, 1024] 
max_freq = 4  # Example max frequency for Fourier features
num_bands = 4  # Number of frequency bands
scalar_hidden_dims = [256,256,256,256]
diffusion_model = LatentDiffusionModel(latent_dim, hidden_dims, scalar_hidden_dims, max_freq, num_bands).to(device)
# diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/mugs_ddpm_cos_20k_test1/model.pth"
# diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/mugs_ddpm_cos_20k_l2loss/model.pth"
# diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/mugs_ddpm_cos_30k_l1huber_normT/model.pth"
# diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/mugs_ddpm_cos_30k_l2_normT/model.pth"
# diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/NEWmugs_ddpm_cos_10k_l1huber/model.pth"
# diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/mugs_ddpm_cos_30k_l1huber_normT_1024/model.pth"
# diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/mugs_ddpm_cos_100k_l1huber_normT_2048_steps10k/model.pth"

# 12-31
diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/mugs_ddpm_cos_200k_l1huber_normT_2048_dataNorm/model_epo199999.pth"
diffusion_ckpt = torch.load(diffusion_ckpt_path)
diffusion_model.load_state_dict(diffusion_ckpt['model'])
diffusion_model = diffusion_model.to(device)
print('Diffusion Model parameters:', sum(p.numel() for p in diffusion_model.parameters()))

In [None]:
time_value = 999
t = torch.tensor([time_value]*3)
tensor = torch.tensor([10]*3)
extract(posterior_variance, t, tensor)

In [None]:
sqrt_alphas_cumprod[0:100]

In [None]:
pred_x_0_list = []
@torch.no_grad()
def uncond_p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    # Note: \grad_{x_t} \log p(x_t|x_0) = - (\epsilon) / (\sqrt{1 - \alpha^{hat}_t})
    #                                   = - model(x, t) / sqrt_one_minus_alphas_cumprod_t
    model_out = model(x, t, timesteps)

    score_x = model_out / sqrt_one_minus_alphas_cumprod_t
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * score_x
    )
    # if t[0]>1:
    #     sqrt_alphas_cumprod_tminus1 = extract(sqrt_alphas_cumprod, t-1, x.shape)
    # else:
    #     sqrt_alphas_cumprod_tminus1 = torch.ones_like(x)
    # pred_x_0 = sqrt_recip_alphas_t * (x - sqrt_one_minus_alphas_cumprod_t * model_out) / sqrt_alphas_cumprod_tminus1
    
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x.shape)
    pred_x_0 = (x - sqrt_one_minus_alphas_cumprod_t * model_out) / sqrt_alphas_cumprod_t
    pred_x_0_list.append(pred_x_0)
    
    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 
        
# Algorithm 2:
@torch.no_grad()
def uncond_p_sample_loop(model, shape, return_traj = False):
    # assert eta in [0, 1]
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    x_t = torch.randn(shape, device=device)
    traj = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        x_t_minus1 = uncond_p_sample(model, x_t, torch.full((b,), i, device=device, dtype=torch.long), i)
        x_t = x_t_minus1
        traj.append(x_t.cpu().numpy())
    if return_traj:
        return traj
    else:    
        return x_t


In [None]:
torch.manual_seed(12)
bs = 5
eta = 1.
trajs = uncond_p_sample_loop(model=diffusion_model, shape=(bs, 256, 4), return_traj=True)

In [None]:

from init import get_AEmodel_cfg
from core.models import get_model
category = "mugs"
# category = "kit4cates"
# category = "chairs"
config_Only_model = get_AEmodel_cfg()
ModelClass = get_model(config_Only_model["model"]["model_name"])
model = ModelClass(config_Only_model)
ckpt_path = f"/home/ziran/se3/EFEM/weights/{category}.pt"
# ckpt_path = f"/home/ziran/se3/EFEM/lib_shape_prior/log/12_10_shape_prior_mugs_old/12_10_shape_prior_mugs_FOR_hopefullybetterAE/checkpoint/15409.pt"
ckpt = torch.load(ckpt_path)
model.network.load_state_dict(ckpt['model_state_dict'])
model.network = model.network.to(device)


In [None]:
len(pred_x_0_list)

# pred_x_0_list[0].shape

In [None]:
codebook_path = f"/home/ziran/se3/EFEM/cache/mugs.npz"
train_ds = CustomDataset(codebook_path, normalization = True)

## Traj over time(predicted x_0)

In [None]:

time_slice = slice(1, 1001, 50)
print(time_slice)
sample_idx = 2
a = torch.stack([item[sample_idx] for item in pred_x_0_list[time_slice]], dim=0)
a.shape


In [None]:
from viz_x import viz_x, viz_x_img

# viz_x(a, model, device)
single_images, combined_image = viz_x_img(a, model, device, train_ds.normal_params)
from IPython.display import display
display(combined_image)

## Traj over time (x_t)

In [None]:
time_slice = slice(990, 1001, 1)
sample_idx = 0

latent_x = np.stack(trajs[time_slice])[:,sample_idx,...]
latent_x = torch.from_numpy(latent_x)
latent_x.shape


In [None]:
from viz_x import viz_x
viz_x(latent_x, model, device)


## simple viz

In [None]:
from viz_x import viz_x
# viz_x(latent_x, model, device)

viz_x(torch.tensor(trajs[-1]), model, device, train_ds.normal_params)