In [None]:
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv3D, UpSampling3D, ReLU, Add, Flatten, Dense, LeakyReLU
from tensorflow.keras.models import Model
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr

# Directories for train and test data
train_dir = "path_to_train_data"
test_dir = "path_to_test_data"

# Helper functions
def load_training_data(base_dir, channel="NIR", patch_size=(64, 64), stride=32):
    """Load training data (NIR or RED channel)."""
    channel_dir = os.path.join(base_dir, channel)
    lr_patches, hr_patches = [], []

    for imgset in os.listdir(channel_dir):
        imgset_path = os.path.join(channel_dir, imgset)
        if not os.path.isdir(imgset_path):
            continue

        hr_path = os.path.join(imgset_path, "HR.png")
        sm_path = os.path.join(imgset_path, "SM.png")
        if not os.path.exists(hr_path) or not os.path.exists(sm_path):
            continue

        hr_image = cv2.imread(hr_path, cv2.IMREAD_GRAYSCALE)
        status_map = cv2.imread(sm_path, cv2.IMREAD_GRAYSCALE)
        hr_image = np.where(status_map == 1, hr_image, 0)

        for file in os.listdir(imgset_path):
            if file.startswith("LR") and file.endswith(".png"):
                lr_path = os.path.join(imgset_path, file)
                qm_path = lr_path.replace("LR", "QM")
                if not os.path.exists(qm_path):
                    continue

                lr_image = cv2.imread(lr_path, cv2.IMREAD_GRAYSCALE)
                quality_map = cv2.imread(qm_path, cv2.IMREAD_GRAYSCALE)
                lr_image = np.where(quality_map == 1, lr_image, 0)

                lr_patches.extend(extract_patches(lr_image, patch_size, stride))
                hr_patches.extend(extract_patches(hr_image, patch_size, stride))

    return np.array(lr_patches), np.array(hr_patches)

def extract_patches(image, patch_size=(64, 64), stride=32):
    """Extract patches from an image."""
    patches = []
    for i in range(0, image.shape[0] - patch_size[0] + 1, stride):
        for j in range(0, image.shape[1] - patch_size[1] + 1, stride):
            patch = image[i:i + patch_size[0], j:j + patch_size[1]]
            patches.append(patch)
    return patches

def load_test_data(base_dir, channel="NIR"):
    """Load test data (NIR or RED channel)."""
    channel_dir = os.path.join(base_dir, channel)
    test_images, status_maps = [], []

    for imgset in os.listdir(channel_dir):
        imgset_path = os.path.join(channel_dir, imgset)
        if not os.path.isdir(imgset_path):
            continue

        for file in os.listdir(imgset_path):
            if file.startswith("LR") and file.endswith(".png"):
                lr_path = os.path.join(imgset_path, file)
                qm_path = lr_path.replace("LR", "QM")
                if not os.path.exists(qm_path):
                    continue

                lr_image = cv2.imread(lr_path, cv2.IMREAD_GRAYSCALE)
                status_map = cv2.imread(qm_path, cv2.IMREAD_GRAYSCALE)
                test_images.append(lr_image)
                status_maps.append(status_map)

    return np.array(test_images), np.array(status_maps)

# Load training data
lr_nir_patches, hr_nir_patches = load_training_data(train_dir, channel="NIR")
lr_red_patches, hr_red_patches = load_training_data(train_dir, channel="RED")
lr_patches = np.concatenate([lr_nir_patches, lr_red_patches], axis=0)
hr_patches = np.concatenate([hr_nir_patches, hr_red_patches], axis=0)
lr_patches = lr_patches / 255.0
hr_patches = hr_patches / 255.0

# Load test data
test_images, test_status_maps = load_test_data(test_dir, channel="NIR")
test_images = test_images / 255.0

# Generator Model
def build_generator(input_shape=(None, 64, 64, 1)):
    inputs = Input(input_shape)
    x = Conv3D(64, kernel_size=(3, 3, 3), padding='same')(inputs)
    x = ReLU()(x)

    for _ in range(16):
        x_initial = x
        x = Conv3D(64, kernel_size=(3, 3, 3), padding='same')(x)
        x = ReLU()(x)
        x = Conv3D(64, kernel_size=(3, 3, 3), padding='same')(x)
        x = Add()([x_initial, x])

    x = UpSampling3D(size=(1, 2, 2))(x)
    x = Conv3D(64, kernel_size=(3, 3, 3), padding='same')(x)
    x = ReLU()(x)
    x = UpSampling3D(size=(1, 2, 2))(x)
    outputs = Conv3D(1, kernel_size=(3, 3, 3), padding='same', activation='tanh')(x)
    return Model(inputs, outputs, name="Generator")

# Discriminator Model
def build_discriminator(input_shape=(64, 64, 1)):
    inputs = Input(input_shape)
    x = Conv3D(64, kernel_size=(3, 3, 3), padding='same')(inputs)
    x = LeakyReLU(alpha=0.2)(x)

    for filters in [128, 256, 512]:
        x = Conv3D(filters, kernel_size=(3, 3, 3), padding='same')(x)
        x = LeakyReLU(alpha=0.2)(x)

    x = Flatten()(x)
    outputs = Dense(1, activation='sigmoid')(x)
    return Model(inputs, outputs, name="Discriminator")

# Loss Functions
def content_loss(hr, sr):
    return tf.reduce_mean(tf.square(hr - sr))

def adversarial_loss(real_output, fake_output):
    return -tf.reduce_mean(real_output) + tf.reduce_mean(fake_output)

# Instantiate Models
generator = build_generator()
discriminator = build_discriminator()

# Optimizers
gen_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

@tf.function
def train_step(lr_batch, hr_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        sr_batch = generator(lr_batch, training=True)
        real_output = discriminator(hr_batch, training=True)
        fake_output = discriminator(sr_batch, training=True)

        g_loss = content_loss(hr_batch, sr_batch) + adversarial_loss(real_output, fake_output)
        d_loss = adversarial_loss(real_output, fake_output)

    gen_gradients = gen_tape.gradient(g_loss, generator.trainable_variables)
    disc_gradients = disc_tape.gradient(d_loss, discriminator.trainable_variables)
    gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))
    return g_loss, d_loss

# Training Loop
EPOCHS = 50
BATCH_SIZE = 16
for epoch in range(EPOCHS):
    for i in range(0, len(lr_patches), BATCH_SIZE):
        lr_batch = lr_patches[i:i + BATCH_SIZE][..., np.newaxis]
        hr_batch = hr_patches[i:i + BATCH_SIZE][..., np.newaxis]
        g_loss, d_loss = train_step(lr_batch, hr_batch)
    print(f"Epoch {epoch+1}/{EPOCHS}, Gen Loss: {g_loss.numpy()}, Disc Loss: {d_loss.numpy()}")

# Evaluate Test Data
def evaluate(generator, test_images, test_status_maps):
    psnr_vals, ssim_vals = [], []
    reconstructed_images = []

    for lr_image, sm in zip(test_images, test_status_maps):
        lr_image = lr_image[np.newaxis, ..., np.newaxis]  # Add batch and channel dimensions
        sr_image = generator.predict(lr_image)[0, ..., 0]  # Remove batch and channel dimensions
        sr_image = np.where(sm == 1, sr_image, 0)  # Mask invalid pixels
        reconstructed_images.append(sr_image)

        # Metrics
        psnr_vals.append(psnr(sm, sr_image))
        ssim_vals.append(ssim(sm, sr_image, multichannel=False))

    print(f"Average PSNR: {np.mean(psnr_vals):.2f}, Average SSIM: {np.mean(ssim_vals):.2f}")
    return reconstructed_images, psnr_vals, ssim_vals

reconstructed_images, psnr_vals, ssim_vals = evaluate(generator, test_images, test_status_maps)
