In [None]:
"""
PnP-MIDAL: MIDAL mit CNN Prior
"""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import os
from glob import glob

# Optimale mu-Werte aus eigenem Experiment
MU_FOR_M = {
    1: 1.2,
    4: 7.0,
    10: 18.0
}


class DnCNN(nn.Module):
    """
    DnCNN wie im Training: D=6, C=64, 3 Kanäle, Residual-Verbindung.
    """

    def __init__(self, D, C=64):
        super(DnCNN, self).__init__()
        self.D = D

        self.conv = nn.ModuleList()
        self.conv.append(nn.Conv2d(3, C, 3, padding=1))
        self.conv.extend([nn.Conv2d(C, C, 3, padding=1) for _ in range(D)])
        self.conv.append(nn.Conv2d(C, 3, 3, padding=1))

        self.bn = nn.ModuleList()
        self.bn.extend([nn.BatchNorm2d(C, C) for _ in range(D)])

    def forward(self, x):
        D = self.D
        h = F.relu(self.conv[0](x))
        for i in range(D):
            h = F.relu(self.bn[i](self.conv[i+1](h)))
        y = self.conv[D+1](h) + x
        return y


def load_cnn_model(checkpoint_path, D=6, device='cuda'):
    """Lädt trainiertes DnCNN aus Checkpoint, gibt Modell im eval-Modus zurück."""
    model = DnCNN(D=D, C=64)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['Net'])
    model = model.to(device)
    model.eval()
    return model


def cnn_denoise(image, model, device='cuda'):
    """
    Wendet CNN-Denoiser auf Bild in Log-Domain an.

    Pipeline: log → exp → [-1,1] → CNN → [0,1] → log
    Graustufen werden als 3-Kanal RGB dupliziert (wie im Training).
    """
    with torch.no_grad():
        # Log-Domain → [0,1] → [-1,1]
        image_original = np.clip(np.exp(image), 0, 1)
        image_normalized = image_original * 2 - 1

        # Graustufen → RGB, (H,W) → (1,3,H,W)
        img_rgb = np.stack([image_normalized] * 3, axis=0)
        img_tensor = torch.from_numpy(img_rgb).float().unsqueeze(0).to(device)

        denoised_tensor = model(img_tensor)
        denoised_normalized = np.mean(denoised_tensor.squeeze(0).cpu().numpy(), axis=0)

        # [-1,1] → [0,1] → Log-Domain
        denoised_original = np.clip((denoised_normalized + 1) / 2, 1e-10, 1)

    return np.log(denoised_original)


def pnp_midal(noisy, M, model, max_iter=100, tol=1e-4, mu=None, device='cuda', verbose=False):
    """
    Plug-and-Play MIDAL: CNN ersetzt TV-Prior im ADMM.
    mu=None verwendet MU_FOR_M[M].
    """
    if mu is None:
        mu = M

    g = np.log(np.maximum(noisy, 1e-10))
    u = g.copy()
    d = np.zeros_like(g)

    for k in range(max_iter):
        # Z-Update (Paper Zeile 3-4): Newton-Verfahren pro Pixel
        z_prime = u + d
        z = z_prime.copy()

        for _ in range(15):
            diff = np.clip(g - z, -50, 50)
            exp_term = np.exp(diff)

            grad = M * (1.0 - exp_term) + mu * (z - z_prime)
            hess = M * exp_term + mu

            delta = grad / hess
            z = z - delta

            if np.max(np.abs(delta)) < 1e-9:
                break

        # U-Update (Paper Zeile 5-6): CNN statt TV
        u_old = u.copy()
        u = cnn_denoise(z - d, model, device=device)

        # Dual-Update (Paper Zeile 7)
        d = d - (z - u)

        # Konvergenzprüfung (normiert auf Bildgröße)
        primal_res = np.linalg.norm(z - u) / np.sqrt(z.size)
        dual_res = np.linalg.norm(u - u_old) * mu / np.sqrt(z.size)

        if verbose and k % 10 == 0:
            print(f"  Iter {k:3d}: primal_res={primal_res:.2e}, dual_res={dual_res:.2e}")

        if primal_res < tol and dual_res < tol:
            if verbose:
                print(f"  Konvergiert nach {k+1} Iterationen")
            break

    return np.exp(u)


def add_gamma_noise(image, M, seed=None):
    """Modell: Y = X * N, N ~ Gamma(M, 1/M) mit E[N] = 1"""
    if seed is not None:
        np.random.seed(seed)
    return image * np.random.gamma(shape=M, scale=1.0/M, size=image.shape)


def compute_psnr(clean, noisy):
    mse = np.mean((clean - noisy)**2)
    if mse < 1e-10:
        return 100.0
    return 10.0 * np.log10(1.0 / mse)


def compute_mse(clean, noisy):
    return np.mean((clean - noisy)**2)


def compute_mae(clean, noisy):
    return np.mean(np.abs(clean - noisy))


def compute_relative_error(clean, estimate):
    return np.linalg.norm(estimate - clean) / np.linalg.norm(clean)


def get_checkpoint_path(M, base_path='/content/drive/MyDrive/Image-Denoising-with-Deep-CNNs/checkpoints'):
    return os.path.join(base_path, f'mult_M{M}', 'checkpoint.pth.tar')


def process_single_image(clean, M, model, device='cuda', seed=42, verbose=False):
    """Verarbeitet ein Bild mit PnP-MIDAL. lam=None wählt mu automatisch."""
    noisy = add_gamma_noise(clean, M, seed=seed)
    denoised = pnp_midal(noisy, M=M, model=model, mu=MU_FOR_M[M], device=device)

    noisy_clip = np.clip(noisy, 0, 1)
    denoised_clip = np.clip(denoised, 0, 1)

    return {
        'clean': clean,
        'noisy': noisy,
        'noisy_clip': noisy_clip,
        'denoised': denoised,
        'denoised_clip': denoised_clip,
        'M': M,
        'psnr_noisy': compute_psnr(clean, noisy_clip),
        'psnr_denoised': compute_psnr(clean, denoised_clip),
        'mse_noisy': compute_mse(clean, noisy_clip),
        'mse_denoised': compute_mse(clean, denoised_clip),
        'mae_noisy': compute_mae(clean, noisy_clip),
        'mae_denoised': compute_mae(clean, denoised_clip),
        'err_noisy': compute_relative_error(clean, noisy_clip),
        'err_denoised': compute_relative_error(clean, denoised_clip),
    }


def process_all_images(images_dir, M_values=[1, 4, 10], output_dir='pnp_midal_results',
                       checkpoint_base_path='/content/drive/MyDrive/Image-Denoising-with-Deep-CNNs/checkpoints',
                       max_images=None, device='cuda'):
    os.makedirs(output_dir, exist_ok=True)

    image_paths = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif', '*.tiff', '*.pgm']:
        image_paths.extend(glob(os.path.join(images_dir, ext)))
    image_paths.sort()

    if max_images is not None:
        image_paths = image_paths[:max_images]

    print(f"PnP-MIDAL Denoising")
    print(f"Gefunden: {len(image_paths)} Bilder, M-Werte: {M_values}, Device: {device}")
    print("=" * 80)

    all_results = {M: [] for M in M_values}

    for M in M_values:
        checkpoint_path = get_checkpoint_path(M, checkpoint_base_path)
        print(f"\nM={M} | Checkpoint: {checkpoint_path}")
        print("=" * 80)

        model = load_cnn_model(checkpoint_path, D=6, device=device)
        M_dir = os.path.join(output_dir, f'M{M}')
        os.makedirs(M_dir, exist_ok=True)

        for idx, img_path in enumerate(image_paths):
            name = os.path.basename(img_path)
            img = np.array(Image.open(img_path).convert('L')).astype(np.float64) / 255.0

            results = process_single_image(img, M=M, model=model, device=device, seed=idx + 1000)
            results['name'] = name
            all_results[M].append(results)

            delta_psnr = results['psnr_denoised'] - results['psnr_noisy']
            print(f"[{idx+1:3d}/{len(image_paths)}] {name:<20} "
                  f"PSNR: {results['psnr_noisy']:.2f} → {results['psnr_denoised']:.2f} dB "
                  f"(Δ{delta_psnr:+.2f}) Err: {results['err_noisy']:.3f} → {results['err_denoised']:.3f}")

            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(results['clean'], cmap='gray', vmin=0, vmax=1)
            axes[0].set_title('Original')
            axes[0].axis('off')
            axes[1].imshow(results['noisy_clip'], cmap='gray', vmin=0, vmax=1)
            axes[1].set_title(f'Verrauscht (M={M})\nPSNR={results["psnr_noisy"]:.2f} dB')
            axes[1].axis('off')
            axes[2].imshow(results['denoised_clip'], cmap='gray', vmin=0, vmax=1)
            axes[2].set_title(f'PnP-MIDAL\nPSNR={results["psnr_denoised"]:.2f} dB')
            axes[2].axis('off')
            plt.suptitle(f'{name} | M={M}', fontweight='bold')
            plt.tight_layout()
            plt.savefig(os.path.join(M_dir, f'{os.path.splitext(name)[0]}.png'), dpi=120)
            plt.close()

    print_summary(all_results, M_values)
    return all_results


def print_summary(all_results, M_values):
    print("\n" + "=" * 80)
    print("ZUSAMMENFASSUNG")
    print("=" * 80)

    print(f"\n{'M':>4} {'PSNR_in':>9} {'PSNR_out':>10} {'ΔPSNR':>8} "
          f"{'Err_in':>8} {'Err_out':>9}")
    print("-" * 60)

    for M in M_values:
        results = all_results[M]
        avg_psnr_in = np.mean([r['psnr_noisy'] for r in results])
        avg_psnr_out = np.mean([r['psnr_denoised'] for r in results])
        avg_err_in = np.mean([r['err_noisy'] for r in results])
        avg_err_out = np.mean([r['err_denoised'] for r in results])

        print(f"{M:>4} {avg_psnr_in:>9.2f} {avg_psnr_out:>10.2f} "
              f"{avg_psnr_out - avg_psnr_in:>+8.2f} {avg_err_in:>8.3f} {avg_err_out:>9.3f}")

    print("=" * 60)


def test_cnn_direct(clean, M, model, device='cuda', seed=42):
    """Testet CNN direkt auf verrauschtem Bild ohne ADMM."""
    np.random.seed(seed)
    noisy = clean * np.random.gamma(shape=M, scale=1.0/M, size=clean.shape)
    noisy_clip = np.clip(noisy, 0, 1)

    with torch.no_grad():
        noisy_normalized = noisy_clip * 2 - 1
        img_rgb = np.stack([noisy_normalized] * 3, axis=0)
        img_tensor = torch.from_numpy(img_rgb).float().unsqueeze(0).to(device)

        denoised_tensor = model(img_tensor)
        denoised_normalized = np.mean(denoised_tensor.squeeze(0).cpu().numpy(), axis=0)
        denoised_clip = np.clip((denoised_normalized + 1) / 2, 0, 1)

    psnr_noisy = compute_psnr(clean, noisy_clip)
    psnr_denoised = compute_psnr(clean, denoised_clip)
    print(f"CNN direkt: PSNR {psnr_noisy:.2f} → {psnr_denoised:.2f} dB (Δ{psnr_denoised-psnr_noisy:+.2f})")

    return noisy_clip, denoised_clip


if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")

    try:
        from google.colab import drive
        drive.mount('/content/drive')
    except:
        pass

    images_dir = '/content/drive/MyDrive/Image-Denoising-with-Deep-CNNs-new/dataset/BSDS300/images/test'
    checkpoint_base_path = '/content/drive/MyDrive/Image-Denoising-with-Deep-CNNs/checkpoints'

    img = np.array(Image.open(os.path.join(images_dir, '101085.jpg')).convert('L')).astype(np.float64) / 255.0

    print("\nTEST: CNN direkt (ohne ADMM)")
    print("=" * 60)
    for M in [1, 4, 10]:
        checkpoint_path = os.path.join(checkpoint_base_path, f'mult_M{M}', 'checkpoint.pth.tar')
        print(f"\nTest M={M}:")
        model = load_cnn_model(checkpoint_path, D=6, device=device)
        test_cnn_direct(img, M=M, model=model, device=device)

    print("\nPnP-MIDAL Verarbeitung")
    print("=" * 60)

    if os.path.exists(images_dir):
        all_results = process_all_images(
            images_dir,
            M_values=[1, 4, 10],
            output_dir='pnp_midal_results',
            checkpoint_base_path=checkpoint_base_path,
            max_images=None,
            device=device
        )
    else:
        print(f"Verzeichnis nicht gefunden: {images_dir}")