In [4]:
# Conditional Latent Diffusion Fine-Tuning

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import math
import os
from google.colab import drive

drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Diffusion hyperparameters
T = 500
beta_start = 1e-4
beta_end = 0.02

betas = torch.linspace(beta_start, beta_end, T, device=device)
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0)

# Time embedding
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, t):
        half = self.dim // 2
        scale = math.log(10000) / (half - 1)
        emb = torch.exp(torch.arange(half, device=t.device) * -scale)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        return self.mlp(emb)

# Conditional Latent Diffusion Model
class LatentDiffusion(nn.Module):
    def __init__(self, latent_dim=64, hidden_dim=512, num_classes=2):
        super().__init__()

        self.time_embed = TimeEmbedding(hidden_dim)
        self.cond_embed = nn.Embedding(num_classes, hidden_dim)

        self.fc_in = nn.Linear(latent_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=6
        )

        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.1)
        self.fc_out = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x, t, y):
        h = self.fc_in(x)
        h = h + self.time_embed(t) + self.cond_embed(y)
        h = self.norm(h)
        h = self.dropout(h)
        h = self.transformer(h.unsqueeze(1)).squeeze(1)
        return self.fc_out(h)

# Forward diffusion (q)
def q_sample(x0, t, noise):
    a_bar = alphas_bar[t].unsqueeze(1)
    return torch.sqrt(a_bar) * x0 + torch.sqrt(1.0 - a_bar) * noise


UNCOND_MODEL_PATH = "/content/drive/MyDrive/AMP-Generation/checkpoints/diffusion_paper_final_ep100.pth"
POS_LATENT_PATH = "/content/drive/MyDrive/AMP-Generation/data/latent_cond_pos.pth"
NEG_LATENT_PATH = "/content/drive/MyDrive/AMP-Generation/data/latent_cond_neg.pth"
SAVE_PATH = "/content/drive/MyDrive/AMP-Generation/models/diffusion_conditional_final.pth"
CHECKPOINT_DIR = "/content/drive/MyDrive/AMP-Generation/checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)


# Load model (from unconditional)
model = LatentDiffusion().to(device)
checkpoint = torch.load(UNCOND_MODEL_PATH, map_location=device)

if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
    model.load_state_dict(checkpoint["model_state_dict"], strict=False)
else:
    model.load_state_dict(checkpoint, strict=False)

model.train()

# Load conditional latent data
pos_latent = torch.load(POS_LATENT_PATH)
neg_latent = torch.load(NEG_LATENT_PATH)

pos_labels = torch.ones(pos_latent.size(0), dtype=torch.long)
neg_labels = torch.zeros(neg_latent.size(0), dtype=torch.long)

latents = torch.cat([pos_latent, neg_latent], dim=0)
labels = torch.cat([pos_labels, neg_labels], dim=0)

dataset = TensorDataset(latents, labels)
loader = DataLoader(
    dataset,
    batch_size=512,
    shuffle=True,
    drop_last=True
)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training loop
EPOCHS = 50
CHECKPOINT_INTERVAL = 5

for epoch in range(EPOCHS):
    total_loss = 0.0

    for x0, y in loader:
        x0 = x0.to(device)
        y = y.to(device)

        t = torch.randint(0, T, (x0.size(0),), device=device)
        noise = torch.randn_like(x0)

        xt = q_sample(x0, t, noise)
        x_pred = model(xt, t, y)

        loss = F.mse_loss(x_pred, x0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.6f}")

    # Save Checkpoint every 5 epochs
    if epoch % CHECKPOINT_INTERVAL == 0:
        ckpt_path = os.path.join(CHECKPOINT_DIR, f"cond_diffusion_epoch_{epoch}.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, ckpt_path)
        print(f"--- Checkpoint saved: {ckpt_path} ---")

# Save conditional model
torch.save(
    {"model_state_dict": model.state_dict()},
    SAVE_PATH
)

print(f"Saved conditional diffusion model to: {SAVE_PATH}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Epoch 001 | Loss: 2.065861
--- Checkpoint saved: /content/drive/MyDrive/AMP-Generation/checkpoints/cond_diffusion_epoch_0.pth ---
Epoch 002 | Loss: 1.438597
Epoch 003 | Loss: 1.359149
Epoch 004 | Loss: 1.321990
Epoch 005 | Loss: 1.297335
Epoch 006 | Loss: 1.279717
--- Checkpoint saved: /content/drive/MyDrive/AMP-Generation/checkpoints/cond_diffusion_epoch_5.pth ---
Epoch 007 | Loss: 1.270250
Epoch 008 | Loss: 1.256148
Epoch 009 | Loss: 1.246658
Epoch 010 | Loss: 1.245291
Epoch 011 | Loss: 1.237857
--- Checkpoint saved: /content/drive/MyDrive/AMP-Generation/checkpoints/cond_diffusion_epoch_10.pth ---
Epoch 012 | Loss: 1.232610
Epoch 013 | Loss: 1.221161
Epoch 014 | Loss: 1.224280
Epoch 015 | Loss: 1.224826
Epoch 016 | Loss: 1.211987
--- Checkpoint saved: /content/drive/MyDrive/AMP-Generation/checkpoints/cond_diffusion_epoch_15.pth ---
Epoch 017 | Loss: 1.21235