In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np
import random

# -----------------------
# 1. Setup
# -----------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

# -----------------------
# 2. Load a real grayscale image (MNIST example)
# -----------------------
transform = T.Compose([T.ToTensor()])
mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
clean_img, _ = mnist[1]  # take one real digit image
clean_img = clean_img.unsqueeze(0).to(device)  # [1,1,H,W]

plt.figure(figsize=(2,2))
plt.imshow(clean_img[0,0].cpu(), cmap='gray')
plt.title("Clean image")
plt.axis('off')
plt.show()

# -----------------------
# 3. Small CNN denoiser
# -----------------------
class TinyDenoiser(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, padding=1)
        )
    def forward(self, x):
        return self.net(x)

# -----------------------
# 4. Noisy pair generator
# -----------------------
def make_noisy_pair(clean, sigma=0.3):
    n1 = torch.randn_like(clean) * sigma
    n2 = torch.randn_like(clean) * sigma
    return clean + n1, clean + n2

# -----------------------
# 5. Training function
# -----------------------
def train(model, clean, mode='n2n', steps=300, lr=1e-3):
    model = model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    losses = []

    for i in range(steps):
        opt.zero_grad()
        noisy1, noisy2 = make_noisy_pair(clean)

        if mode == 'n2n':
            # Noise2Noise
            pred = model(noisy1)
            loss = F.mse_loss(pred, noisy2)

        elif mode == 'n2v':
            # Noise2Void — mask random pixels
            mask = torch.rand_like(noisy1) < 0.1
            noisy_masked = noisy1.clone()
            noisy_masked[mask] = 0.0
            pred = model(noisy_masked)
            loss = F.mse_loss(pred[mask], noisy1[mask])

        elif mode == 'n2s':
            # Noise2Self — J-invariant variant
            mask = torch.rand_like(noisy1) < 0.1
            masked_input = noisy1 * (~mask)
            pred = model(masked_input)
            loss = F.mse_loss(pred[mask], noisy1[mask])

        else:
            raise ValueError("Mode must be n2n, n2v, or n2s")

        loss.backward()
        opt.step()
        losses.append(loss.item())

        if (i+1) % 50 == 0:
            print(f"[{mode}] step {i+1}/{steps}  loss={loss.item():.4f}")

    return model.eval(), losses

# -----------------------
# 6. Train all three modes
# -----------------------
models = {}
losses = {}

for mode in ['n2n', 'n2v', 'n2s']:
    print(f"\n=== Training {mode.upper()} ===")
    m = TinyDenoiser()
    m, l = train(m, clean_img, mode=mode, steps=300)
    models[mode] = m
    losses[mode] = l

# -----------------------
# 7. Test on a new noisy image
# -----------------------
noisy_test, _ = make_noisy_pair(clean_img, sigma=0.3)

with torch.no_grad():
    out_n2n = models['n2n'](noisy_test)
    out_n2v = models['n2v'](noisy_test)
    out_n2s = models['n2s'](noisy_test)

# -----------------------
# 8. Visualize results
# -----------------------
titles = ['Clean', 'Noisy', 'Noise2Noise', 'Noise2Void', 'Noise2Self']
images = [
    clean_img[0,0].cpu(),
    noisy_test[0,0].cpu(),
    out_n2n[0,0].cpu(),
    out_n2v[0,0].cpu(),
    out_n2s[0,0].cpu()
]

plt.figure(figsize=(12,3))
for i in range(5):
    plt.subplot(1,5,i+1)
    plt.imshow(images[i], cmap='gray', vmin=0, vmax=1)
    plt.title(titles[i])
    plt.axis('off')
plt.tight_layout()
plt.show()

# -----------------------
# 9. Plot training loss curves
# -----------------------
plt.figure(figsize=(6,4))
for mode in losses:
    plt.plot(losses[mode], label=mode)
plt.legend()
plt.title("Training loss curves")
plt.xlabel("Iteration")
plt.ylabel("MSE loss")
plt.show()
