In [None]:
import os
os.environ['TF_USE_LEGACY_KERAS'] = 'True'
import re 
from scipy import ndimage, misc 
from tqdm import tqdm

from skimage.transform import resize, rescale
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

np.random.seed(0)
import cv2
import rasterio
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense ,Conv2D,MaxPooling2D ,Dropout
from tensorflow.keras.layers import Conv2DTranspose, UpSampling2D, add
from tensorflow.keras.models import Model
from tensorflow.keras import regularizers
from tensorflow.keras.utils import plot_model
from tensorflow.keras.preprocessing.image import img_to_array, load_img

print(tf.__version__)

# Verify that TensorFlow can access the GPU
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

### Load Raw Data

read low resolution images and high resolution images into RGB values.

In [None]:
hr_dataset_path = 'dataset/hr_dataset'
lr_dataset_path = 'dataset/lr_dataset'
hr_data = [d for d in os.listdir(hr_dataset_path) if os.path.isdir(os.path.join(lr_dataset_path, d))]
lr_data = [d for d in os.listdir(lr_dataset_path) if os.path.isdir(os.path.join(lr_dataset_path, d))]

In [None]:
print(f"Number of folders in '{hr_dataset_path}': {len(hr_data)}")
print(f"Number of folders in '{lr_dataset_path}': {len(lr_data)}")

In [None]:
# def load_lr_images(aoi_base_name, num_revisits):
#     """
#     Load low-resolution images from the dataset using rasterio.

#     Args:
#     - aoi_base_name: Base name of the AOI folder.
#     - num_revisits: Number of low-resolution images needed.

#     Returns:
#     - List of loaded images.
#     """
#     images = []
#     for i in range(1, num_revisits+1):
#         # lr_image_path = os.path.join('dataset/lr_dataset', aoi_base_name, 'L2A', f"{aoi_base_name}-{i}-L2A_data.TIFF")
#         lr_image_path = os.path.join('dataset/lr_dataset', aoi_base_name, 'L1C', f"{aoi_base_name}-{i}-L1C_data.TIFF")
#         if os.path.isfile(lr_image_path):
#             try:
#                 with rasterio.open(lr_image_path) as src:
#                     image = src.read([4, 3, 2])  # Read RGB channels if available
#                     image = np.moveaxis(image, 0, -1)  # Convert to HWC format
#                     images.append(image)
#             except Exception as e:
#                 print(f"Warning: Failed to read image {lr_image_path} with error {e}")
#         else:
#             print(f"Warning: Low-resolution image path {lr_image_path} does not exist")
#     return images

def load_hr_images(aoi):
    """
    Load high-resolution images from the dataset.

    Args:
    - aoi: Name of the AOI folder.

    Returns:
    - List of loaded images.
    """
    images = []
    hr_image_path = os.path.join('dataset/hr_dataset', aoi, f"{aoi}_ps.TIFF")
    if os.path.isfile(hr_image_path):
        try:
            with rasterio.open(hr_image_path) as src:
                image = src.read([1, 2, 3])  # Read RGB channels if available
                image = np.moveaxis(image, 0, -1)  # Convert to HWC format
                images.append(image)
        except Exception as e:
            print(f"Warning: Failed to read image {hr_image_path} with error {e}")
    else:
        print(f"Warning: Low-resolution image path {hr_image_path} does not exist")
    return images

def preprocess_image(image, target_size):
    """
    Preprocess the image for SRCNN.

    Args:
    - image: Input image.
    - target_size: Tuple of target size (width, height).

    Returns:
    - Preprocessed image.
    """
    image = cv2.resize(image, target_size)
    if np.max(image) != 0:
        image = image / np.max(image)
    
    return image

def process_aois(aoi_names, target_size_hr, target_size_lr, num_revisits):
    """
    Load and preprocess images for multiple AOIs.

    Args:
    - aoi_names: List of AOI base names.
    - target_size_hr: Tuple of target size (width, height) for high-resolution images.
    - target_size_lr: Tuple of target size (width, height) for low-resolution images.
    - num_revisits: Number of low-resolution images needed.

    Returns:
    - Dictionary with AOI names as keys and tuple of (LR images, HR images) as values.
    """
    data = {}
    for aoi_base_name in aoi_names:
        # print(f"Processing AOI: {aoi_base_name}")
        # lr_images = load_lr_images(aoi_base_name, num_revisits)
        hr_images = load_hr_images(aoi_base_name)

        # Preprocess images
        # lr_images = [preprocess_image(img, target_size_lr) for img in lr_images]
        hr_images = [preprocess_image(img, target_size_hr) for img in hr_images]

        # compress hr_images at scale of 5 to generate low res images
        lr_images = []
        for img in hr_images:
            lr_images.append(cv2.resize(img, target_size_lr, interpolation=cv2.INTER_AREA))
        
        # Convert lists to numpy arrays
        lr_images = np.array(lr_images)
        hr_images = np.array(hr_images)
        
        data[aoi_base_name] = (lr_images, hr_images)
        #print(f"Low-resolution images shape for {aoi_base_name}: {lr_images.shape}")
        # print(f"High-resolution images shape for {aoi_base_name}: {hr_images.shape}")
    
    return data

In [None]:
# Example usage
aoi_names = ['Amnesty POI-1-1-1', 'Amnesty POI-1-1-2'] # list of selected AOIs
target_size_hr = (500, 500)  # assuming the target size for SRCNN is (500, 500)
target_size_lr = (100, 100)
num_revisits = 1 # 1-16

aois_data = process_aois(aoi_names, target_size_hr, target_size_lr, num_revisits)

In [None]:
aois_data['Amnesty POI-1-1-1'][0].shape

In [None]:
aois_data['Amnesty POI-1-1-1'][1].shape

In [None]:
aois_data['Amnesty POI-1-1-1'][0]

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(aois_data['Amnesty POI-1-1-2'][0][0])

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(aois_data['Amnesty POI-1-1-2'][1][0])

### Prepare Training Data

In [None]:
metadata = pd.read_csv("dataset/metadata.csv")
metadata.rename(columns={metadata.columns[0]: 'aoi_name'}, inplace=True)
metadata.drop_duplicates(subset=['aoi_name'], keep='first', inplace=True)
metadata.head()

In [None]:
df_select = metadata[(metadata['aoi_name'].str.contains('Landcover')) & (metadata['cloud_cover'] < 0.05)]
len(df_select)

In [None]:
category_counts = df_select['IPCC Class'].value_counts()

# Plot the frequencies as a bar plot
plt.figure(figsize=(10, 6))
category_counts.plot(kind='bar')

# Add title and labels
plt.title('Frequency of Each Category in Your Column')
plt.xlabel('Category')
plt.ylabel('Frequency')

# Show the plot
plt.show()

In [None]:
landcover_folders = []
aoi_names_set = set(df_select['aoi_name'])

# Loop over the directories in lr_dataset
for folder_name in os.listdir(lr_dataset_path):
    if folder_name in aoi_names_set:
        landcover_folders.append(folder_name)
    # Stop when 1000 folders have been added
    if len(landcover_folders) >= 1000:
        break

#print(landcover_folders)

In [None]:
# upscale the low-resolution images
target_size_hr = (500, 500)  # assuming the target size for SRCNN is (500, 500)
target_size_lr = (100, 100)
sample_images = process_aois(landcover_folders, target_size_hr, target_size_lr, num_revisits)

In [None]:
# Observe the loaded images in the dataset
for i in range(20):
    plt.figure(figsize=(20,20))

    plt.subplot(1,2,1)
    plt.title('Low Resolution Image', color = 'red', fontsize = 20)
    plt.imshow(sample_images[landcover_folders[i]][0][0])

    plt.subplot(1,2,2)
    plt.title('High Resolution Image', color = 'green', fontsize = 20)
    plt.imshow(sample_images[landcover_folders[i]][1][0])

    plt.tight_layout()
    plt.show()

### Prepare Train, Test, and Validation

In [None]:
# Collect all high-res and low-res images from sample_images
all_high_images = []
all_low_images = []

for aoi_name in sample_images:
    all_low_images.extend(sample_images[aoi_name][0])  # Low-resolution images
    all_high_images.extend(sample_images[aoi_name][1])  # High-resolution images

# upscale the 100 x 100 low resolution image to 500 x 500 using bicubic interpolation
for i in range(len(all_low_images)):
    upscaled_img = cv2.resize(all_low_images[i], (500, 500), interpolation=cv2.INTER_CUBIC)
    all_low_images[i] = upscaled_img

# Convert lists to numpy arrays
all_high_images = np.array(all_high_images)
all_low_images = np.array(all_low_images)

# Split the data into train, validation, and test sets
train_high_image = all_high_images[:800]
train_low_image = all_low_images[:800]

validation_high_image = all_high_images[800:900]
validation_low_image = all_low_images[800:900]

test_high_image = all_high_images[900:]
test_low_image = all_low_images[900:]

# Print the shapes to verify
print("Shape of training images:", train_high_image.shape)
print("Shape of test images:", test_high_image.shape)
print("Shape of validation images:", validation_high_image.shape)

print("Shape of training low images:", train_low_image.shape)
print("Shape of test low images:", test_low_image.shape)
print("Shape of validation low images:", validation_low_image.shape)

### MODEL

In [None]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [None]:
SRCNN = Sequential([
    Conv2D(128, (9, 9), padding='same', input_shape=(500, 500, 3)),
    Activation('relu'),

    Conv2D(64, (3, 3), padding='same'),
    Activation('relu'),

    Conv2D(32, (1, 1), padding='same'),
    Activation('relu'),
    
    Conv2D(3, (5, 5), padding='same'),
    Activation('relu')
])

def pixel_mse_loss(x,y):
    return tf.reduce_mean( (x - y) ** 2 )

SRCNN.compile(optimizer=tf.keras.optimizers.legacy.Adam(0.001),loss=pixel_mse_loss)
# SRCNN.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=pixel_mse_loss)

SRCNN.summary()

checkpoint = ModelCheckpoint('srcnn_model_checkpoint.h5', monitor='val_loss', save_best_only=True, verbose=1)

### Training

In [None]:
early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

SRCNN.fit(
    train_low_image,
    train_high_image,
    epochs = 100,
    batch_size = 10,
    validation_data = (validation_low_image,validation_high_image),
    callbacks=[early_stop]
)

In [None]:
# save the weights
SRCNN.save_weights('srcnn_weights_2.h5')

### Define Image Quality Metrics

In [None]:
from skimage.metrics import structural_similarity as ssim
import math

def psnr(target, ref):
    """
    Compute the Peak Signal-to-Noise Ratio (PSNR) between two images.

    PSNR is used to measure the quality of reconstruction of lossy compression codecs.
    The signal is the original data, and the noise is the error introduced by compression.

    Args:
        target (numpy.ndarray): The target (reference) image.
        ref (numpy.ndarray): The reference image to compare against.

    Returns:
        float: The PSNR value in decibels (dB).

    """
    target_data = target.astype(float)
    ref_data = ref.astype(float)
    
    diff = ref_data-target_data
    diff = diff.flatten('C') # need ot flatten so computations can be done
    
    rmse = math.sqrt(np.mean(diff**2.))
    
    return 20 * math.log10(255./rmse)


def mse(target,ref):
    """
    
    Compute the Mean Squared Error (MSE) between two images.

    MSE is a risk function corresponding to the expected value of the squared error loss.
    It is used to measure the average of the squares of the errors, i.e., the average squared difference
    between the estimated values and the actual value.

    Args:
        target (numpy.ndarray): The target (reference) image.
        ref (numpy.ndarray): The reference image to compare against.

    Returns:
        float: The MSE value.

    """
    err=np.sum((target.astype('float')-ref.astype('float'))**2)
    err=err/float(target.shape[0]*target.shape[1]) # divided by total number of pixels
    
    return err


def compare_images(target,ref):
    """
    Compute the PSNR, MSE, and SSIM between two images.

    This function combines three image quality metrics:
    - Peak Signal-to-Noise Ratio (PSNR)
    - Mean Squared Error (MSE)
    - Structural Similarity Index (SSIM)

    Args:
        target (numpy.ndarray): The target (reference) image.
        ref (numpy.ndarray): The reference image to compare against.

    Returns:
        list: A list containing the PSNR, MSE, and SSIM values, respectively.

    """
    # _target = target * 255
    # _ref = ref * 255
    scores=[
        psnr(target, ref),
        mse(target, ref),
        ssim(target, ref, win_size=11, channel_axis=2, data_range=255)
    ]

    return scores


### Testing

In [None]:
pred_image = SRCNN.predict(test_low_image)

In [None]:
# Post-process predictions
pred_image *= 255
pred_image = np.clip(pred_image, 0, 255)
pred_image = pred_image.astype(np.uint8)


In [None]:
# Initialize list to hold scores for all images
pred_scores = []
original_scores = []

# Iterate over each image to compute the quality metrics
for i in range(len(test_low_image)):
    pred = pred_image[i]
    original = test_low_image[i] * 255
    truth = test_high_image[i] * 255
    original = original.astype(np.uint8)
    truth = truth.astype(np.uint8)
    # original = test_low_image[i]
    # truth = test_high_image[i]

    pred_scores.append(compare_images(pred, truth))
    original_scores.append(compare_images(original, truth))

### Visualize the results

In [None]:
for i in range(len(test_low_image)):
    fig, axs = plt.subplots(1, 3, figsize=(20, 8))
    
    axs[0].imshow(test_low_image[i])
    axs[0].set_title('Upscaled')
    axs[0].set(xlabel = 'PSNR: {}\nMSE: {} \nSSIM: {}'.format(original_scores[i][0], original_scores[i][1], original_scores[i][2]))

    axs[1].imshow(pred_image[i])
    axs[1].set_title('Generated by SRCNN')
    axs[1].set(xlabel = 'PSNR: {}\nMSE: {} \nSSIM: {}'.format(pred_scores[i][0], pred_scores[i][1], pred_scores[i][2]))

    axs[2].imshow(test_high_image[i])
    axs[2].set_title('Target')

    plt.tight_layout()
    plt.show()

### Effectiveness
Compare the PSNR/MSE/SSIM of the original low-resolution image and the processed images generated by the model

In [None]:
# Extract PSNR, MSE, and SSIM for plotting
original_psnr = [score[0] for score in original_scores]
original_mse = [score[1] for score in original_scores]
original_ssim = [score[2] for score in original_scores]
pred_psnr = [score[0] for score in pred_scores]
pred_mse = [score[1] for score in pred_scores]
pred_ssim = [score[2] for score in pred_scores]

# Calculate percentage increase
psnr_increase = [(p - o) / o * 100 for p, o in zip(pred_psnr, original_psnr)]
mse_increase = [(p - o) / o * 100 for p, o in zip(pred_mse, original_mse)]
ssim_increase = [(p - o) / o * 100 for p, o in zip(pred_ssim, original_ssim)]

# Number of test images
num_images = len(original_scores)
image_indices = range(num_images)

In [None]:
# Plot percentage increase for PSNR
plt.figure(figsize=(14, 6))
plt.bar(image_indices, psnr_increase, color='blue')
plt.xlabel('Test Image Index')
plt.ylabel('Percentage Increase (%)')
plt.title('Percentage Change in PSNR')
plt.show()
# Plot PSNR comparison
plt.figure(figsize=(12, 6))
plt.plot(range(num_images), original_psnr, label='Original Low-Res PSNR', marker='o')
plt.plot(range(num_images), pred_psnr, label='Generated High-Res PSNR', marker='x')
plt.xlabel('Test Image Index')
plt.ylabel('PSNR')
plt.title('PSNR Comparison')
plt.legend()
plt.show()

In [None]:
# Plot percentage increase for MSE
plt.figure(figsize=(14, 6))
plt.bar(image_indices, mse_increase, color='green')
plt.xlabel('Test Image Index')
plt.ylabel('Percentage Increase (%)')
plt.title('Percentage Change in MSE')
plt.show()
# Plot MSE comparison
plt.figure(figsize=(12, 6))
plt.plot(range(num_images), original_mse, label='Original Low-Res MSE', marker='o')
plt.plot(range(num_images), pred_mse, label='Generated High-Res MSE', marker='x')
plt.xlabel('Test Image Index')
plt.ylabel('MSE')
plt.title('MSE Comparison')
plt.legend()
plt.show()

In [None]:
# Plot percentage increase for SSIM
plt.figure(figsize=(14, 6))
plt.bar(image_indices, ssim_increase, color='purple')
plt.xlabel('Test Image Index')
plt.ylabel('Percentage Increase (%)')
plt.title('Percentage Change in SSIM')
plt.show()
# Plot SSIM comparison
plt.figure(figsize=(12, 6))
plt.plot(range(num_images), original_ssim, label='Original Low-Res SSIM', marker='o')
plt.plot(range(num_images), pred_ssim, label='Generated High-Res SSIM', marker='x')
plt.xlabel('Test Image Index')
plt.ylabel('SSIM')
plt.title('SSIM Comparison')
plt.legend()
plt.show()