## 13. Reverse diffusion sampler  
A simple Euler-style loop that starts from Gaussian noise and iteratively subtracts the model’s predicted noise.  
This isn't production-grade sampling, but it's suitable for illustrating inference after training. Use **Ray Data** when performing inference at scale.

In [None]:
# 13. Reverse diffusion sampling

def sample_image(model, steps=50, device="cpu"):
    """Generate an image by iteratively de-noising random noise."""
    model.eval()
    with torch.no_grad():
        img = torch.randn(1, 3, 224, 224, device=device)
        for step in reversed(range(steps)):
            t = torch.tensor([step], device=device)
            pred_noise = model(img, t)
            img = img - pred_noise * 0.1                      # simple Euler update
        # Rescale back to [0,1]
        img = torch.clamp((img * 0.5 + 0.5), 0.0, 1.0)
        return img.squeeze(0).cpu().permute(1,2,0).numpy()

### 14. Generate and display samples from the best checkpoint 
Load the model weights from `best_ckpt`, move to GPU if available, generate three images, and show them side-by-side.  
Remember that when using a tiny CNN and only 10 epochs, these samples look noise-like. If you replace the backbone or train longer, you will to see better quality.

In [None]:
# 14. Generate and display samples

import glob
from ray.train import Checkpoint

assert best_ckpt is not None, "Checkpoint is missing. Did training run and complete?"

# Restore model weights from Ray Train checkpoint (Lightning-first)
model = PixelDiffusion()

with best_ckpt.as_directory() as ckpt_dir:
    # Prefer Lightning checkpoints (*.ckpt) saved by ModelCheckpoint
    ckpt_files = glob.glob(os.path.join(ckpt_dir, "*.ckpt"))
    if ckpt_files:
        pl_ckpt = torch.load(ckpt_files[0], map_location="cpu")
        state = pl_ckpt.get("state_dict", pl_ckpt)
        model.load_state_dict(state, strict=False)
    elif os.path.exists(os.path.join(ckpt_dir, "model.pt")):
        # Fallback for older/manual checkpoints
        state = torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cpu")
        model.load_state_dict(state, strict=False)
    else:
        raise FileNotFoundError(
            f"No Lightning .ckpt or model.pt found in: {ckpt_dir}"
        )

# Move to device and sample
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Generate three images
samples = [sample_image(model, steps=50, device=device) for _ in range(3)]

fig, axs = plt.subplots(1, 3, figsize=(9, 3))
for ax, img in zip(axs, samples):
    ax.imshow(img)
    ax.axis("off")
plt.suptitle("Food-101 Diffusion Samples (unconditional)")
plt.tight_layout()
plt.show()

### 15. Clean up shared storage  
Reclaim cluster disk space by deleting the entire tutorial output directory.  
Run this only when you’re **sure** you don’t need the checkpoints or metrics anymore.

In [None]:
# 15. Cleanup -- delete checkpoints and metrics from model training

TARGET_PATH = "/mnt/cluster_storage/generative_cv"

if os.path.exists(TARGET_PATH):
    shutil.rmtree(TARGET_PATH)
    print(f"✅ Deleted everything under {TARGET_PATH}")
else:
    print(f"⚠️ Path does not exist: {TARGET_PATH}")

### Wrap up and next steps

In this tutorial, you used **Ray Train and Ray Data on Anyscale** to scale a compact diffusion-policy workload, from raw JPEG bytes to distributed training and sampling, without changing the core PyTorch logic. You should now feel confident:

* Using **Ray Data** to decode, normalize, and shard large image datasets in parallel  
* Scaling training across multiple GPUs using **TorchTrainer** and a Ray-native `train_loop`  
* Managing distributed training state with **Ray Checkpoints** and automatic resume  
* Running fault-tolerant multi-node jobs on Anyscale without orchestration scripts  

---

### Where can you take this next?

Below are a few directions you might explore to adapt or extend the pattern:

1. **Backbones and architecture upgrades**  
   * Swap in a larger ResNet or another vision model for much better generative performance.  
   * Try pre-trained encoders and fine-tune only the diffusion-specific layers.

2. **Conditional diffusion**  
   * Use the `label` column to condition the model (for example, class-conditioning).  
   * Compare unconditional versus conditional generation side by side.

3. **Sampling improvements**  
   * Replace naive reverse diffusion with De-noising Diffusion Implicit Models (DDIM), Pseudo Numerical Methods for Diffusion Models (PNDM), or learned de-noisers.  
   * Add timestep embeddings or noise schedules to increase model expressiveness.

4. **Longer training and mixed precision**  
   * Increase the `max_epochs` and enable Automatic Mixed Precision (AMP) for faster training with less memory.  
   * Visualize convergence and training stability across longer runs.

5. **Hyperparameter sweeps**  
   * Use **Ray Tune** to search over learning rates, model size, or sampling steps.  
   * Leverage Tune’s reporting to schedule early stopping or checkpoint pruning.

6. **Data handling and scaling**  
   * Shard the dataset into multiple Parquet files and distribute across more workers.  
   * Store and load datasets from S3 or other cloud storage.

7. **Image quality evaluation**  
   * Log Fréchet Inception Distance (FID) scores, perceptual similarity, or diffusion-specific metrics.  
   * Compare generated samples from different checkpoints or backbones.

8. **Model serving**  
   * Package the reverse sampler into a Ray task or **Ray Serve** endpoint.  
   * Run a demo app that generates images on demand from a class name or random seed.

9. **End-to-end MLOps**  
   * Register the best checkpoint with MLflow or Weights & Biases.  
   * Wrap the training loop in a Ray Job and run it on a schedule with Anyscale.