## 04. DiffusionPolicy LightningModule

A tiny MLP that predicts the injected noise ϵ given:

- 3D normalized state  
- 1D *noisy* action  
- scalar timestep (normalized by `max_t`)

It logs per-epoch losses to a shared JSON so you can plot later.

In [None]:
# 04. DiffusionPolicy for low-dim observation (3D) and action (1D)

class DiffusionPolicy(pl.LightningModule):
    def __init__(self, obs_dim: int = 3, act_dim: int = 1, max_t: int = 1000,
                 log_path: str = "/mnt/cluster_storage/pendulum_diffusion/epoch_metrics.json"):
        super().__init__()
        self.max_t   = max_t
        self.log_path = log_path

        # 3D obs  + 1-D action  + 1 timestep  → 1-D noise
        self.net = nn.Sequential(
            nn.Linear(obs_dim + act_dim + 1, 128),
            nn.ReLU(),
            nn.Linear(128, act_dim),
        )
        self.loss_fn = nn.MSELoss()
        self._train_losses, self._val_losses = [], []

    # ---------- forward ----------
    def forward(self, obs, noisy_action, timestep):
        t = timestep.view(-1, 1).float() / self.max_t
        x = torch.cat([obs, noisy_action, t], dim=1)
        return self.net(x)

    # ---------- shared step ----------
    def _shared_step(self, batch):
        pred = self.forward(batch["obs"].float(),
                            batch["noisy_action"],
                            batch["timestep"])
        return self.loss_fn(pred, batch["noise"])

    def training_step(self, batch, _):
        loss = self._shared_step(batch)
        self._train_losses.append(loss.item())
        return loss

    def validation_step(self, batch, _):
        loss = self._shared_step(batch)
        self._val_losses.append(loss.item())
        return loss

    # ---------- epoch-end ----------
    def on_train_epoch_end(self):
        rank = get_context().get_world_rank()
        if rank == 0:
            tr_avg = float(np.mean(self._train_losses))
            va_avg = float(np.mean(self._val_losses)) if self._val_losses else None
            print(f"[Epoch {self.current_epoch}] "
                  f"train={tr_avg:.4f}  val={va_avg if va_avg is not None else 'N/A'}")

            os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
            logs = []
            if os.path.exists(self.log_path):
                with open(self.log_path, "r") as f:
                    logs = json.load(f)
            logs.append({"epoch": self.current_epoch+1,
                         "train_loss": tr_avg,
                         "val_loss": va_avg})
            with open(self.log_path, "w") as f:
                json.dump(logs, f)

        self._train_losses.clear(); self._val_losses.clear()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)