## 08 · Pixel Diffusion LightningModule  
A minimal **de-noising diffusion** policy:  
* Input = noisy image + scalar timestep (packed as a 4-channel tensor)  
* Output = predicted noise ϵ  
Log per-epoch losses and save them to a shared JSON so every worker can later plot global curves.

In [None]:
# 08. Pixel De-noising Diffusion Model

class PixelDiffusion(pl.LightningModule):
    """Tiny CNN that predicts noise ϵ given noisy image + timestep."""

    def __init__(self, max_t=1000, log_path=None):
        super().__init__()
        self.max_t = max_t
        self.log_path = log_path or "/mnt/cluster_storage/generative_cv/epoch_metrics.json"

        # Network: (3 + 1)‑channel input → 3‑channel noise prediction
        self.net = nn.Sequential(
            nn.Conv2d(4, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 3, 3, padding=1),
        )
        self.loss_fn = nn.MSELoss()
        self._train_losses, self._val_losses = [], []

    # ---------- forward ----------
    def forward(self, noisy_img, t):
        """noisy_img: Bx3xHxW,  t: B (int) or Bx1 scalar"""
        b, _, h, w = noisy_img.shape
        t_scaled = (t / self.max_t).view(-1, 1, 1, 1).float().to(noisy_img.device)
        t_img = t_scaled.expand(-1, 1, h, w)
        x = torch.cat([noisy_img, t_img], dim=1)  # 4 channels
        return self.net(x)
    
    # ---------- training / validation steps ----------
    def _shared_step(self, batch):
        clean = batch["image"].to(self.device)             # Bx3xHxW, ‑1…1
        noise = torch.randn_like(clean)                    # ϵ ~ N(0, 1)
        t = torch.randint(0, self.max_t, (clean.size(0),), device=self.device)
        noisy = clean + noise                              # x_t = x_0 + ϵ
        pred_noise = self(noisy, t)
        return self.loss_fn(pred_noise, noise)

    def training_step(self, batch, batch_idx):
        loss = self._shared_step(batch)
        self._train_losses.append(loss.item())
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._shared_step(batch)
        self._val_losses.append(loss.item())
        return loss

    # ---------- epoch end logging ----------
    def on_train_epoch_end(self):
        rank = get_context().get_world_rank()
        if rank == 0:
            train_avg = np.mean(self._train_losses)
            val_avg   = np.mean(self._val_losses) if self._val_losses else None
            if val_avg is not None:
                print(f"[Epoch {self.current_epoch}] train={train_avg:.4f}  val={val_avg:.4f}")
            else:
                print(f"[Epoch {self.current_epoch}] train={train_avg:.4f}  val=N/A")

            # Append to shared JSON so you can plot later
            if os.path.exists(self.log_path):
                with open(self.log_path, "r") as f: logs = json.load(f)
            else:
                logs = []
            logs.append({"epoch": self.current_epoch+1, "train_loss": train_avg, "val_loss": val_avg})
            with open(self.log_path, "w") as f: json.dump(logs, f)

        # Clear per‑epoch trackers
        self._train_losses.clear(); self._val_losses.clear()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=2e-4)