## 08. Reverse Diffusion Helper

Iteratively de-noise a random action vector **50 steps** back to a feasible Pendulum command.


In [None]:
# 08. Reverse diffusion sampling for 1-D action

# Function to simulate reverse diffusion process
def sample_action(model, obs, n_steps=50, device="cpu"):
    """
    Runs reverse diffusion starting from noise to generate a Pendulum action.
    obs: torch.Tensor of shape (3,)
    returns: torch.Tensor of shape (1,)
    """
    model.eval()
    with torch.no_grad():
        obs = obs.unsqueeze(0).to(device)      # [1, 3]
        obs = obs / np.pi                      # same normalization used in training

        x = torch.randn(1, 1).to(device)       # start from noise in action space

        for step in reversed(range(n_steps)):
            t = torch.tensor([step], device=device)
            pred_noise = model(obs, x, t)
            x = x - pred_noise * 0.1

        return x.squeeze(0)

### 09. Sample an Action from the Trained Policy

Finally, load the **latest epoch checkpoint**, supply a sample state  
`[cos θ = 1, sin θ = 0, θ̇ = 0]`, and generate a 1-D torque command.

In [None]:
# 09. In-notebook sampling from trained model

# A plausible pendulum state: [cos(theta), sin(theta), theta_dot]
obs_sample = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32)   # shape (3,)

# Load the most recent model checkpoint from the checkpoint directory
CKPT_DIR = "/mnt/cluster_storage/pendulum_diffusion/pendulum_diffusion_ckpts"

# Pick latest by sorted creation time (or filename if using uuid naming)
latest = sorted(os.listdir(CKPT_DIR))[-1]
model_path = os.path.join(CKPT_DIR, latest, "model.pt")

model = DiffusionPolicy(obs_dim=3, act_dim=1)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

# Run reverse diffusion sampling
action = sample_action(model, obs_sample, n_steps=50, device="cpu")
print("Sampled action:", action)

In [None]:
# 10. Cleanup -- delete checkpoints and metrics from model training

TARGET_PATH = "/mnt/cluster_storage/pendulum_diffusion"

if os.path.exists(TARGET_PATH):
    shutil.rmtree(TARGET_PATH)
    print(f"✅ Deleted everything under {TARGET_PATH}")
else:
    print(f"⚠️ Path does not exist: {TARGET_PATH}")

### 🎉 Wrapping Up & Next Steps  

Great job reaching the finish line. You’ve transformed a synthetic control demo into a **Ray-native, real-data pipeline**, training a diffusion policy across multiple GPUs, surviving worker restarts, and sampling feasible actions, all within a distributed Ray environment.

You should now feel confident:

* Logging continuous-control trajectories directly into a **Ray Dataset** for scalable preprocessing  
* Streaming data into a **Ray Train** workload using Ray Data + Lightning with minimal integration overhead  
* Saving structured checkpoints with `ray.train.report()` and leveraging **Ray’s fault-tolerant recovery**  
* Running reverse diffusion sampling directly in-notebook—or scaling it up as **Ray remote tasks**  

---

### 🚀 Where can you take this next?

1. **Evaluate in the Environment**  
   * Load the best checkpoint, deploy the policy in Gym’s `Pendulum-v1`, and log episode returns.  
   * Compare against baseline behavior cloning or TD3/TD3+BC.

2. **Larger & Richer Datasets**  
   * Generate 100 k+ steps with a scripted controller or collect data from a learned agent.  
   * Swap in other classic-control tasks like `CartPole` or `MountainCar`.

3. **Model & Loss Upgrades**  
   * Add timestep embeddings or a small transformer for better temporal reasoning.  
   * Experiment with different noise schedules or auxiliary consistency losses.

4. **Hyperparameter Sweeps**  
   * Wrap the training loop in **Ray Tune** and grid-search learning rate, hidden size, or diffusion steps.  
   * Use Tune’s automatic checkpoint pruning to keep only the top-N runs.

5. **Mixed Precision & Performance**  
   * Enable `torch.set_float32_matmul_precision('high')` to leverage A10G Tensor Cores.  
   * Profile GPU utilization across workers and tune batch size accordingly.

6. **Real Robotics Logs**  
   * Replace Pendulum with logs from a real robotic apparatus stored in Parquet; Ray Data shards them just the same.

7. **Serving the Policy**  
   * Export the trained MLP to TorchScript and deploy with **Ray Serve** for low-latency inference.  
   * Hook it to a real-time simulator or a web dashboard.

8. **End-to-End MLOps**  
   * Track checkpoints and metrics with MLflow or Weights & Biases.  
   * Schedule nightly Ray Jobs on Anyscale to retrain as new data arrives.