In [1]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback

2025-01-14 17:28:47.079653: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-14 17:28:47.087640: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1736855927.097412   99570 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1736855927.100139   99570 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-14 17:28:47.110757: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
tf.keras.backend.clear_session()

In [3]:
# Ensure GPU memory growth
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [4]:
class RealTimePlottingCallback(Callback):
    def __init__(self):
        super().__init__()
        self.epochs = []
        self.losses = []
        self.psnrs = []
        self.ssims = []
        self.snrs = []

    def on_epoch_end(self, epoch, logs=None):
        self.epochs.append(epoch)
        self.losses.append(logs['loss'])
        
        # Add the values for PSNR, SSIM, and SNR (modify these based on your metrics' names)
        if 'psnr' in logs:
            self.psnrs.append(logs['psnr'])
        else:
            self.psnrs.append(None)
        
        if 'ssim' in logs:
            self.ssims.append(logs['ssim'])
        else:
            self.ssims.append(None)
        
        if 'snr' in logs:
            self.snrs.append(logs['snr'])
        else:
            self.snrs.append(None)

        # Clear the plot and re-draw for each epoch
        plt.clf()
        
        # Plotting Loss
        plt.subplot(2, 2, 1)
        plt.plot(self.epochs, self.losses, label='Loss', color='red')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')

        # Plotting PSNR
        if self.psnrs[0] is not None:
            plt.subplot(2, 2, 2)
            plt.plot(self.epochs, self.psnrs, label='PSNR', color='blue')
            plt.xlabel('Epoch')
            plt.ylabel('PSNR')
            plt.title('PSNR')

        # Plotting SSIM
        if self.ssims[0] is not None:
            plt.subplot(2, 2, 3)
            plt.plot(self.epochs, self.ssims, label='SSIM', color='green')
            plt.xlabel('Epoch')
            plt.ylabel('SSIM')
            plt.title('SSIM')

        # Plotting SNR
        if self.snrs[0] is not None:
            plt.subplot(2, 2, 4)
            plt.plot(self.epochs, self.snrs, label='SNR', color='purple')
            plt.xlabel('Epoch')
            plt.ylabel('SNR')
            plt.title('SNR')

        # Adjust the layout and show the plot
        plt.tight_layout()
        plt.pause(0.1)

In [5]:
# Metrics
class PSNR(tf.keras.metrics.Metric):
    def __init__(self, name='psnr', **kwargs):
        super(PSNR, self).__init__(name=name, **kwargs)
        self.psnr_value = self.add_weight(name='psnr_value', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        max_pixel = 1.0
        psnr_value = tf.image.psnr(y_true, y_pred, max_val=max_pixel)
        self.psnr_value.assign(tf.reduce_mean(psnr_value))

    def result(self):
        return self.psnr_value

class SSIM(tf.keras.metrics.Metric):
    def __init__(self, name='ssim', **kwargs):
        super(SSIM, self).__init__(name=name, **kwargs)
        self.ssim_value = self.add_weight(name='ssim_value', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        ssim_value = tf.image.ssim(y_true, y_pred, max_val=1.0)
        self.ssim_value.assign(tf.reduce_mean(ssim_value))  # Update the state variable

    def result(self):
        return self.ssim_value

class SignalNoiseRatio(tf.keras.metrics.Metric):
    def __init__(self, name='snr', **kwargs):
        super(SignalNoiseRatio, self).__init__(name=name, **kwargs)
        self.snr_value = self.add_weight(name='snr_value', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        noise = tf.math.reduce_std(y_true - y_pred, axis=[1, 2])
        signal = tf.math.reduce_std(y_true, axis=[1, 2])
        snr_value = signal / noise
        self.snr_value.assign(tf.reduce_mean(snr_value))  # Update the state variable

    def result(self):
        return self.snr_value


In [6]:
LOW_RES_PATH = "./LR"
HIGH_RES_PATH = "./HR"
LR_SHAPE = (128, 64, 3)
HR_SHAPE = (512, 256, 3)
BATCH_SIZE = 4
EPOCHS = 100

In [7]:
# Generator Model
def build_generator(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, kernel_size=9, padding='same', activation='relu')(inputs)
    skip_connection = x

    for _ in range(16):
        res = layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(x)
        res = layers.Conv2D(64, kernel_size=3, padding='same')(res)
        x = layers.Add()([x, res])

    x = layers.Conv2D(64, kernel_size=3, padding='same')(x)
    x = layers.Add()([x, skip_connection])

    for _ in range(2):
        x = layers.Conv2DTranspose(256, kernel_size=3, strides=2, padding='same', activation='relu')(x)

    outputs = layers.Conv2D(3, kernel_size=9, padding='same', activation='tanh')(x)
    return models.Model(inputs, outputs, name="Generator")

# Discriminator Model
def build_discriminator(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, kernel_size=3, strides=2, padding='same', activation=layers.LeakyReLU(0.2))(inputs)

    filters = 32
    for _ in range(4):
        # filters *= 2
        x = layers.Conv2D(filters, kernel_size=3, strides=2, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1024, activation=layers.LeakyReLU(0.2))(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)
    return models.Model(inputs, outputs, name="Discriminator")


In [8]:
# Combined GAN Model
def build_gan(generator, discriminator):
    discriminator.trainable = False
    inputs = layers.Input(shape=LR_SHAPE)
    generated_image = generator(inputs)
    validity = discriminator(generated_image)
    return models.Model(inputs, [generated_image, validity], name="GAN")


In [9]:
# Dataset Loading and Preprocessing
def preprocess_image(image_path, target_size):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(image, target_size)
    image = tf.cast(image, tf.float32) / 127.5 - 1.0
    return image

def load_dataset(lr_path, hr_path, lr_shape, hr_shape, batch_size):
    lr_files = tf.data.Dataset.list_files(os.path.join(lr_path, "*.png"), shuffle=True)
    hr_files = tf.data.Dataset.list_files(os.path.join(hr_path, "*.png"), shuffle=True)

    lr_images = lr_files.map(lambda x: preprocess_image(x, lr_shape[:2]), num_parallel_calls=tf.data.AUTOTUNE)
    hr_images = hr_files.map(lambda x: preprocess_image(x, hr_shape[:2]), num_parallel_calls=tf.data.AUTOTUNE)

    dataset = tf.data.Dataset.zip((lr_images, hr_images))
    dataset = dataset.shuffle(buffer_size=256).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset


In [10]:
generator = build_generator(LR_SHAPE)
generator.summary()

I0000 00:00:1736855929.023628   99570 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5580 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070, pci bus id: 0000:01:00.0, compute capability: 8.6


In [11]:
discriminator = build_discriminator(HR_SHAPE)
discriminator.summary()

In [12]:
# Training
gan = build_gan(generator, discriminator)


In [13]:
# Define optimizers
optimizer = Adam(learning_rate=1e-4)
# disc_optimizer = Adam(learning_rate=1e-4)

discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'      ])
gan.compile(loss=['mse', 'binary_crossentropy'], loss_weights=[1, 1e-3], optimizer=optimizer,  metrics=[PSNR(), SignalNoiseRatio()])

In [18]:
print("Generator weights:", generator.trainable_weights)
print("Discriminator weights:", discriminator.trainable_weights)


Generator weights: [<Variable path=conv2d/kernel, shape=(9, 9, 3, 64), dtype=float32, value=[[[[-9.99196805e-03  1.03591606e-02 -3.13785113e-03 ... -2.24511251e-02
    -1.72486529e-03 -1.47933569e-02]
   [ 2.18704753e-02  3.16380933e-02 -2.23640241e-02 ... -4.48725373e-03
     1.94262378e-02  2.66004428e-02]
   [-3.89705971e-03 -1.31082032e-02 -1.89909954e-02 ... -3.18742469e-02
    -3.31456661e-02  1.72141567e-03]]

  [[ 1.67386830e-02  1.03368871e-02 -1.67574175e-02 ...  8.05681199e-03
    -6.08116388e-05 -3.08744051e-03]
   [-2.74033174e-02  1.71444118e-02  2.57267803e-02 ...  1.24175362e-02
    -1.22151300e-02  2.67993696e-02]
   [-1.38478130e-02 -1.28016789e-02 -5.98468445e-03 ...  1.68253332e-02
    -1.73481982e-02 -2.18508132e-02]]

  [[-1.41652860e-02 -2.61110719e-02  2.65461504e-02 ... -1.99728720e-02
     1.29503123e-02 -1.86886173e-02]
   [ 5.24412468e-03  1.34550706e-02 -7.90726021e-03 ...  2.17454880e-03
    -1.69465188e-02  1.74736120e-02]
   [ 2.53019929e-02  1.61973797e

In [14]:
dataset = load_dataset(LOW_RES_PATH, HIGH_RES_PATH, LR_SHAPE, HR_SHAPE, BATCH_SIZE)
real = tf.ones((BATCH_SIZE, 1))
fake = tf.zeros((BATCH_SIZE, 1))


In [15]:
for epoch in range(EPOCHS):
    i = 1
    for low_res, high_res in dataset:
        fake_high_res = generator.predict(low_res)

        # Train Discriminator
        d_loss_real = discriminator.train_on_batch(high_res, real)
        d_loss_fake = discriminator.train_on_batch(fake_high_res, fake)
        d_loss = 0.5 * (d_loss_real[0] + d_loss_fake[0])

        # Train Generator
        g_loss = gan.train_on_batch(low_res, [high_res, real])

        print(f"{i} images done!", end='\r')
        i += 1

    print(f"Epoch {epoch + 1}/{EPOCHS}, D Loss: {d_loss:.4f}, G Loss: {g_loss[0]:.4f}")


I0000 00:00:1736855933.905620   99664 service.cc:148] XLA service 0x707ae0006d60 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736855933.905640   99664 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3070, Compute Capability 8.6
2025-01-14 17:28:53.916162: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1736855933.961579   99664 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step


2025-01-14 17:28:57.814935: W external/local_xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version 12.5.82. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
I0000 00:00:1736855937.895773   99664 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


ValueError: Invalid reduction dimension 2 for input with 2 dimensions. for '{{node reduce_std/reduce_variance/Mean}} = Mean[T=DT_FLOAT, Tidx=DT_INT32, keep_dims=true](sub, reduce_std/reduce_variance/Mean/reduction_indices)' with input shapes: [4,1], [2] and with computed input tensors: input[1] = <1 2>.

In [24]:
# Prediction
def predict_on_image(img_path):
    img_array = preprocess_image(img_path, (256, 256))
    img_array = tf.expand_dims(img_array, axis=0)
    prediction = generator.predict(img_array)
    prediction = (prediction + 1.0) * 127.5
    prediction = np.clip(prediction, 0, 255).astype(np.uint8)

    original_img = tf.image.decode_png(tf.io.read_file(img_path))
    plt.figure(figsize=(10, 5))

    # Original Image
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_img)

    # Predicted Image
    plt.subplot(1, 2, 2)
    plt.title("Predicted High-Resolution Image")
    plt.imshow(prediction[0])
    plt.show()