Midal mit TV-Prior

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
"""
MIDAL Algorithmus mit TV-Prior
"""

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from glob import glob


def tv_denoise(image, weight, n_iter=100, px_init=None, py_init=None):
    """
    Chambolle's Projektionsalgorithmus für TV-Denoising.

    Args:
        image: Eingabebild
        weight: Regularisierungsparameter
        n_iter: Anzahl Iterationen
        px_init, py_init: Warm-Start für Dualvariablen

    Returns:
        denoised: Entrauschtes Bild
        px, py: Dualvariablen (für Warm-Start)
    """
    f = image.astype(np.float64)
    rows, cols = f.shape

    if px_init is not None and py_init is not None:
        px = px_init.copy()
        py = py_init.copy()
    else:
        px = np.zeros((rows, cols), dtype=np.float64)
        py = np.zeros((rows, cols), dtype=np.float64)

    tau = 0.125

    for _ in range(n_iter):
        # div(p) = d/dx(px) + d/dy(py)
        divp = np.zeros_like(f)

        divp[1:-1, :] = px[1:-1, :] - px[:-2, :]
        divp[0, :] = px[0, :]
        divp[-1, :] = -px[-2, :]

        divp[:, 1:-1] += py[:, 1:-1] - py[:, :-2]
        divp[:, 0] += py[:, 0]
        divp[:, -1] += -py[:, -2]

        u = f + weight * divp

        gx = np.zeros_like(f)
        gy = np.zeros_like(f)
        gx[:-1, :] = u[1:, :] - u[:-1, :]
        gy[:, :-1] = u[:, 1:] - u[:, :-1]

        norm_g = np.sqrt(gx**2 + gy**2)
        denom = 1.0 + tau * norm_g / weight

        px = (px + tau * gx) / denom
        py = (py + tau * gy) / denom

    # Finale Rekonstruktion
    divp = np.zeros_like(f)
    divp[1:-1, :] = px[1:-1, :] - px[:-2, :]
    divp[0, :] = px[0, :]
    divp[-1, :] = -px[-2, :]
    divp[:, 1:-1] += py[:, 1:-1] - py[:, :-2]
    divp[:, 0] += py[:, 0]
    divp[:, -1] += -py[:, -2]

    denoised = f + weight * divp

    return denoised, px, py


def midal(noisy, M, lam, max_iter=100, tol=1e-4, verbose=False):
    """
    MIDAL: Multiplicative Image Denoising by Augmented Lagrangian.
    Gibt Ergebnis in Originaldomäne zurück.
    verbose=True zeigt Konvergenz.

    """
    mu = lam  # Paper empfiehlt mu = lambda

    # Log-Transformation (Gleichung 4 im Paper)
    g = np.log(np.maximum(noisy, 1e-10))

    u = g.copy()
    d = np.zeros_like(g)
    px, py = None, None

    for k in range(max_iter):
        # Z-Update (Paper Zeile 3-4): Newton-Verfahren pro Pixel
        z_prime = u + d
        z = z_prime.copy()

        for _ in range(15):
            diff = np.clip(g - z, -50, 50)
            exp_term = np.exp(diff)

            # Gradient und Hessian von M*(z + exp(g-z)) + (mu/2)(z-z')^2
            grad = M * (1.0 - exp_term) + mu * (z - z_prime)
            hess = M * exp_term + mu

            delta = grad / hess
            z = z - delta

            if np.max(np.abs(delta)) < 1e-9:
                break

        # U-Update (Paper Zeile 5-6): TV-Denoising
        u_old = u.copy()
        u, px, py = tv_denoise(z - d, lam / mu, n_iter=20, px_init=px, py_init=py)

        # Dual-Update (Paper Zeile 7)
        d = d - (z - u)

        # Konvergenzprüfung (normiert auf Bildgröße)
        primal_res = np.linalg.norm(z - u) / np.sqrt(z.size)
        dual_res = np.linalg.norm(u - u_old) * mu / np.sqrt(z.size)

        if verbose and k % 10 == 0:
            obj = compute_objective(z, g, M, lam)
            print(f"  Iter {k:3d}: primal_res={primal_res:.2e}, dual_res={dual_res:.2e}, obj={obj:.2f}")

        if primal_res < tol and dual_res < tol:
            if verbose:
                print(f"  Konvergiert nach {k+1} Iterationen")
            break

    return np.exp(u)


def compute_objective(z, g, M, lam):
    """Berechnet die Zielfunktion L(z) aus Gleichung (11)."""
    diff = np.clip(g - z, -50, 50)
    data_term = M * np.sum(z + np.exp(diff))

    gx = np.diff(z, axis=0)
    gy = np.diff(z, axis=1)
    tv_term = lam * (np.sum(np.sqrt(gx**2 + 1e-10)) + np.sum(np.sqrt(gy**2 + 1e-10)))

    return data_term + tv_term


def add_gamma_noise(image, M, seed=None):
    """
    Fügt M-Look Gamma multiplikatives Rauschen hinzu.
    Modell: Y = X * N, wobei N ~ Gamma(M, 1/M) mit E[N] = 1
    """
    if seed is not None:
        np.random.seed(seed)
    return image * np.random.gamma(shape=M, scale=1.0/M, size=image.shape)


def compute_psnr(clean, noisy):
    mse = np.mean((clean - noisy)**2)
    if mse < 1e-10:
        return 100.0
    return 10.0 * np.log10(1.0 / mse)


def compute_mse(clean, noisy):
    return np.mean((clean - noisy)**2)


def compute_mae(clean, noisy):
    return np.mean(np.abs(clean - noisy))


def compute_relative_error(clean, estimate):
    return np.linalg.norm(estimate - clean) / np.linalg.norm(clean)


def get_lambda(M):
    """
    Lambda-Werte basierend auf eigenem Experiment (nicht Paper-Werte).
    """
    lambda_map = {1: 1.36, 4: 1.98, 10: 2.65}
    return lambda_map.get(M)


def process_single_image(clean, M, lam=None, seed=42, verbose=False):
    """
    verarbeitet Bild mit Midal

    lam=None wählt automatisch

    """
    if lam is None:
        lam = get_lambda(M)

    noisy = add_gamma_noise(clean, M, seed=seed)
    denoised = midal(noisy, M=M, lam=lam, max_iter=100, tol=1e-4, verbose=verbose)

    noisy_clip = np.clip(noisy, 0, 1)
    denoised_clip = np.clip(denoised, 0, 1)

    return {
        'clean': clean,
        'noisy': noisy,
        'noisy_clip': noisy_clip,
        'denoised': denoised,
        'denoised_clip': denoised_clip,
        'M': M,
        'lambda': lam,
        'psnr_noisy': compute_psnr(clean, noisy_clip),
        'psnr_denoised': compute_psnr(clean, denoised_clip),
        'mse_noisy': compute_mse(clean, noisy_clip),
        'mse_denoised': compute_mse(clean, denoised_clip),
        'mae_noisy': compute_mae(clean, noisy_clip),
        'mae_denoised': compute_mae(clean, denoised_clip),
        'err_noisy': compute_relative_error(clean, noisy_clip),
        'err_denoised': compute_relative_error(clean, denoised_clip),
    }


def process_all_images(images_dir, M_values=[1, 4, 10], output_dir='midal_results'):
    os.makedirs(output_dir, exist_ok=True)

    image_paths = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif', '*.tiff', '*.pgm']:
        image_paths.extend(glob(os.path.join(images_dir, ext)))
    image_paths.sort()

    print(f"MIDAL Denoising")
    print(f"Gefunden: {len(image_paths)} Bilder")
    print(f"M-Werte: {M_values}")
    print("=" * 80)

    all_results = {M: [] for M in M_values}

    for M in M_values:
        lam = get_lambda(M)

        print(f"\n{'=' * 80}")
        print(f"M={M}, λ={lam}")
        print(f"{'=' * 80}")

        M_dir = os.path.join(output_dir, f'M{M}')
        os.makedirs(M_dir, exist_ok=True)

        for idx, img_path in enumerate(image_paths):
            name = os.path.basename(img_path)
            img = np.array(Image.open(img_path).convert('L')).astype(np.float64) / 255.0

            results = process_single_image(img, M=M, lam=lam, seed=idx + 1000)
            results['name'] = name
            all_results[M].append(results)

            delta_psnr = results['psnr_denoised'] - results['psnr_noisy']
            print(f"[{idx+1:3d}/{len(image_paths)}] {name:<20} "
                  f"PSNR: {results['psnr_noisy']:.2f} → {results['psnr_denoised']:.2f} dB "
                  f"(Δ{delta_psnr:+.2f}) "
                  f"Err: {results['err_noisy']:.3f} → {results['err_denoised']:.3f}")

            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(results['clean'], cmap='gray', vmin=0, vmax=1)
            axes[0].set_title('Original')
            axes[0].axis('off')
            axes[1].imshow(results['noisy_clip'], cmap='gray', vmin=0, vmax=1)
            axes[1].set_title(f'Verrauscht (M={M})\nPSNR={results["psnr_noisy"]:.2f} dB')
            axes[1].axis('off')
            axes[2].imshow(results['denoised_clip'], cmap='gray', vmin=0, vmax=1)
            axes[2].set_title(f'MIDAL\nPSNR={results["psnr_denoised"]:.2f} dB')
            axes[2].axis('off')
            plt.suptitle(f'{name} | M={M}, λ={lam}', fontweight='bold')
            plt.tight_layout()
            plt.savefig(os.path.join(M_dir, f'{os.path.splitext(name)[0]}.png'), dpi=120)
            plt.close()

    print_summary(all_results, M_values)
    return all_results


def print_summary(all_results, M_values):
    print("\n" + "=" * 80)
    print("ZUSAMMENFASSUNG")
    print("=" * 80)

    summary_data = []
    for M in M_values:
        results = all_results[M]
        summary_data.append({
            'M': M,
            'lambda': get_lambda(M),
            'psnr_in': np.mean([r['psnr_noisy'] for r in results]),
            'psnr_out': np.mean([r['psnr_denoised'] for r in results]),
            'delta_psnr': np.mean([r['psnr_denoised'] - r['psnr_noisy'] for r in results]),
            'err_in': np.mean([r['err_noisy'] for r in results]),
            'err_out': np.mean([r['err_denoised'] for r in results]),
            'mae_in': np.mean([r['mae_noisy'] for r in results]),
            'mae_out': np.mean([r['mae_denoised'] for r in results]),
        })

    print(f"\n{'M':>4} {'λ':>6} {'PSNR_in':>9} {'PSNR_out':>10} {'ΔPSNR':>8} "
          f"{'Err_in':>8} {'Err_out':>9} {'MAE_in':>8} {'MAE_out':>9}")
    print("-" * 85)
    for s in summary_data:
        print(f"{s['M']:>4} {s['lambda']:>6.1f} {s['psnr_in']:>9.2f} {s['psnr_out']:>10.2f} "
              f"{s['delta_psnr']:>+8.2f} {s['err_in']:>8.3f} {s['err_out']:>9.3f} "
              f"{s['mae_in']:>8.4f} {s['mae_out']:>9.4f}")
    print("=" * 85)

    for M in M_values:
        print(f"\n{'=' * 80}")
        print(f"DETAILLIERTE ERGEBNISSE: M={M}")
        print(f"{'=' * 80}")
        print(f"{'Bild':<25} {'PSNR_in':>8} {'PSNR_out':>9} {'ΔPSNR':>7} "
              f"{'Err_in':>8} {'Err_out':>9} {'MAE_in':>8} {'MAE_out':>9}")
        print("-" * 85)
        for r in all_results[M]:
            delta = r['psnr_denoised'] - r['psnr_noisy']
            print(f"{r['name']:<25} {r['psnr_noisy']:>8.2f} {r['psnr_denoised']:>9.2f} "
                  f"{delta:>+7.2f} {r['err_noisy']:>8.3f} {r['err_denoised']:>9.3f} "
                  f"{r['mae_noisy']:>8.4f} {r['mae_denoised']:>9.4f}")


if __name__ == '__main__':
    try:
        from google.colab import drive
        drive.mount('/content/drive')
    except:
        pass

    images_dir = '/content/drive/MyDrive/Image-Denoising-with-Deep-CNNs-new/dataset/BSDS300/images/test'

    if os.path.exists(images_dir):
        all_results = process_all_images(
            images_dir,
            M_values=[1, 4, 10],
            output_dir='midal_results_corrected'
        )
    else:
        print(f"Verzeichnis nicht gefunden: {images_dir}")