In [3]:
# Conditional Reverse Diffusion Sampling

import torch
import torch.nn as nn
import math
import os
from google.colab import drive

drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_PATH = "/content/drive/MyDrive/AMP-Generation/models/diffusion_conditional_final.pth"
SAVE_PATH  = "/content/drive/MyDrive/AMP-Generation/data/generated_latent_amp.pth"

# Diffusion hyperparameters
T = 500
beta_start = 1e-4
beta_end   = 0.02

betas = torch.linspace(beta_start, beta_end, T, device=device)
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0)

# ᾱ_{t-1}
alphas_bar_prev = torch.cat(
    [torch.tensor([1.0], device=device), alphas_bar[:-1]],
    dim=0
)

# β̃_t (posterior variance)
beta_tilde = (1.0 - alphas_bar_prev) / (1.0 - alphas_bar) * betas

# Time embedding
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, t):
        half = self.dim // 2
        scale = math.log(10000) / (half - 1)
        emb = torch.exp(torch.arange(half, device=t.device) * -scale)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        return self.mlp(emb)

# Conditional Latent Diffusion Model
class LatentDiffusion(nn.Module):
    def __init__(self, latent_dim=64, hidden_dim=512, num_classes=2):
        super().__init__()

        self.time_embed = TimeEmbedding(hidden_dim)
        self.cond_embed = nn.Embedding(num_classes, hidden_dim)
        self.fc_in = nn.Linear(latent_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=6
        )

        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.1)
        self.fc_out = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x, t, y):
        h = self.fc_in(x)
        h = h + self.time_embed(t) + self.cond_embed(y)
        h = self.norm(h)
        h = self.dropout(h)
        h = self.transformer(h.unsqueeze(1)).squeeze(1)
        return self.fc_out(h)  # predicts X0


model = LatentDiffusion().to(device)
checkpoint = torch.load(MODEL_PATH, map_location=device)

if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
    model.load_state_dict(checkpoint["model_state_dict"])
else:
    model.load_state_dict(checkpoint)

model.eval()

# Reverse diffusion sampling
@torch.no_grad()
def conditional_reverse_sampling(
    model,
    batch_size=512,
    num_batches=10,
    class_label=1  # AMP
):
    all_samples = []

    for _ in range(num_batches):
        # X_T ~ N(0, I)
        xt = torch.randn(batch_size, 64, device=device)
        y = torch.full((batch_size,), class_label, dtype=torch.long, device=device)

        for t in reversed(range(T)):
            t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)

            # Predict X0
            x0_pred = model(xt, t_tensor, y)

            # μ̃(Xt, X0) — paper equation
            coef1 = (
                torch.sqrt(alphas_bar_prev[t]) * betas[t]
                / (1.0 - alphas_bar[t])
            )
            coef2 = (
                torch.sqrt(alphas[t]) * (1.0 - alphas_bar_prev[t])
                / (1.0 - alphas_bar[t])
            )

            mean = coef1 * x0_pred + coef2 * xt

            if t > 0:
                noise = torch.randn_like(xt)
                xt = mean + torch.sqrt(beta_tilde[t])* noise
            else:
                xt = mean  # t = 0, no noise

        all_samples.append(xt.cpu())

    return torch.cat(all_samples, dim=0)

# Run generation
generated_latents = conditional_reverse_sampling(
    model=model,
    batch_size=512,
    num_batches=10,
    class_label=1
)

torch.save(generated_latents, SAVE_PATH)

print(f"Generated {generated_latents.shape[0]} AMP latent vectors")
print(f"Saved to: {SAVE_PATH}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Generated 5120 AMP latent vectors
Saved to: /content/drive/MyDrive/AMP-Generation/data/generated_latent_amp.pth
