In [5]:
!pip install torch torchvision numpy scikit-image matplotlib tqdm



In [7]:
# ============================================================
# GMiSDE-Net
# GammaMoE + CHSN + Implicit SDE
# MNIST + Speckle (Gamma noise)
# ============================================================

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

# ---------------- CONFIG ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMG_SIZE = 32
BATCH = 64
EPOCHS = 10
LR = 2e-4

EPS = 1e-4
GAMMA_K = 2.0
NUM_EXPERTS = 4

SAVE_DIR = "./gmisde_results"
os.makedirs(SAVE_DIR, exist_ok=True)

# ============================================================
# DATASET: MNIST + GAMMA SPECKLE
# ============================================================

class MNISTSpeckle(Dataset):
    def __init__(self, train=True):
        tfm = T.Compose([
            T.Resize((IMG_SIZE, IMG_SIZE)),
            T.ToTensor()
        ])
        base = datasets.MNIST("./data", train=train, download=True, transform=tfm)
        self.samples = []

        for img, _ in base:
            clean = img.squeeze(0).clamp(EPS, 1.0)
            gamma = torch.distributions.Gamma(GAMMA_K, GAMMA_K).sample(clean.shape)
            noisy = (clean * gamma).clamp(EPS, 1.0)
            self.samples.append((noisy, clean))

    def __len__(self): return len(self.samples)
    def __getitem__(self, i): return self.samples[i]

# ============================================================
# CHSN EXPERT
# ============================================================

class CHSN(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU()
        )
        self.mu = nn.Conv2d(64, 1, 3, padding=1)
        self.sigma = nn.Conv2d(64, 1, 3, padding=1)
        self.unc = nn.Conv2d(64, 1, 3, padding=1)

    def forward(self, x):
        h = self.enc(x)
        mu = self.mu(h)
        sigma = 0.1 * F.softplus(self.sigma(h)) + EPS
        unc = torch.sigmoid(self.unc(h))
        return mu, sigma, unc

# ============================================================
# GAMMA MIXTURE OF EXPERTS
# ============================================================

class GammaMoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.router = nn.Conv2d(1, NUM_EXPERTS, 1)
        self.experts = nn.ModuleList([CHSN() for _ in range(NUM_EXPERTS)])

    def forward(self, x):
        weights = F.softmax(self.router(x), dim=1)

        mu, sigma, unc = 0, 0, 0
        for i, expert in enumerate(self.experts):
            m, s, u = expert(x)
            w = weights[:, i:i+1]
            mu += w * m
            sigma += w * s
            unc += w * u

        return mu, sigma, unc

# ============================================================
# IMPLICIT SDE HEAD
# ============================================================

class ImplicitSDE(nn.Module):
    """
    Learns terminal solution:
    x_T = x + f_theta(x, μ, σ)
    """
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 1, 3, padding=1)
        )

    def forward(self, x, mu, sigma):
        inp = torch.cat([x, mu, sigma], dim=1)
        delta = self.net(inp)
        return (x + delta).clamp(EPS, 1.0)

# ============================================================
# FULL MODEL
# ============================================================

class GMiSDENet(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = GammaMoE()
        self.sde = ImplicitSDE()

    def forward(self, x):
        mu, sigma, unc = self.moe(x)
        return self.sde(x, mu * unc, sigma)

# ============================================================
# TRAINING
# ============================================================

def train():
    dataset = MNISTSpeckle(train=True)
    loader = DataLoader(dataset, BATCH, shuffle=True)

    model = GMiSDENet().to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=LR)

    print("\nTraining GMiSDE-Net...\n")

    for ep in range(EPOCHS):
        total = 0
        for noisy, clean in loader:
            noisy = noisy.unsqueeze(1).to(DEVICE)
            clean = clean.unsqueeze(1).to(DEVICE)

            recon = model(noisy)

            loss = (
                F.mse_loss(recon, clean) +
                0.2 * ((torch.log(recon) - torch.log(clean)) ** 2).mean()
            )

            opt.zero_grad()
            loss.backward()
            opt.step()

            total += loss.item()

        print(f"[Epoch {ep+1:02d}] Loss={total/len(loader):.4f}")

    return model

# ============================================================
# EVALUATION
# ============================================================

def evaluate(model):
    dataset = MNISTSpeckle(train=False)
    psnr_list, ssim_list = [], []

    print("\nEvaluating...\n")

    for i in range(10):
        noisy, clean = dataset[i]
        noisy = noisy.unsqueeze(0).unsqueeze(0).to(DEVICE)
        clean = clean.unsqueeze(0).unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            recon = model(noisy)

        cn = clean[0,0].cpu().numpy()
        nn = noisy[0,0].cpu().numpy()
        rn = recon[0,0].cpu().numpy()

        p = peak_signal_noise_ratio(cn, rn, data_range=1.0)
        s = structural_similarity(cn, rn, data_range=1.0)

        psnr_list.append(p)
        ssim_list.append(s)

        fig, ax = plt.subplots(1,4, figsize=(12,3))
        ax[0].imshow(cn, cmap="gray"); ax[0].set_title("Clean")
        ax[1].imshow(nn, cmap="gray"); ax[1].set_title("Noisy")
        ax[2].imshow(np.abs(nn-cn), cmap="hot"); ax[2].set_title("Noise")
        ax[3].imshow(rn, cmap="gray"); ax[3].set_title(f"GMiSDE\nPSNR={p:.2f}")
        for a in ax: a.axis("off")
        plt.savefig(f"{SAVE_DIR}/sample_{i}.png", dpi=150)
        plt.close()

    print("\n=== RESULTS ===")
    print(f"Avg PSNR: {np.mean(psnr_list):.2f}")
    print(f"Avg SSIM: {np.mean(ssim_list):.4f}")

# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    model = train()
    evaluate(model)


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 507kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.69MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.86MB/s]



Training GMiSDE-Net...

[Epoch 01] Loss=0.0641
[Epoch 02] Loss=0.0425
[Epoch 03] Loss=0.0393
[Epoch 04] Loss=0.0352
[Epoch 05] Loss=0.0301
[Epoch 06] Loss=0.0262
[Epoch 07] Loss=0.0237
[Epoch 08] Loss=0.0209
[Epoch 09] Loss=0.0188
[Epoch 10] Loss=0.0171

Evaluating...


=== RESULTS ===
Avg PSNR: 26.22
Avg SSIM: 0.9705


In [10]:
# ============================================================
# GMiSDE-Net vs ID-CNN vs DnCNN
# MNIST + Speckle (Gamma noise) — FULL VISUALIZATION
# ============================================================

import os, torch, numpy as np, matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

# ---------------- CONFIG ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 32
BATCH = 64
EPOCHS = 10
LR = 2e-4

EPS = 1e-4
GAMMA_K = 2.0
NUM_EXPERTS = 4

SAVE_DIR = "./gmisde_compare"
os.makedirs(SAVE_DIR, exist_ok=True)

# ============================================================
# DATASET
# ============================================================

class MNISTSpeckle(Dataset):
    def __init__(self, train=True):
        tfm = T.Compose([T.Resize((IMG_SIZE, IMG_SIZE)), T.ToTensor()])
        base = datasets.MNIST("./data", train=train, download=True, transform=tfm)
        self.data = []

        for img, _ in base:
            clean = img.squeeze(0).clamp(EPS, 1.0)
            gamma = torch.distributions.Gamma(GAMMA_K, GAMMA_K).sample(clean.shape)
            noisy = (clean * gamma).clamp(EPS, 1.0)
            self.data.append((noisy, clean))

    def __len__(self): return len(self.data)
    def __getitem__(self, i): return self.data[i]

# ============================================================
# BASELINES
# ============================================================

class IDCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,1,3,1,1)
        )
    def forward(self,x): return self.net(x).clamp(EPS,1)

class DnCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,1,3,1,1)
        )
    def forward(self,x): return (x - self.net(x)).clamp(EPS,1)

# ============================================================
# PROPOSED MODEL: GMiSDE-Net
# ============================================================

class CHSN(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1,32,3,1,1), nn.ReLU(),
            nn.Conv2d(32,64,3,1,1), nn.ReLU()
        )
        self.mu = nn.Conv2d(64,1,3,1,1)
        self.sigma = nn.Conv2d(64,1,3,1,1)
        self.unc = nn.Conv2d(64,1,3,1,1)

    def forward(self,x):
        h = self.enc(x)
        mu = self.mu(h)
        sigma = 0.1*F.softplus(self.sigma(h)) + EPS
        unc = torch.sigmoid(self.unc(h))
        return mu,sigma,unc

class GammaMoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.router = nn.Conv2d(1, NUM_EXPERTS, 1)
        self.experts = nn.ModuleList([CHSN() for _ in range(NUM_EXPERTS)])

    def forward(self,x):
        w = F.softmax(self.router(x), dim=1)
        mu=sigma=unc=0
        for i,e in enumerate(self.experts):
            m,s,u = e(x)
            mu += w[:,i:i+1]*m
            sigma += w[:,i:i+1]*s
            unc += w[:,i:i+1]*u
        return mu,sigma,unc

class ImplicitSDE(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,1,3,1,1)
        )
    def forward(self,x,mu,sigma):
        return (x + self.net(torch.cat([x,mu,sigma],1))).clamp(EPS,1)

class GMiSDENet(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = GammaMoE()
        self.sde = ImplicitSDE()
    def forward(self,x):
        mu,sigma,unc = self.moe(x)
        return self.sde(x, mu*unc, sigma)

# ============================================================
# TRAINING
# ============================================================

def train_model(model, name):
    data = MNISTSpeckle(train=True)
    loader = DataLoader(data,BATCH,shuffle=True)
    model = model.to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=LR)

    print(f"\nTraining {name}")
    for ep in range(EPOCHS):
        tot=0
        for noisy,clean in loader:
            noisy=noisy.unsqueeze(1).to(DEVICE)
            clean=clean.unsqueeze(1).to(DEVICE)
            out=model(noisy)
            loss=F.mse_loss(out,clean)
            opt.zero_grad(); loss.backward(); opt.step()
            tot+=loss.item()
        print(f"[{name}][{ep+1}] Loss={tot/len(loader):.4f}")
    return model

# ============================================================
# EVALUATION (FIXED + NOISE MAP)
# ============================================================

def evaluate(models):
    data = MNISTSpeckle(train=False)

    psnr_scores = {k:[] for k in models}
    ssim_scores = {k:[] for k in models}

    for i in range(10):
        noisy,clean = data[i]
        noisy=noisy.unsqueeze(0).unsqueeze(0).to(DEVICE)
        clean=clean.unsqueeze(0).unsqueeze(0).to(DEVICE)

        cn = clean[0,0].cpu().numpy()
        nn = noisy[0,0].cpu().numpy()
        noise_map = np.abs(nn - cn)

        outputs={}
        with torch.no_grad():
            for name,m in models.items():
                outputs[name] = m(noisy)[0,0].cpu().numpy()

        fig,ax=plt.subplots(1,6,figsize=(20,3))
        ax[0].imshow(cn,cmap="gray"); ax[0].set_title("Clean")
        ax[1].imshow(nn,cmap="gray"); ax[1].set_title("Noisy")
        ax[2].imshow(noise_map,cmap="gray"); ax[2].set_title("Noise")

        for idx,name in enumerate(models):
            p = peak_signal_noise_ratio(cn, outputs[name], data_range=1.0)
            s = structural_similarity(cn, outputs[name], data_range=1.0)
            psnr_scores[name].append(p)
            ssim_scores[name].append(s)

            ax[idx+3].imshow(outputs[name],cmap="gray")
            ax[idx+3].set_title(f"{name}\nPSNR={p:.2f}\nSSIM={s:.4f}")

        for a in ax: a.axis("off")
        plt.tight_layout()
        plt.savefig(f"{SAVE_DIR}/sample_{i}.png", dpi=150)
        plt.close()

    print("\n=== FINAL RESULTS ===")
    for k in models:
        print(f"{k}: PSNR={np.mean(psnr_scores[k]):.2f}, SSIM={np.mean(ssim_scores[k]):.4f}")

# ============================================================
# MAIN
# ============================================================

if __name__=="__main__":
    models = {
        "GMiSDE": train_model(GMiSDENet(),"GMiSDE"),
        "ID-CNN": train_model(IDCNN(),"ID-CNN"),
        "DnCNN": train_model(DnCNN(),"DnCNN")
    }
    evaluate(models)



Training GMiSDE
[GMiSDE][1] Loss=0.0033
[GMiSDE][2] Loss=0.0022
[GMiSDE][3] Loss=0.0021
[GMiSDE][4] Loss=0.0020
[GMiSDE][5] Loss=0.0020
[GMiSDE][6] Loss=0.0020
[GMiSDE][7] Loss=0.0019
[GMiSDE][8] Loss=0.0019
[GMiSDE][9] Loss=0.0019
[GMiSDE][10] Loss=0.0019

Training ID-CNN
[ID-CNN][1] Loss=0.0040
[ID-CNN][2] Loss=0.0025
[ID-CNN][3] Loss=0.0024
[ID-CNN][4] Loss=0.0023
[ID-CNN][5] Loss=0.0023
[ID-CNN][6] Loss=0.0022
[ID-CNN][7] Loss=0.0022
[ID-CNN][8] Loss=0.0022
[ID-CNN][9] Loss=0.0022
[ID-CNN][10] Loss=0.0021

Training DnCNN
[DnCNN][1] Loss=0.0033
[DnCNN][2] Loss=0.0025
[DnCNN][3] Loss=0.0024
[DnCNN][4] Loss=0.0023
[DnCNN][5] Loss=0.0023
[DnCNN][6] Loss=0.0023
[DnCNN][7] Loss=0.0022
[DnCNN][8] Loss=0.0022
[DnCNN][9] Loss=0.0022
[DnCNN][10] Loss=0.0022

=== FINAL RESULTS ===
GMiSDE: PSNR=28.06, SSIM=0.9783
ID-CNN: PSNR=27.49, SSIM=0.9749
DnCNN: PSNR=27.54, SSIM=0.9753


In [8]:
# ============================================================
# GMiSDE-Net vs ID-CNN vs DnCNN
# CIFAR-10 (Grayscale) + Gamma Speckle
# ============================================================

import os, torch, numpy as np, matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

# ---------------- CONFIG ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMG_SIZE = 32
BATCH = 32
EPOCHS = 15
LR = 2e-4

EPS = 1e-4
GAMMA_K = 2.0
NUM_EXPERTS = 4

SAVE_DIR = "./gmisde_cifar_results"
os.makedirs(SAVE_DIR, exist_ok=True)

# ============================================================
# DATASET: CIFAR-10 → GRAYSCALE + GAMMA SPECKLE
# ============================================================

class CIFAR10Speckle(Dataset):
    def __init__(self, train=True):
        tfm = T.Compose([
            T.Grayscale(),
            T.Resize((IMG_SIZE, IMG_SIZE)),
            T.ToTensor()
        ])
        base = datasets.CIFAR10(
            "./data", train=train, download=True
        )
        self.data = []

        for img, _ in base:
            clean = tfm(img).squeeze(0).clamp(EPS, 1.0)
            gamma = torch.distributions.Gamma(GAMMA_K, GAMMA_K).sample(clean.shape)
            noisy = (clean * gamma).clamp(EPS, 1.0)
            self.data.append((noisy, clean))

    def __len__(self): return len(self.data)
    def __getitem__(self, i): return self.data[i]

# ============================================================
# BASELINES
# ============================================================

class IDCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,1,3,1,1)
        )
    def forward(self,x):
        return self.net(x).clamp(EPS,1)

class DnCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,1,3,1,1)
        )
    def forward(self,x):
        return (x - self.net(x)).clamp(EPS,1)

# ============================================================
# PROPOSED: GammaMoE + CHSN + Implicit SDE
# ============================================================

class CHSN(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1,32,3,1,1), nn.ReLU(),
            nn.Conv2d(32,64,3,1,1), nn.ReLU()
        )
        self.mu = nn.Conv2d(64,1,3,1,1)
        self.sigma = nn.Conv2d(64,1,3,1,1)
        self.unc = nn.Conv2d(64,1,3,1,1)

    def forward(self,x):
        h = self.enc(x)
        mu = self.mu(h)
        sigma = 0.1 * F.softplus(self.sigma(h)) + EPS
        unc = torch.sigmoid(self.unc(h))
        return mu, sigma, unc

class GammaMoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.router = nn.Conv2d(1, NUM_EXPERTS, 1)
        self.experts = nn.ModuleList([CHSN() for _ in range(NUM_EXPERTS)])

    def forward(self,x):
        w = F.softmax(self.router(x), dim=1)
        mu = sigma = unc = 0
        for i,e in enumerate(self.experts):
            m,s,u = e(x)
            mu += w[:,i:i+1] * m
            sigma += w[:,i:i+1] * s
            unc += w[:,i:i+1] * u
        return mu, sigma, unc

class ImplicitSDE(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,1,3,1,1)
        )
    def forward(self,x,mu,sigma):
        return (x + self.net(torch.cat([x,mu,sigma],1))).clamp(EPS,1)

class GMiSDENet(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = GammaMoE()
        self.sde = ImplicitSDE()
    def forward(self,x):
        mu,sigma,unc = self.moe(x)
        return self.sde(x, mu*unc, sigma)

# ============================================================
# TRAINING
# ============================================================

def train_model(model, name):
    data = CIFAR10Speckle(train=True)
    loader = DataLoader(data, BATCH, shuffle=True)
    model = model.to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=LR)

    print(f"\nTraining {name}")
    for ep in range(EPOCHS):
        tot = 0
        for noisy,clean in loader:
            noisy = noisy.unsqueeze(1).to(DEVICE)
            clean = clean.unsqueeze(1).to(DEVICE)
            out = model(noisy)
            loss = (
                F.mse_loss(out, clean) +
                0.2 * ((torch.log(out) - torch.log(clean))**2).mean()
            )
            opt.zero_grad(); loss.backward(); opt.step()
            tot += loss.item()
        print(f"[{name}][{ep+1}] Loss={tot/len(loader):.4f}")
    return model

# ============================================================
# EVALUATION (SAFE PLOTTING)
# ============================================================

def evaluate(models):
    data = CIFAR10Speckle(train=False)

    psnr_scores = {k:[] for k in models}
    ssim_scores = {k:[] for k in models}

    for i in range(10):
        noisy, clean = data[i]
        noisy = noisy.unsqueeze(0).unsqueeze(0).to(DEVICE)
        clean = clean.unsqueeze(0).unsqueeze(0).to(DEVICE)

        clean_np = clean[0,0].cpu().numpy()
        noisy_np = noisy[0,0].cpu().numpy()

        outputs = {}
        with torch.no_grad():
            for name,model in models.items():
                outputs[name] = model(noisy)

        for name,out in outputs.items():
            out_np = out[0,0].cpu().numpy()
            psnr_scores[name].append(
                peak_signal_noise_ratio(clean_np, out_np, data_range=1.0)
            )
            ssim_scores[name].append(
                structural_similarity(clean_np, out_np, data_range=1.0)
            )

        fig,ax = plt.subplots(1,5,figsize=(16,3))
        ax[0].imshow(clean_np,cmap="gray"); ax[0].set_title("Clean")
        ax[1].imshow(noisy_np,cmap="gray"); ax[1].set_title("Noisy")

        for idx,name in enumerate(models):
            ax[idx+2].imshow(outputs[name][0,0].cpu(),cmap="gray")
            ax[idx+2].set_title(
                f"{name}\nPSNR={psnr_scores[name][-1]:.2f}\n"
                f"SSIM={ssim_scores[name][-1]:.4f}"
            )

        for a in ax: a.axis("off")
        plt.tight_layout()
        plt.savefig(f"{SAVE_DIR}/sample_{i}.png", dpi=150)
        plt.close()

    print("\n=== FINAL RESULTS (CIFAR-10) ===")
    for k in models:
        print(
            f"{k}: PSNR={np.mean(psnr_scores[k]):.2f}, "
            f"SSIM={np.mean(ssim_scores[k]):.4f}"
        )

# ============================================================
# MAIN
# ============================================================

if __name__=="__main__":
    models = {
        "GMiSDE": train_model(GMiSDENet(),"GMiSDE"),
        "ID-CNN": train_model(IDCNN(),"ID-CNN"),
        "DnCNN": train_model(DnCNN(),"DnCNN")
    }
    evaluate(models)


100%|██████████| 170M/170M [00:04<00:00, 42.5MB/s]



Training GMiSDE
[GMiSDE][1] Loss=0.0391
[GMiSDE][2] Loss=0.0288
[GMiSDE][3] Loss=0.0271
[GMiSDE][4] Loss=0.0259
[GMiSDE][5] Loss=0.0252
[GMiSDE][6] Loss=0.0246
[GMiSDE][7] Loss=0.0243
[GMiSDE][8] Loss=0.0240
[GMiSDE][9] Loss=0.0237
[GMiSDE][10] Loss=0.0235
[GMiSDE][11] Loss=0.0233
[GMiSDE][12] Loss=0.0231
[GMiSDE][13] Loss=0.0229
[GMiSDE][14] Loss=0.0228
[GMiSDE][15] Loss=0.0226

Training ID-CNN
[ID-CNN][1] Loss=0.0463
[ID-CNN][2] Loss=0.0338
[ID-CNN][3] Loss=0.0320
[ID-CNN][4] Loss=0.0301
[ID-CNN][5] Loss=0.0287
[ID-CNN][6] Loss=0.0278
[ID-CNN][7] Loss=0.0281
[ID-CNN][8] Loss=0.0267
[ID-CNN][9] Loss=0.0267
[ID-CNN][10] Loss=0.0264
[ID-CNN][11] Loss=0.0259
[ID-CNN][12] Loss=0.0256
[ID-CNN][13] Loss=0.0255
[ID-CNN][14] Loss=0.0252
[ID-CNN][15] Loss=0.0250

Training DnCNN
[DnCNN][1] Loss=0.0434
[DnCNN][2] Loss=0.0317
[DnCNN][3] Loss=0.0295
[DnCNN][4] Loss=0.0290
[DnCNN][5] Loss=0.0279
[DnCNN][6] Loss=0.0270
[DnCNN][7] Loss=0.0267
[DnCNN][8] Loss=0.0263
[DnCNN][9] Loss=0.0263
[DnCNN][10]

In [12]:
# ============================================================
# GMiSDE-Net Robustness Study
# CIFAR-10 (Grayscale) + Gamma Speckle
# Gamma values: k = [2,4,6,8,10]
# ============================================================

import os, torch, numpy as np, matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

# ---------------- CONFIG ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMG_SIZE = 32
BATCH = 32
EPOCHS = 12
LR = 2e-4

EPS = 1e-4
NUM_EXPERTS = 4

TRAIN_GAMMA = 4.0
TEST_GAMMAS = [2, 4, 6, 8, 10]

SAVE_DIR = "./gamma_robustness_cifar"
os.makedirs(SAVE_DIR, exist_ok=True)

# ============================================================
# DATASET
# ============================================================

class CIFAR10Speckle(Dataset):
    def __init__(self, gamma_k, train=True):
        self.gamma_k = gamma_k
        tfm = T.Compose([
            T.Grayscale(),
            T.Resize((IMG_SIZE, IMG_SIZE)),
            T.ToTensor()
        ])

        base = datasets.CIFAR10("./data", train=train, download=True)
        self.samples = []

        for img, _ in base:
            clean = tfm(img).squeeze(0).clamp(EPS, 1.0)
            gamma = torch.distributions.Gamma(gamma_k, gamma_k).sample(clean.shape)
            noisy = (clean * gamma).clamp(EPS, 1.0)
            self.samples.append((noisy, clean))

    def __len__(self): return len(self.samples)
    def __getitem__(self, i): return self.samples[i]

# ============================================================
# PROPOSED MODEL: GMiSDE-Net
# ============================================================

class CHSN(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1,32,3,1,1), nn.ReLU(),
            nn.Conv2d(32,64,3,1,1), nn.ReLU()
        )
        self.mu = nn.Conv2d(64,1,3,1,1)
        self.sigma = nn.Conv2d(64,1,3,1,1)
        self.unc = nn.Conv2d(64,1,3,1,1)

    def forward(self,x):
        h = self.enc(x)
        mu = self.mu(h)
        sigma = 0.1 * F.softplus(self.sigma(h)) + EPS
        unc = torch.sigmoid(self.unc(h))
        return mu, sigma, unc

class GammaMoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.router = nn.Conv2d(1, NUM_EXPERTS, 1)
        self.experts = nn.ModuleList([CHSN() for _ in range(NUM_EXPERTS)])

    def forward(self,x):
        w = F.softmax(self.router(x), dim=1)
        mu = sigma = unc = 0
        for i,e in enumerate(self.experts):
            m,s,u = e(x)
            mu += w[:,i:i+1] * m
            sigma += w[:,i:i+1] * s
            unc += w[:,i:i+1] * u
        return mu, sigma, unc

class ImplicitSDE(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,1,3,1,1)
        )
    def forward(self,x,mu,sigma):
        return (x + self.net(torch.cat([x,mu,sigma],1))).clamp(EPS,1)

class GMiSDENet(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = GammaMoE()
        self.sde = ImplicitSDE()
    def forward(self,x):
        mu,sigma,unc = self.moe(x)
        return self.sde(x, mu*unc, sigma)

# ============================================================
# TRAINING (single Gamma)
# ============================================================

def train_model():
    dataset = CIFAR10Speckle(TRAIN_GAMMA, train=True)
    loader = DataLoader(dataset, BATCH, shuffle=True)

    model = GMiSDENet().to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=LR)

    print(f"\nTraining GMiSDE-Net (Gamma={TRAIN_GAMMA})\n")

    for ep in range(EPOCHS):
        tot = 0
        for noisy,clean in loader:
            noisy = noisy.unsqueeze(1).to(DEVICE)
            clean = clean.unsqueeze(1).to(DEVICE)

            out = model(noisy)
            loss = (
                F.mse_loss(out, clean) +
                0.2 * ((torch.log(out) - torch.log(clean))**2).mean()
            )

            opt.zero_grad()
            loss.backward()
            opt.step()
            tot += loss.item()

        print(f"[Epoch {ep+1:02d}] Loss={tot/len(loader):.4f}")

    return model

# ============================================================
# EVALUATION ACROSS GAMMA VALUES
# ============================================================

def evaluate_gamma_robustness(model):
    results = {}

    for k in TEST_GAMMAS:
        dataset = CIFAR10Speckle(k, train=False)
        psnr_list, ssim_list = [], []

        for i in range(100):
            noisy, clean = dataset[i]
            noisy = noisy.unsqueeze(0).unsqueeze(0).to(DEVICE)
            clean = clean.unsqueeze(0).unsqueeze(0).to(DEVICE)

            with torch.no_grad():
                recon = model(noisy)

            cn = clean[0,0].cpu().numpy()
            rn = recon[0,0].cpu().numpy()

            psnr_list.append(peak_signal_noise_ratio(cn, rn, data_range=1.0))
            ssim_list.append(structural_similarity(cn, rn, data_range=1.0))

        results[k] = (np.mean(psnr_list), np.mean(ssim_list))
        print(f"Gamma {k}: PSNR={results[k][0]:.2f}, SSIM={results[k][1]:.4f}")

    return results

# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    model = train_model()
    gamma_results = evaluate_gamma_robustness(model)

    # Plot
    ks = list(gamma_results.keys())
    psnrs = [gamma_results[k][0] for k in ks]
    ssims = [gamma_results[k][1] for k in ks]

    plt.figure(figsize=(6,4))
    plt.plot(ks, psnrs, marker="o", label="PSNR")
    plt.plot(ks, ssims, marker="s", label="SSIM")
    plt.xlabel("Gamma Shape Parameter k")
    plt.ylabel("Metric Value")
    plt.title("GMiSDE-Net Robustness to Gamma Mismatch")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{SAVE_DIR}/gamma_robustness_curve.png", dpi=150)
    plt.close()



Training GMiSDE-Net (Gamma=4.0)

[Epoch 01] Loss=0.0282
[Epoch 02] Loss=0.0195
[Epoch 03] Loss=0.0181
[Epoch 04] Loss=0.0172
[Epoch 05] Loss=0.0171
[Epoch 06] Loss=0.0162
[Epoch 07] Loss=0.0160
[Epoch 08] Loss=0.0155
[Epoch 09] Loss=0.0155
[Epoch 10] Loss=0.0153
[Epoch 11] Loss=0.0152
[Epoch 12] Loss=0.0149
Gamma 2: PSNR=19.07, SSIM=0.5988
Gamma 4: PSNR=22.29, SSIM=0.7226
Gamma 6: PSNR=22.89, SSIM=0.7501
Gamma 8: PSNR=23.12, SSIM=0.7549
Gamma 10: PSNR=23.14, SSIM=0.7560


In [13]:
# ============================================================
# Ablation Study: Sensitivity to Loss Constants (α, β)
# GMiSDE-Net | CIFAR-10 + Gamma Speckle
# ============================================================

import torch, numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

# ---------------- CONFIG ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMG_SIZE = 32
BATCH = 32
EPOCHS = 5          # SHORT: ablation only
LR = 2e-4

EPS = 1e-4
GAMMA_K = 4.0
NUM_EXPERTS = 4

ALPHAS = [0.0, 0.1, 0.2, 0.4]
BETAS  = [0.0, 0.05, 0.1]

# ============================================================
# DATASET
# ============================================================

class CIFAR10Speckle(Dataset):
    def __init__(self, train=True):
        tfm = T.Compose([
            T.Grayscale(),
            T.Resize((IMG_SIZE, IMG_SIZE)),
            T.ToTensor()
        ])
        base = datasets.CIFAR10("./data", train=train, download=True)
        self.data = []

        for img,_ in base:
            clean = tfm(img).squeeze(0).clamp(EPS,1.0)
            gamma = torch.distributions.Gamma(GAMMA_K, GAMMA_K).sample(clean.shape)
            noisy = (clean * gamma).clamp(EPS,1.0)
            self.data.append((noisy, clean))

    def __len__(self): return len(self.data)
    def __getitem__(self,i): return self.data[i]

# ============================================================
# MODEL (compact)
# ============================================================

class CHSN(nn.Module):
    def __init__(self):
        super().__init__()
        self.f = nn.Sequential(
            nn.Conv2d(1,32,3,1,1), nn.ReLU(),
            nn.Conv2d(32,64,3,1,1), nn.ReLU()
        )
        self.mu = nn.Conv2d(64,1,3,1,1)
        self.sg = nn.Conv2d(64,1,3,1,1)
        self.unc = nn.Conv2d(64,1,3,1,1)

    def forward(self,x):
        h = self.f(x)
        return self.mu(h), 0.1*F.softplus(self.sg(h))+EPS, torch.sigmoid(self.unc(h))

class GMiSDENet(nn.Module):
    def __init__(self):
        super().__init__()
        self.router = nn.Conv2d(1, NUM_EXPERTS, 1)
        self.experts = nn.ModuleList([CHSN() for _ in range(NUM_EXPERTS)])
        self.sde = nn.Sequential(
            nn.Conv2d(3,64,3,1,1), nn.ReLU(),
            nn.Conv2d(64,1,3,1,1)
        )

    def forward(self,x):
        w = torch.softmax(self.router(x), dim=1)
        mu = sg = uc = 0
        for i,e in enumerate(self.experts):
            m,s,u = e(x)
            mu += w[:,i:i+1]*m
            sg += w[:,i:i+1]*s
            uc += w[:,i:i+1]*u
        delta = self.sde(torch.cat([x,mu*uc,sg],1))
        return (x+delta).clamp(EPS,1.0), uc

# ============================================================
# TRAIN + EVAL (single run)
# ============================================================

def run_experiment(alpha, beta):
    train_data = CIFAR10Speckle(train=True)
    test_data  = CIFAR10Speckle(train=False)

    train_loader = DataLoader(train_data, BATCH, shuffle=True)
    model = GMiSDENet().to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=LR)

    # ---- training ----
    for _ in range(EPOCHS):
        for noisy, clean in train_loader:
            noisy = noisy.unsqueeze(1).to(DEVICE)
            clean = clean.unsqueeze(1).to(DEVICE)

            out, unc = model(noisy)

            mse = F.mse_loss(out, clean)
            log_loss = ((torch.log(out)-torch.log(clean))**2).mean()
            unc_reg = unc.mean()

            loss = mse + alpha*log_loss + beta*unc_reg

            opt.zero_grad()
            loss.backward()
            opt.step()

    # ---- evaluation ----
    psnr, ssim = [], []
    for i in range(50):
        noisy, clean = test_data[i]
        noisy = noisy.unsqueeze(0).unsqueeze(0).to(DEVICE)
        clean = clean.unsqueeze(0).unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            out,_ = model(noisy)

        c = clean[0,0].cpu().numpy()
        r = out[0,0].cpu().numpy()

        psnr.append(peak_signal_noise_ratio(c,r,data_range=1.0))
        ssim.append(structural_similarity(c,r,data_range=1.0))

    return np.mean(psnr), np.mean(ssim)

# ============================================================
# ABLATION LOOP
# ============================================================

if __name__ == "__main__":
    print("\n=== CONSTANT SENSITIVITY ABLATION ===\n")
    print(" alpha | beta  |  PSNR  |  SSIM ")
    print("----------------------------------")

    for a in ALPHAS:
        for b in BETAS:
            p,s = run_experiment(a,b)
            print(f" {a:4.2f} | {b:4.2f} | {p:6.2f} | {s:6.4f}")



=== CONSTANT SENSITIVITY ABLATION ===

 alpha | beta  |  PSNR  |  SSIM 
----------------------------------
 0.00 | 0.00 |  22.04 | 0.7261
 0.00 | 0.05 |  21.73 | 0.7112
 0.00 | 0.10 |  21.78 | 0.7109
 0.10 | 0.00 |  21.84 | 0.7152
 0.10 | 0.05 |  21.45 | 0.7143
 0.10 | 0.10 |  21.56 | 0.7070
 0.20 | 0.00 |  21.77 | 0.7133
 0.20 | 0.05 |  21.59 | 0.7169
 0.20 | 0.10 |  21.45 | 0.7097
 0.40 | 0.00 |  21.48 | 0.7188
 0.40 | 0.05 |  21.36 | 0.7047
 0.40 | 0.10 |  21.22 | 0.7100
