## 8. Pixel diffusion LightningModule  
A minimal **de-noising diffusion** policy:  
* Input = noisy image + scalar timestep (packed as a 4-channel tensor)  
* Output = predicted noise ϵ  
Log per-epoch losses and save them so every worker can later plot global curves.

In [None]:
# 08. Pixel De-noising Diffusion Model  — final, logging via Lightning/Ray

class PixelDiffusion(pl.LightningModule):
    """Tiny CNN that predicts noise ϵ given noisy image + timestep."""

    def __init__(self, max_t=1000):
        super().__init__()
        self.max_t = max_t

        # Network: (3+1)-channel input → 3-channel noise prediction
        self.net = nn.Sequential(
            nn.Conv2d(4, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 3, 3, padding=1),
        )
        self.loss_fn = nn.MSELoss()

    # ---------- forward ----------
    def forward(self, noisy_img, t):
        """noisy_img: Bx3xHxW,  t: B (int) or Bx1 scalar"""
        b, _, h, w = noisy_img.shape
        t_scaled = (t / self.max_t).view(-1, 1, 1, 1).float().to(noisy_img.device)
        t_img = t_scaled.expand(-1, 1, h, w)
        x = torch.cat([noisy_img, t_img], dim=1)  # 4 channels
        return self.net(x)
    
    # ---------- shared loss ----------
    def _shared_step(self, batch):
        clean = batch["image"].to(self.device)             # Bx3xHxW, -1…1
        noise = torch.randn_like(clean)                    # ϵ ~ N(0, 1)
        t = torch.randint(0, self.max_t, (clean.size(0),), device=self.device)
        noisy = clean + noise                              # x_t = x_0 + ϵ
        pred_noise = self(noisy, t)
        return self.loss_fn(pred_noise, noise)

    # ---------- training / validation ----------
    def training_step(self, batch, batch_idx):
        loss = self._shared_step(batch)
        # Let Lightning aggregate + Ray callback report
        self.log("train_loss", loss, on_epoch=True, prog_bar=False, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, prog_bar=False, sync_dist=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=2e-4)