In [32]:
# -----------------------------------------------------------
# 🟢 Cell 1 – Imports, seed, device
# -----------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as tvu

import numpy as np
import math, random, os, time, itertools
from tqdm import tqdm

import intel_extension_for_pytorch as ipex
import matplotlib.pyplot as plt

# Uncomment below if you want experiment tracking
# import wandb
# wandb.init(project="mnist_diffusion", config={"task": "2.5"})

SEED = 42
random.seed(SEED); 
np.random.seed(SEED); 
torch.manual_seed(SEED)

# ── Device selection ─────────────────────────────────────────
# IPEX registers 'xpu' for Intel GPUs via DPC++
if hasattr(torch, "xpu") and torch.xpu.is_available():
    DEVICE = torch.device("xpu")
else:
    DEVICE = torch.device("cpu")

print("Using device:", DEVICE)


Using device: xpu


In [21]:
# -----------------------------------------------------------
# 🔧 Cell X – Hyperparameters (centralized)
# -----------------------------------------------------------

# Random seed for reproducibility
SEED          = 42

# Device
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"

# Data
BATCH_SIZE    = 128
NUM_WORKERS   = 2            # keep parallelism now that everything is pickle-safe
NUM_CLASSES   = 10
IMG_SHAPE     = (1, 28, 28)

# Diffusion process
T             = 200          # Number of timesteps
BETA_SCHEDULE = "cosine"     # Options: "linear", "cosine", etc.

# Model architecture
BASE_CHANNELS = 64           # "base_c" in UNet
TIME_EMB_DIM  = 128          # Dimensionality of time (and label) embeddings

# Optimization / Training
LR            = 2e-4         # Learning rate for AdamW
WEIGHT_DECAY  = 0.0          # If you wish to add weight decay
N_EPOCHS      = 15           # Number of training epochs

# Sampling
DDIM_ETA      = 0.0          # η parameter for DDIM sampler (0 → deterministic)

# Auxiliary classifier (for evaluation)
CLF_LR        = 1e-3         # Learning rate for TinyCNN classifier
CLF_EPOCHS    = 2            # Quick training epochs on MNIST

# (Optionally) WandB logging
USE_WANDB     = False
WANDB_PROJECT = "mnist_diffusion"


In [22]:
# -----------------------------------------------------------
# 🟢 Cell 2 – Dataset & label-conditioning helper (patched)
# -----------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))   # built-in, pickle-safe
])

train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

In [23]:
# -----------------------------------------------------------
# 🟢 Cell 3 – Diffusion schedule utilities
# -----------------------------------------------------------
def cosine_beta_schedule(timesteps, s=0.008):
    """
    Cosine schedule from Nichol & Dhariwal 2021 (DDPM++).
    Returns a tensor of betas with shape [T].
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi / 2)**2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clamp(betas, 1e-8, 0.999)

T = 200               # diffusion steps (kept small for MNIST)
betas  = cosine_beta_schedule(T).to(DEVICE)
alphas = 1. - betas
alphas_cumprod     = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod= torch.sqrt(alphas_cumprod)
sqrt_one_minus_acp = torch.sqrt(1 - alphas_cumprod)


In [24]:
# -----------------------------------------------------------
# 🟢 Cell 4 – Positional & label embeddings
# -----------------------------------------------------------
class SinusoidalPosEmb(nn.Module):
    """
    Standard 1-D sinusoidal embeddings for timestep t (shape [B]).
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half  = self.dim // 2
        emb   = math.log(10000) / (half - 1)
        emb   = torch.exp(torch.arange(half, device=device) * -emb)
        emb   = t[:, None] * emb[None, :]
        emb   = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb

class LabelEmbedding(nn.Module):
    def __init__(self, num_classes, dim):
        super().__init__()
        self.emb = nn.Embedding(num_classes, dim)
    def forward(self, y):
        return self.emb(y)


In [25]:
# -----------------------------------------------------------
# 🟢 Cell 5 – A *tiny* UNet denoiser  (patched channel counts)
# -----------------------------------------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_c: int, out_c: int, time_emb_dim: int):
        """
        Simple ResNet block with FiLM-style time/label conditioning.
        """
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_c)
        )
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.GroupNorm(8, out_c),
            nn.SiLU(),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.GroupNorm(8, out_c)
        )
        self.res_conv = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()

    def forward(self, x, t_emb):
        h = self.block[0](x)          # first conv
        h = self.block[1](h)          # GN
        h = h + self.mlp(t_emb)[:, :, None, None]   # add conditioning
        h = self.block[2:](h)         # SiLU + conv + GN
        return h + self.res_conv(x)   # residual add


class SimpleUNet(nn.Module):
    """
    UNet-ish network that predicts ε for DDPM, conditioned on timestep + class label.
    Channel dimensions have been fixed to avoid concat-mismatch errors.
    """
    def __init__(self, img_ch=1, base_c=64, time_emb_dim=128, num_classes=10):
        super().__init__()

        # ---------- embeddings ----------
        self.time_emb = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 4),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim)
        )
        self.label_emb = LabelEmbedding(num_classes, time_emb_dim)

        # ---------- down path ----------
        self.down1 = ResidualBlock(img_ch, base_c, time_emb_dim)        # 1  → 64
        self.pool1 = nn.MaxPool2d(2)                                     # 28→14
        self.down2 = ResidualBlock(base_c, base_c * 2, time_emb_dim)    # 64 →128
        self.pool2 = nn.MaxPool2d(2)                                     # 14→7

        # ---------- bottleneck ----------
        self.mid = ResidualBlock(base_c * 2, base_c * 2, time_emb_dim)   # 128→128

        # ---------- up path ----------
        self.up2  = nn.ConvTranspose2d(base_c * 2, base_c, 2, stride=2)  # 7→14, 128→64
        # [u2 (64) ⊕ d2 (128)] = 192 channels
        self.res2 = ResidualBlock(base_c * 3, base_c, time_emb_dim)      # 192→64

        self.up1  = nn.ConvTranspose2d(base_c, base_c // 2, 2, stride=2) # 14→28, 64→32
        # [u1 (32) ⊕ d1 (64)] = 96 channels
        self.res1 = ResidualBlock(base_c + base_c // 2, base_c // 2, time_emb_dim)  # 96→32

        self.out_conv = nn.Conv2d(base_c // 2, img_ch, 1)                # 32→1

    def forward(self, x, t, y):
        """
        x : noisy image  [B, 1, 28, 28]
        t : timestep     [B]
        y : digit label  [B]
        """
        # --- combine timestep + label embeddings ---
        cond_emb = self.time_emb(t) + self.label_emb(y)

        # --- encoder ---
        d1 = self.down1(x, cond_emb)             # 64 ch
        d2 = self.down2(self.pool1(d1), cond_emb)# 128 ch

        # --- bottleneck ---
        m  = self.mid(self.pool2(d2), cond_emb)  # 128 ch

        # --- decoder + skip 2 ---
        u2 = self.up2(m)                         # 64 ch, 14×14
        u2 = torch.cat([u2, d2], dim=1)          # 192 ch
        u2 = self.res2(u2, cond_emb)             # 64 ch

        # --- decoder + skip 1 ---
        u1 = self.up1(u2)                        # 32 ch, 28×28
        u1 = torch.cat([u1, d1], dim=1)          # 96 ch
        u1 = self.res1(u1, cond_emb)             # 32 ch

        return self.out_conv(u1)                 # 1 ch (ε̂)


In [26]:
# -----------------------------------------------------------
# 🟢 Cell 6 – Forward-diffusion helper + loss
# -----------------------------------------------------------
@torch.no_grad()
def q_sample(x0, t, noise=None):
    """Diffuse the clean image x0 to x_t via q(x_t | x0)."""
    if noise is None:
        noise = torch.randn_like(x0)
    sqrt_acp     = sqrt_alphas_cumprod[t][:, None, None, None]
    sqrt_1m_acp  = sqrt_one_minus_acp[t][:, None, None, None]
    return sqrt_acp * x0 + sqrt_1m_acp * noise

def diffusion_loss(model, x0, y):
    """Simplified DDPM objective: predict ε directly (MSE)."""
    B = x0.size(0)
    t = torch.randint(0, T, (B,), device=DEVICE, dtype=torch.long)
    noise = torch.randn_like(x0)
    x_t   = q_sample(x0, t, noise)
    ε_pred = model(x_t, t, y)
    return F.mse_loss(ε_pred, noise)


In [29]:
# -----------------------------------------------------------
# 🟢 Cell 7 – Optimiser & training loop (IPEX on XPU, FP32 only)
# -----------------------------------------------------------

model = SimpleUNet(num_classes=NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

model, optimizer = ipex.optimize(
    model=model,
    optimizer=optimizer,
    dtype=torch.float32,    # pure FP32
    level="O1",             # basic fusions
    weights_prepack=False   # disable weight pre-packing
)

# Training loop (FP32 only)
for epoch in range(1, N_EPOCHS + 1):
    model.train()
    losses = []
    pbar = tqdm(train_loader, desc=f"Epoch {epoch:02d}")

    for xb, yb in pbar:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        # Forward + loss (pure FP32)
        loss = diffusion_loss(model, xb, yb)

        # Backward + step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Logging
        losses.append(loss.item())
        mean_loss = np.mean(losses)
        pbar.set_postfix(loss=f"{mean_loss:.4f}")

        # if wandb.run:
        #     wandb.log({"train_loss": loss.item()})

    print(f"Epoch {epoch}: mean loss {mean_loss:.4f}")


Epoch 01: 100%|██████████| 469/469 [04:46<00:00,  1.64it/s, loss=0.1312]


Epoch 1: mean loss 0.1312


Epoch 02: 100%|██████████| 469/469 [04:51<00:00,  1.61it/s, loss=0.0616]


Epoch 2: mean loss 0.0616


Epoch 03: 100%|██████████| 469/469 [04:44<00:00,  1.65it/s, loss=0.0535]


Epoch 3: mean loss 0.0535


Epoch 04: 100%|██████████| 469/469 [05:06<00:00,  1.53it/s, loss=0.0502]


Epoch 4: mean loss 0.0502


Epoch 05: 100%|██████████| 469/469 [05:08<00:00,  1.52it/s, loss=0.0473]


Epoch 5: mean loss 0.0473


Epoch 06: 100%|██████████| 469/469 [04:51<00:00,  1.61it/s, loss=0.0459]


Epoch 6: mean loss 0.0459


Epoch 07: 100%|██████████| 469/469 [04:49<00:00,  1.62it/s, loss=0.0446]


Epoch 7: mean loss 0.0446


Epoch 08: 100%|██████████| 469/469 [04:50<00:00,  1.62it/s, loss=0.0443]


Epoch 8: mean loss 0.0443


Epoch 09: 100%|██████████| 469/469 [04:44<00:00,  1.65it/s, loss=0.0434]


Epoch 9: mean loss 0.0434


Epoch 10: 100%|██████████| 469/469 [04:46<00:00,  1.64it/s, loss=0.0429]


Epoch 10: mean loss 0.0429


Epoch 11: 100%|██████████| 469/469 [04:54<00:00,  1.59it/s, loss=0.0423]


Epoch 11: mean loss 0.0423


Epoch 12: 100%|██████████| 469/469 [05:14<00:00,  1.49it/s, loss=0.0419]


Epoch 12: mean loss 0.0419


Epoch 13: 100%|██████████| 469/469 [05:07<00:00,  1.53it/s, loss=0.0417]


Epoch 13: mean loss 0.0417


Epoch 14: 100%|██████████| 469/469 [04:49<00:00,  1.62it/s, loss=0.0416]


Epoch 14: mean loss 0.0416


Epoch 15: 100%|██████████| 469/469 [04:45<00:00,  1.64it/s, loss=0.0414]

Epoch 15: mean loss 0.0414





In [30]:
# -----------------------------------------------------------
# 🟢 Cell 8 – Sampling (reverse diffusion)
# -----------------------------------------------------------
@torch.no_grad()
def p_sample(model, x, t, y, eta=0.0):
    """One reverse step p(x_{t-1} | x_t) – DDIM."""
    eps_pred = model(x, t, y)
    alpha_t = alphas[t][:, None, None, None]
    alpha_bar = alphas_cumprod[t][:, None, None, None]
    sqrt_inv_alpha = (1 / torch.sqrt(alpha_t))
    x0_pred = sqrt_inv_alpha * (x - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar)) * eps_pred)
    if t[0] == 0:
        return x0_pred
    alpha_bar_prev = alphas_cumprod[t - 1][:, None, None, None]
    sigma_t = eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_t / alpha_bar_prev)
    noise = torch.randn_like(x) if eta > 0 else 0
    x_prev = torch.sqrt(alpha_bar_prev) * x0_pred + torch.sqrt(1 - alpha_bar_prev - sigma_t**2) * eps_pred + sigma_t * noise
    return x_prev

@torch.no_grad()
def sample_ddim(model, y, num_steps=T, eta=0.0):
    """Generate images given labels y (tensor shape [B])."""
    B = y.size(0)
    x = torch.randn(B, *IMG_SHAPE, device=DEVICE)
    for t_ in reversed(range(num_steps)):
        t = torch.full((B,), t_, device=DEVICE, dtype=torch.long)
        x = p_sample(model, x, t, y, eta)
    return x

In [37]:
# -----------------------------------------------------------
# 🟢 Cell 9 – Quick visual sanity check
# -----------------------------------------------------------
model.eval()
labels = torch.arange(10, device=DEVICE)
with torch.no_grad():
    # images = sample_ddim(model, labels, eta=0.0).cpu()  # shape [10,1,28,28]
    images = sample_ddim(model, labels, eta=0.0)  # still on XPU

images = images.cpu()  # [10,1,28,28]
grid   = tvu.make_grid((images + 1) / 2, nrow=10, pad_value=1)  # un-normalize to [0,1]

plt.figure(figsize=(12,3))
plt.imshow(grid.permute(1,2,0), interpolation='nearest')
plt.title("Diffusion-generated digits (0–9)")
plt.axis('off')
plt.show()

tvu.save_image(grid, "diffusion_digits.png")
print("Saved grid → diffusion_digits.png")

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and xpu:0! (when checking argument for argument mat1 in method wrapper_XPU_addmm)

In [38]:
# -----------------------------------------------------------
# 🟢 Cell 10 – Auxiliary classifier for quantitative scores
# -----------------------------------------------------------
class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1),   nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*7*7, 256), nn.ReLU(),
            nn.Linear(256, 10)
        )
    def forward(self, x): return self.net(x)

clf = TinyCNN().to(DEVICE)
opt = torch.optim.Adam(clf.parameters(), lr=1e-3)

# --- quick 2-epoch training on MNIST train set ---
for epoch in range(2):
    clf.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = clf(xb)
        loss   = F.cross_entropy(logits, yb)
        opt.zero_grad(); loss.backward(); opt.step()

from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def evaluate_generated(model, n_samples=1000, target_digit=3):
    model.eval(); 
    clf.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for xb, yb in test_loader:          # test_loader uses the same [-1,1] normalization
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = clf(xb)
            all_preds.append(logits.argmax(1).cpu())
            all_labels.append(yb.cpu())

    all_preds  = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()
    prec, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro", zero_division=0)
    acc = accuracy_score(all_labels, all_preds)
    return acc, prec, recall, f1

acc, prec, rec, f1 = evaluate_generated(model, 1000, target_digit=3)
print(f"Digit ‘3’ — acc: {acc:.3f}  precision: {prec:.3f}  recall: {rec:.3f}  F1: {f1:.3f}")


Digit ‘3’ — acc: 0.986  precision: 0.986  recall: 0.986  F1: 0.986


In [40]:
# -----------------------------------------------------------
# 🟢 Cell 11 – Save the trained model
# -----------------------------------------------------------

# 1) Ensure model is in eval() mode (optional, but good practice)
model.eval()

# 2) Choose a path
SAVE_PATH = "simple_unet_mnist_diffusion.pth"

# 3) Save only the state_dict (recommended)
torch.save(model.state_dict(), SAVE_PATH)

print(f"Model weights saved to {SAVE_PATH}")


Model weights saved to simple_unet_mnist_diffusion.pth


In [41]:
# -----------------------------------------------------------
# 🟢 Cell 12 – Reload the model for future use
# -----------------------------------------------------------

# 1) Reconstruct the model architecture
#    (must match exactly the class definition & init args!)
loaded_model = SimpleUNet(num_classes=NUM_CLASSES).to(DEVICE)

# 2) Load the saved weights
#    map_location ensures compatibility if you switch device
state_dict = torch.load(SAVE_PATH, map_location=DEVICE)
loaded_model.load_state_dict(state_dict)

# 3) Set to eval mode for inference
loaded_model.eval()

print("Model reloaded and ready for inference.")


Model reloaded and ready for inference.


In [42]:
# Example: generate a “7”
with torch.no_grad():
    sample = sample_ddim(loaded_model,
                         torch.tensor([7], device=DEVICE),
                         eta=0.0)
# Move to CPU, un-normalize & display
import matplotlib.pyplot as plt
img = ((sample + 1) / 2).cpu()[0,0]
plt.imshow(img, cmap="gray"); plt.axis("off"); plt.show()


RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

In [None]:
# -----------------------------------------------------------
# 🟢 Cell 13 – Final test on untouched data (classifier sanity-check)
# -----------------------------------------------------------
model.eval(); clf.eval()
test_preds, test_labels = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(DEVICE)
        logits = clf(xb).cpu()
        test_preds.append(logits.argmax(1))
        test_labels.append(yb)
test_preds  = torch.cat(test_preds).numpy()
test_labels = torch.cat(test_labels).numpy()
prec, rec, f1, _ = precision_recall_fscore_support(test_labels, test_preds, average="macro")
acc = accuracy_score(test_labels, test_preds)
print(f"Classifier on real test images  →  acc={acc:.4f}  F1={f1:.4f}")


Classifier on real test images  →  acc=0.9862  F1=0.9861
