## 4. DiffusionPolicy LightningModule

A tiny MLP that predicts the injected noise ϵ given:

- 3D normalized state  
- 1D *noisy* action  
- scalar timestep (normalized by `max_t`)

It logs per-epoch losses so you can plot later.

In [None]:
# 04. DiffusionPolicy for low-dim observation (3D) and action (1D)

class DiffusionPolicy(pl.LightningModule):
    """Tiny MLP that predicts injected noise ϵ given (obs, noisy_action, timestep)."""

    def __init__(self, obs_dim: int = 3, act_dim: int = 1, max_t: int = 1000):
        super().__init__()
        self.max_t = max_t

        # 3D obs + 1D action + 1 timestep → 1D noise
        self.net = nn.Sequential(
            nn.Linear(obs_dim + act_dim + 1, 128),
            nn.ReLU(),
            nn.Linear(128, act_dim),
        )
        self.loss_fn = nn.MSELoss()

    # ---------- forward ----------
    def forward(self, obs, noisy_action, timestep):
        t = timestep.view(-1, 1).float() / self.max_t
        x = torch.cat([obs, noisy_action, t], dim=1)
        return self.net(x)

    # ---------- shared loss ----------
    def _shared_step(self, batch):
        pred = self.forward(
            batch["obs"].float(),
            batch["noisy_action"],
            batch["timestep"],
        )
        return self.loss_fn(pred, batch["noise"])

    # ---------- training / validation ----------
    def training_step(self, batch, batch_idx):
        loss = self._shared_step(batch)
        self.log("train_loss", loss, on_epoch=True, prog_bar=False, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, prog_bar=False, sync_dist=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
