## 8. Reverse diffusion helper

Iteratively de-noise a random action vector **50 steps** back to a feasible Pendulum command.


In [None]:
# 08. Reverse diffusion sampling for 1-D action

# Function to simulate reverse diffusion process
def sample_action(model, obs, n_steps=50, device="cpu"):
    """
    Runs reverse diffusion starting from noise to generate a Pendulum action.
    obs: torch.Tensor of shape (3,)
    returns: torch.Tensor of shape (1,)
    """
    model.eval()
    with torch.no_grad():
        obs = obs.unsqueeze(0).to(device)      # [1, 3]
        obs = obs / np.pi                      # Same normalization used in training

        x = torch.randn(1, 1).to(device)       # Start from noise in action space

        for step in reversed(range(n_steps)):
            t = torch.tensor([step], device=device)
            pred_noise = model(obs, x, t)
            x = x - pred_noise * 0.1

        return x.squeeze(0)

### 9. Sample an action from the trained policy

Finally, load the **latest epoch checkpoint**, supply a sample state  
`[cos θ = 1, sin θ = 0, θ̇ = 0]`, and generate a 1-D torque command.

In [None]:
# 09. In-notebook sampling from trained model (Ray Lightning checkpoint)

# A plausible pendulum state: [cos(theta), sin(theta), theta_dot]
obs_sample = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32)   # shape (3,)

assert best_ckpt is not None, "No checkpoint found — did training complete successfully?"

# Load the trained model from Ray's latest Lightning checkpoint
model = DiffusionPolicy(obs_dim=3, act_dim=1)

with best_ckpt.as_directory() as ckpt_dir:
    # RayTrainReportCallback saves a file named "checkpoint.ckpt"
    ckpt_file = os.path.join(ckpt_dir, "checkpoint.ckpt")
    if not os.path.exists(ckpt_file):
        # Fallback: search any .ckpt file if name differs
        candidates = glob.glob(os.path.join(ckpt_dir, "*.ckpt"))
        ckpt_file = candidates[0] if candidates else None

    assert ckpt_file is not None, f"No Lightning checkpoint found in {ckpt_dir}"
    state = torch.load(ckpt_file, map_location="cpu")
    model.load_state_dict(state.get("state_dict", state), strict=False)

# Move to device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Run reverse diffusion sampling
action = sample_action(model, obs_sample, n_steps=50, device=device)
print("Sampled action:", action)

### 10. Clean up

When you're finished, release Ray resources and clear any temporary files.  
This ensures the cluster is ready for other jobs and avoids unnecessary storage costs.

In [None]:
# 10. Cleanup -- delete checkpoints and metrics from model training

TARGET_PATH = "/mnt/cluster_storage/pendulum_diffusion"

if os.path.exists(TARGET_PATH):
    shutil.rmtree(TARGET_PATH)
    print(f"✅ Deleted everything under {TARGET_PATH}")
else:
    print(f"⚠️ Path does not exist: {TARGET_PATH}")

### Wrap up and next steps  

You transformed a synthetic control demo into a **Ray-native, real-data pipeline**, training a diffusion policy across multiple GPUs, surviving worker restarts, and sampling feasible actions, all within a distributed Ray environment.

You should now feel confident:

* Logging continuous-control trajectories directly into a **Ray Dataset** for scalable preprocessing  
* Streaming data into a **Ray Train** workload using Ray Data and Lightning with minimal integration overhead  
* Saving structured checkpoints automatically through **Lightning + Ray Train callbacks**, ensuring seamless **fault-tolerant recovery**  
* Running reverse diffusion sampling directly in-notebook

---

### Where can you take this next?

The following are a few directions you can explore to extend or adapt this workload:

1. **Evaluate in the environment**  
   * Load the best checkpoint, deploy the policy in Gym’s `Pendulum-v1`, and log episode returns.  
   * Compare against baseline behavior cloning or TD3/TD3+BC.

2. **Larger and richer datasets**  
   * Generate 100 k+ steps with a scripted controller or collect data from a learned agent.  
   * Swap in other classic-control tasks like `CartPole` or `MountainCar`.

3. **Model and loss upgrades**  
   * Add timestep embeddings or a small transformer for better temporal reasoning.  
   * Experiment with different noise schedules or auxiliary consistency losses.

4. **Hyperparameter sweeps**  
   * Wrap the training loop in **Ray Tune** and grid-search learning rate, hidden size, or diffusion steps.  
   * Use Tune’s automatic checkpoint pruning to keep only the top-N runs.

5. **Mixed precision and performance**  
   * Enable `torch.set_float32_matmul_precision('high')` to leverage A10G Tensor Cores.  
   * Profile GPU utilization across workers and tune batch size accordingly.

6. **Real robotics logs**  
   * Replace Pendulum with logs from a real robotic apparatus stored in Parquet; Ray Data shards them the same way.

7. **Serving the policy**  
   * Export the trained MLP to TorchScript and deploy with **Ray Serve** for low-latency inference.  
   * Hook it to a real-time simulator or a web dashboard.

8. **End-to-end MLOps**  
   * Track checkpoints and metrics with MLflow or Weights & Biases.  
   * Schedule nightly Ray jobs on Anyscale to retrain as new data arrives.