In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import time
import tensorflow.keras.backend as K

In [None]:
# load + preprocess image
def load_image(img_path, target_size=(256, 256)):
    img = load_img(img_path, target_size=target_size)
    img = img_to_array(img)
    img = img / 255.0
    return img


In [None]:
# UNet
def build_unet(input_shape=(256, 256, 3)):
    inputs = Input(input_shape)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    p4 = MaxPooling2D((2, 2))(c4)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
    u6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
    u7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
    u8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
    u9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = concatenate([u9, c1])
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
    outputs = Conv2D(3, (1, 1), activation='sigmoid')(c9)
    model = Model(inputs, outputs)
    return model


In [None]:
# DnCNN
def build_dncnn(input_shape=(256, 256, 3), depth=17, num_filters=64):
    input_img = Input(shape=input_shape)
    x = Conv2D(num_filters, (3, 3), padding='same', activation='relu')(input_img)
    for _ in range(depth - 2):
        x = Conv2D(num_filters, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(3, (3, 3), padding='same')(x)
    model = Model(input_img, x)
    return model


In [None]:
# ResNet
def build_resnet(input_shape=(256, 256, 3)):
    base_model = tf.keras.applications.ResNet50(weights=None, include_top=False, input_shape=input_shape)
    x = base_model.output
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(1024, activation='relu')(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dense(3, activation='sigmoid')(x)
    model = Model(inputs=base_model.input, outputs=x)
    return model


In [None]:
# PSNR
def compute_psnr(original, denoised):
    original = np.clip(original, 0., 1.)
    denoised = np.clip(denoised, 0., 1.)
    return peak_signal_noise_ratio(original, denoised)

# SSIM
def compute_ssim(original, denoised):
    original = np.clip(original, 0., 1.)
    denoised = np.clip(denoised, 0., 1.)
    return structural_similarity(original, denoised, channel_axis=-1)

# generation time
def measure_time(model, noisy_img):
    start_time = time.time()
    denoised_img = model.predict(np.expand_dims(noisy_img, axis=0))
    end_time = time.time()
    return denoised_img, end_time - start_time


In [None]:
# 定義感知損失函數
def compute_perceptual_loss(y_true, y_pred):
    return K.mean(K.square(y_true - y_pred))


In [None]:
# 定義混合損失函數，可決定要不要用感知損失
def custom_loss_function(y_true, y_pred):
    mse_loss = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
    if use_perceptual_loss:
      perceptual_loss = compute_perceptual_loss(y_true, y_pred)
      return mse_loss + 0.1 * perceptual_loss  # 感知損失的權重，固定在0.1
    else:
      return mse_loss


In [None]:
# 生成不同的beta_scheduler，控制每個階段加入的noise_level
def generate_beta_scheduler(scheduler_type, num_steps, beta_min=0.0001, beta_max=0.1):
    if scheduler_type == 'linear':
        return np.linspace(beta_min, beta_max, num_steps)
    elif scheduler_type == 'scaled_linear':
        return np.linspace(beta_min, beta_max, num_steps) ** 2
    else:
        raise ValueError(f"Unknown scheduler type: {scheduler_type}")


In [None]:
# 使用DDPM的foward過程生成雜訊圖像
def ddpm_forward(x0, step, alphabars):
    a_bar = alphabars[step]
    noise = np.random.normal(size=x0.shape)
    noisy_img = a_bar ** 0.5 * x0 + (1 - a_bar) ** 0.5 * noise
    return noisy_img


In [None]:
# 動態調整訓練次數的指數衰減函數, 做early-stopping
def calculate_epochs(beta, Emax, k, alpha):
    return int(Emax * (1 - alpha * np.exp(-k * beta)))


In [None]:
img_path = '/content/image_0.PNG'
input_img = load_image(img_path)


scheduler_type = 'linear'  # 'linear' or 'scaled_linear'
num_steps = 3
beta_min = 0.001    ###
beta_max = 0.01    ###
beta_scheduler = generate_beta_scheduler(scheduler_type, num_steps, beta_min, beta_max)
alphas = 1 - beta_scheduler
alpha_bars = np.array([np.prod(alphas[:i+1]) for i in range(len(alphas))])


# 生成不同noise level的影像序列
noisy_images = [ddpm_forward(input_img, i, alpha_bars) for i in range(num_steps)]

In [None]:
Emax = 100     # 最大訓練次數
k = 0.1       # 衰減係數
alpha = 0.5     # 調整參數

# 控制是否使用感知損失
use_perceptual_loss = True    ###


In [None]:
# training
results = {}
model_architecture = 'unet'  # unet or dncnn or resnet
psnr_threshold = 0.05
ssim_threshold = 0.0005
all_denoised_images = []  # 用來存每個階段的去雜訊圖片

for stage, (noisy_img, beta) in enumerate(zip(noisy_images, beta_scheduler)):
    if model_architecture == 'unet':
        model = build_unet()
    elif model_architecture == 'dncnn':
        model = build_dncnn()
    elif model_architecture == 'resnet':
        model = build_resnet()

    model.compile(optimizer=Adam(learning_rate=0.001), loss=custom_loss_function)

    stage_psnr = []
    stage_ssim = []

    previous_psnr = 0
    previous_ssim = 0

    for epoch in range(calculate_epochs(beta, Emax, k, alpha)):
        model.fit(np.expand_dims(noisy_img, axis=0), np.expand_dims(input_img, axis=0), epochs=1, batch_size=1, verbose=1)

        # 測量生成時間
        denoised_img, gen_time = measure_time(model, noisy_img)

        # 記錄PSNR + SSIM
        psnr_value = compute_psnr(input_img, denoised_img[0])
        ssim_value = compute_ssim(input_img, denoised_img[0])
        loss_value = model.evaluate(np.expand_dims(noisy_img, axis=0), np.expand_dims(input_img, axis=0), verbose=0)

        print(f"Stage: {stage + 1}, Epoch: {epoch + 1}, PSNR: {psnr_value}, SSIM: {ssim_value}, MSE: {loss_value}")

        stage_psnr.append(psnr_value)
        stage_ssim.append(ssim_value)

        if abs(psnr_value - previous_psnr) < psnr_threshold and abs(ssim_value - previous_ssim) < ssim_threshold:
            print(f"Early stopping at Stage: {stage + 1}, Epoch: {epoch + 1}")
            break

        previous_psnr = psnr_value
        previous_ssim = ssim_value


    results[f'Stage: {stage + 1}, Noise Level: {alpha_bars[stage]}, Beta: {beta}'] = {
        'PSNR': psnr_value,
        'SSIM': ssim_value,
        'Generation Time': gen_time
    }

    input_img = denoised_img[0]
    all_denoised_images.append(input_img)

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title('Original Image')
    plt.imshow(input_img)
    plt.subplot(1, 3, 2)
    plt.title('Noisy Image')
    plt.imshow(noisy_img)
    plt.subplot(1, 3, 3)
    plt.title('Denoised Image')
    plt.imshow(denoised_img[0])
    plt.show()

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

    ax1.plot(range(len(stage_psnr)), stage_psnr, label='PSNR')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('PSNR Value')
    ax1.set_title('PSNR over Epochs')
    ax1.legend()

    ax2.plot(range(len(stage_ssim)), stage_ssim, label='SSIM')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('SSIM Value')
    ax2.set_title('SSIM over Epochs')
    ax2.legend()

    plt.tight_layout()
    plt.show()
