## Twisted Diffusion Sampler (TDS)

The **Twisted Diffusion Sampler (TDS)** is a Sequential Monte Carlo (SMC) method designed for conditional sampling from diffusion models. It addresses the challenge of sampling from a posterior distribution $p(x_0 | y)$ given a measurement $y$, where the prior $p(x_0)$ is defined by a pre-trained diffusion model.

Standard diffusion guidance methods (like classifier guidance) approximate the conditional score $\nabla_{x_t} \log p(x_t | y)$. TDS improves upon this by introducing a "twisting" function (or auxiliary target) $\tilde{p}(y|x_t)$ to guide the intermediate proposal distributions in an SMC framework.

### Key Steps in Algorithm 1:
1.  **Initialization**: Start with particles $x_T$ from the standard normal prior and assign initial weights based on the twisting function.
2.  **Resampling**: At each timestep, resample particles based on their importance weights to focus on high-probability regions.
3.  **Conditional Score Approximation**: Compute a "twisted" score that combines the unconditional diffusion score with the gradient of the twisting function $\nabla_{x_t} \log \tilde{p}(y|x_t)$.
4.  **Proposal Step**: Propagate particles to the next timestep $x_{t-1}$ using the twisted score (similar to guidance).
5.  **Weight Update**: Update particle weights to account for the discrepancy between the optimal target distribution and the proposal distribution.

In [5]:
import torch
import torch.nn.functional as F
import numpy as np

def twisted_diffusion_sampler(
    model,              # The diffusion model (e.g., a UNet that predicts noise or x_0)
    scheduler,          # Diffusion scheduler (e.g., DDIM/DDPM scheduler)
    y,                  # The observed measurement (y)
    twisting_func,      # Function that computes log_prob of y given x_t: log_p(y|x_t)
    num_particles=2,    # K: Number of particles
    num_steps=100,        # T: Number of timesteps
    shape=(3, 256, 256),# Shape of x
    device='cuda'
):
    """
    Implementation of Algorithm 1: Twisted Diffusion Sampler (TDS)
    """

    # 1. Initialization
    # x_T ~ p(x_T) (Standard Normal)
    # shape needs to include batch size K
    x_t = torch.randn((num_particles, *shape), device=device)

    # Initial weights
    # w_k = p_tilde(y | x_T)
    with torch.no_grad():
        log_weights = twisting_func(x_t, y, num_steps)
        weights = torch.softmax(log_weights, dim=0)

    # Time loop T-1 to 0
    # Note: scheduler timesteps usually go from T-1 down to 0
    scheduler.set_timesteps(num_steps)

    for i, t in enumerate(scheduler.timesteps):
        # 2. Resampling
        # Resample indices based on weights
        indices = torch.multinomial(weights, num_particles, replacement=True)
        x_t = x_t[indices]

        # Current twisting value for the resampled particles (log_p_tilde_{t+1})
        # We re-evaluate or carry over. For simplicity, re-evaluating here or caching is needed.
        # Since x_t changed, we should technically use the weights' associated probs,
        # but let's re-calculate gradients typically needed for the score.

        # Enable grad for the score approximation step
        x_t = x_t.detach().requires_grad_(True)

        # 3. Conditional Score Approximation
        # Estimate x_0 (x_hat_theta)
        model_output = model(x_t, t) # Assuming model outputs noise or x_0

        # Convert model output to x_0 prediction using scheduler
        # (This depends on specific scheduler API, here is a generic placeholder)
        # For DDIM/DDPM, we usually get prev_sample, but we need x_0 for the formula.
        # Let's assume a function `get_x0_from_noise` exists or scheduler provides it.
        # step_output = scheduler.step(model_output, t, x_t)
        # x_0_hat = step_output.pred_original_sample

        # Simplified: let's assume we calculate the score update directly:
        # s_tilde = (x_0_hat - x_t) / sigma^2 + grad(log_p_tilde_{t+1})

        # Calculate log twisting function for gradients
        log_twisting = twisting_func(x_t, y, t)
        grad_log_twisting = torch.autograd.grad(log_twisting.sum(), x_t)[0]

        # Standard diffusion score (unconditional)
        # score_uncond = (x_0_hat - x_t) / sigma_t^2
        # Often schedulers give us the mean of p(x_{t-1} | x_t).
        # Let's use the "Proposal" logic defined in the algorithm:
        # x_{t-1} ~ N(x_t + sigma^2 * s_tilde, ...)

        # 4. Proposal Step (Transition)
        # We combine the unconditional transition with the twisting gradient
        # This is equivalent to Classifier Guidance usually.
        with torch.no_grad():
            # Standard step to get parameters for p(x_{t-1} | x_t)
            # This usually returns x_{t-1} mean and variance.
            step_output = scheduler.step(model_output, t, x_t)

            prev_sample_mean = step_output.prev_sample # This usually includes the drift towards x_0

            # Adjust mean with twisting gradient
            # The scale factor depends on the variance schedule (beta_t or sigma_t)
            # Typically: new_mean = old_mean + variance * gradient
            variance = scheduler._get_variance(t) if hasattr(scheduler, '_get_variance') else 1.0 # Placeholder

            proposal_mean = prev_sample_mean + variance * grad_log_twisting

            # Sample x_{t-1}
            noise = torch.randn_like(x_t)
            x_prev = proposal_mean + torch.sqrt(torch.tensor(variance)) * noise

            # 5. Weight Update
            # w = p(x_t | x_{t+1}) * p_tilde_t / [ p_tilde(x_t | x_{t+1}, y) * p_tilde_{t+1} ]
            # In many Twisted implementations, if the proposal is optimal, weights are uniform.
            # Otherwise, we compute the importance weights.

            # Calculate new twisting func value
            log_p_tilde_t = twisting_func(x_prev, y, t-1 if t>0 else 0)

            # Update weights (simplified log domain update)
            # For exact TDS, we need the transition probability densities.
            # Often approximated or set to 1 if proposal is very good.
            log_weights = log_p_tilde_t # + transition terms correction
            weights = torch.softmax(log_weights, dim=0)

            x_t = x_prev

    return x_t

In [6]:
from diffusers import DDPMScheduler, UNet2DModel

# Chargement d'un modèle adapté aux images (ex: entraîné sur des visages ou paysages)
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256")
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")

Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
An error occurred while trying to fetch google/ddpm-celebahq-256: google/ddpm-celebahq-256 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
