In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf  # For tf.data and preprocessing only.
import tensorflow_addons as tfa
import keras
from keras import layers
# from keras import ops
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os
from skimage.io import imread, imshow
from skimage.transform import resize
from skimage import io  # for reading TIFF images
from sklearn.model_selection import train_test_split
import h5py
from sklearn.metrics import f1_score, recall_score, precision_score, cohen_kappa_score
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping
from sklearn.metrics import r2_score
import rasterio
import h5py
import time

In [None]:
num_classes = 1
input_shape = (90, 90, 15)

patch_size = (5, 5)  # 2-by-2 sized patches
dropout_rate = 0.03  # Dropout rate
num_heads = 64  # Attention heads
embed_dim = 128  # Embedding dimension
num_mlp = 1024  # MLP layer size
# Convert embedded patches to query, key, and values with a learnable additive
# value
qkv_bias = True
window_size = 2  # Size of attention window
shift_size = 0  # Size of shifting window
image_dimension = 90  # Initial image size

num_patch_x = input_shape[0] // patch_size[0]
num_patch_y = input_shape[1] // patch_size[1]

learning_rate = 1e-3
batch_size = 64
num_epochs = 140
weight_decay = 0.001
label_smoothing = 0.1

In [None]:
# Set the desired number of patches
desired_num_patches = 15000

# Your data directories
optic_images_dir = 'Sentinel-2-patches/'
sar_images_dir = 'Sentinel-1-patches/' 
agb_maps_dir = ''

# Get the list of file names in the directories without sorting
agb_map_files = os.listdir(agb_maps_dir)
optic_image_files = os.listdir(optic_images_dir)
sar_image_files = os.listdir(sar_images_dir)

# Initialize empty lists to store data
x_train_list = []
y_train_list = []
agb_values_list = []

# Loop through each patch with tqdm for progress bars
# Loop through each patch with tqdm for progress bars
for i in tqdm(range(desired_num_patches), desc='Loading Patches'):
    # Get file names for the current index
    optic_file = optic_image_files[i]
    sar_file = sar_image_files[i]
    agb_file = agb_map_files[i]
    # Load images
    optic_image = io.imread(os.path.join(optic_images_dir, optic_file))
    sar_image = io.imread(os.path.join(sar_images_dir, sar_file))
    agb_map = io.imread(os.path.join(agb_maps_dir, agb_file))

    # Save AGB values as a NumPy array
    agb_values_list.append(agb_map.flatten())

    # Ensure optic image has 13 bands
    if optic_image.shape[-1] == 13:
        x_train = optic_image
    else:
        raise ValueError("Optic image should have 13 bands. Found: {}".format(optic_image.shape[-1]))

    # Check if SAR image has 2 bands
    if len(sar_image.shape) == 2:
        # Expand dimensions to make it 3D
        sar_image = np.stack([sar_image, sar_image], axis=-1)

    # Normalize the data using min-max scaling
    optic_image_normalized = (optic_image - np.min(optic_image)) / (np.max(optic_image) - np.min(optic_image))
    sar_image_normalized = (sar_image - np.min(sar_image)) / (np.max(sar_image) - np.min(sar_image))

    # Concatenate optic and SAR images
    x_train = np.concatenate([optic_image_normalized, sar_image_normalized], axis=-1)

    x_train_list.append(x_train)

    # Break the loop when desired_num_patches is reached
    if len(x_train_list) >= desired_num_patches:
        break

# Combine all patches into single arrays
x_data = np.stack(x_train_list)
agb_values = np.concatenate(agb_values_list)

# Normalize AGB values using min-max scaling
min_agb = np.min(agb_values)
max_agb = np.max(agb_values)
agb_scaled = (agb_values - min_agb) / (max_agb - min_agb)

# Reshape scaled AGB values for compatibility with the model
agb_normalized = np.expand_dims(agb_scaled, axis=-1)

# Split the data into training, validation, and test sets
x_train, x_test, y_train, y_test = train_test_split(x_data, agb_normalized, test_size=0.20, random_state=42)

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.20, random_state=42)
print('x_train:', x_train.shape)
print('y_train:', y_train.shape)
print('x_val shape:', x_val.shape)
print('y_val shape:', y_val.shape)
print('x_test shape:', x_test.shape)
print('y_test shape:', y_test.shape)
input_shape = x_train.shape[1:]

In [None]:

def window_partition(x, window_size):
    _, height, width, channels = x.shape
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        x,
        (
            -1,
            patch_num_y,
            window_size,
            patch_num_x,
            window_size,
            channels,
        ),
    )
    x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
    windows = tf.reshape(x, (-1, window_size, window_size, channels))
    return windows


def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        windows,
        (
            -1,
            patch_num_y,
            patch_num_x,
            window_size,
            window_size,
            channels,
        ),
    )
    x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
    x = tf.reshape(x, (-1, height, width, channels))
    return x


In [None]:

class WindowAttention(layers.Layer):
    def __init__(
        self,
        dim,
        window_size,
        num_heads,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
        self.dropout = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(dim)

        num_window_elements = (2 * self.window_size[0] - 1) * (
            2 * self.window_size[1] - 1
        )
        self.relative_position_bias_table = self.add_weight(
            shape=(num_window_elements, self.num_heads),
            initializer=keras.initializers.Zeros(),
            trainable=True,
        )
        coords_h = np.arange(self.window_size[0])
        coords_w = np.arange(self.window_size[1])
        coords_matrix = np.meshgrid(coords_h, coords_w, indexing="ij")
        coords = np.stack(coords_matrix)
        coords_flatten = coords.reshape(2, -1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.transpose([1, 2, 0])
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)

        # self.relative_position_index = keras.Variable(
         #   initializer=relative_position_index,
        #    shape=relative_position_index.shape,
         #   dtype="int",
        #    trainable=False,
        #)


# Define self.relative_position_index using tf.Variable
        self.relative_position_index = tf.Variable(
                initial_value=relative_position_index,
                trainable=False,
                dtype=tf.int32  # Set the data type to int32
        )


    def call(self, x, mask=None):
        _, size, channels = x.shape
        head_dim = channels // self.num_heads
        x_qkv = self.qkv(x)
        x_qkv = tf.reshape(x_qkv, (-1, size, 3, self.num_heads, head_dim))
        x_qkv = tf.transpose(x_qkv, (2, 0, 3, 1, 4))
        q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
        q = q * self.scale
        k = tf.transpose(k, (0, 1, 3, 2))
        attn = q @ k

        num_window_elements = self.window_size[0] * self.window_size[1]
        relative_position_index_flat = tf.reshape(self.relative_position_index, (-1,))
        relative_position_bias = tf.gather(
            self.relative_position_bias_table,
            relative_position_index_flat,
            axis=0,
        )
        relative_position_bias = tf.reshape(
            relative_position_bias,
            (num_window_elements, num_window_elements, -1),
        )
        relative_position_bias = tf.transpose(relative_position_bias, (2, 0, 1))
        attn = attn + tf.expand_dims(relative_position_bias, axis=0)

        if mask is not None:
            nW = mask.shape[0]
            mask_float = tf.cast(
                tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0),
                "float32",
            )
            attn = tf.reshape(attn, (-1, nW, self.num_heads, size, size)) + mask_float
            attn = tf.reshape(attn, (-1, self.num_heads, size, size))
            attn = keras.activations.softmax(attn, axis=-1)
        else:
            attn = keras.activations.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        x_qkv = attn @ v
        x_qkv = tf.transpose(x_qkv, (0, 2, 1, 3))
        x_qkv = tf.reshape(x_qkv, (-1, size, channels))
        x_qkv = self.proj(x_qkv)
        x_qkv = self.dropout(x_qkv)
        return x_qkv


In [None]:

class SwinTransformer(layers.Layer):
    def __init__(
        self,
        dim,
        num_patch,
        num_heads,
        window_size=7,
        shift_size=0,
        num_mlp=1024,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.dim = dim  # number of input dimensions
        self.num_patch = num_patch  # number of embedded patches
        self.num_heads = num_heads  # number of attention heads
        self.window_size = window_size  # size of window
        self.shift_size = shift_size  # size of window shift
        self.num_mlp = num_mlp  # number of MLP nodes

        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(
            dim,
            window_size=(self.window_size, self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate,
        )
        self.drop_path = layers.Dropout(dropout_rate)
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)

        self.mlp = keras.Sequential(
            [
                layers.Dense(num_mlp),
                layers.Activation(keras.activations.gelu),
                layers.Dropout(dropout_rate),
                layers.Dense(dim),
                layers.Dropout(dropout_rate),
            ]
        )

        if min(self.num_patch) < self.window_size:
            self.shift_size = 0
            self.window_size = min(self.num_patch)

    def build(self, input_shape):
        if self.shift_size == 0:
            self.attn_mask = None
        else:
            height, width = self.num_patch
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            mask_array = np.zeros((1, height, width, 1))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    mask_array[:, h, w, :] = count
                    count += 1
            mask_array = tf.convert_to_tensor(mask_array)

            # mask array to windows
            mask_windows = window_partition(mask_array, self.window_size)
            mask_windows = tf.reshape(
                mask_windows, [-1, self.window_size * self.window_size]
            )
            attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
                mask_windows, axis=2
            )
            attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = tf.Variable(
                initial_value=attn_mask,

                dtype=attn_mask.dtype,
                trainable=False,
            )


    def call(self, x, training=False):
        height, width = self.num_patch
        _, num_patches_before, channels = x.shape
        x_skip = x
        x = self.norm1(x)
        x = tf.reshape(x, (-1, height, width, channels))
        if self.shift_size > 0:
            shifted_x = tf.roll(
                x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
            )
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = tf.reshape(
            x_windows, (-1, self.window_size * self.window_size, channels)
        )
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows = tf.reshape(
            attn_windows,
            (-1, self.window_size, self.window_size, channels),
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, height, width, channels
        )
        if self.shift_size > 0:
            x = tf.roll(
                shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
            )
        else:
            x = shifted_x

        x = tf.reshape(x, (-1, height * width, channels))
        x = self.drop_path(x, training=training)
        x = x_skip + x
        x_skip = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x_skip + x
        return x
    
    def get_config(self):
      config = super().get_config()
      config.update({
        "dim": self.dim,
        "num_patch": self.num_patch,
        "num_heads": self.num_heads,
        "window_size": self.window_size,
        "shift_size": self.shift_size,
        "num_mlp": self.num_mlp,
        
      })
      return config

In [None]:

# Using tf ops since it is only used in tf.data.
def patch_extract(images):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=(1, patch_size[0], patch_size[1], 1),
        strides=(1, patch_size[0], patch_size[1], 1),
        rates=(1, 1, 1, 1),
        padding="VALID",
    )
    patch_dim = patches.shape[-1]
    patch_num = patches.shape[1]
    return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))


class PatchEmbedding(layers.Layer):
    def __init__(self, num_patch, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        self.proj = layers.Dense(embed_dim)
        self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = tf.range(start=0, limit=self.num_patch)
        return self.proj(patch) + self.pos_embed(pos)

    def get_config(self):
        config = super().get_config()
        config.update({
            'num_patch': self.num_patch,
            'embed_dim': self.embed_dim,
        })
        return config


class PatchMerging(keras.layers.Layer):
    def __init__(self, num_patch, embed_dim, **kwargs):
        super().__init__(**kwargs) 
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)

    def call(self, x):
        height, width = self.num_patch
        _, _, C = x.shape
        x = tf.reshape(x, (-1, height, width, C))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = tf.concat((x0, x1, x2, x3), axis=-1)
        x = tf.reshape(x, (-1, (height // 2) * (width // 2), 4 * C))
        return self.linear_trans(x)
    def get_config(self):
          config = super().get_config()
          config.update({
            'num_patch': self.num_patch,
            'embed_dim': self.embed_dim,
            
          })
          return config

In [None]:

def augment(x):
    x = tf.image.random_crop(x, size=(image_dimension, image_dimension, 15))
    x = tf.image.random_flip_left_right(x)
    return x


dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .map(lambda x, y: (augment(x), y))
    .batch(batch_size=batch_size)
    .map(lambda x, y: (patch_extract(x), y))
    .prefetch(tf.data.experimental.AUTOTUNE)
)

dataset_val = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val))
    .batch(batch_size=batch_size)
    .map(lambda x, y: (patch_extract(x), y))
    .prefetch(tf.data.experimental.AUTOTUNE)
)

dataset_test = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(batch_size=batch_size)
    .map(lambda x, y: (patch_extract(x), y))
    .prefetch(tf.data.experimental.AUTOTUNE)
)


In [None]:
tf.data.experimental.save(dataset, 'dataset_train')
tf.data.experimental.save(dataset_val, 'dataset_val')
tf.data.experimental.save(dataset_test, 'dataset_test')

In [None]:
dataset_train = tf.data.experimental.load('dataset_train')
dataset_val = tf.data.experimental.load('dataset_val')
dataset_test = tf.data.experimental.load('dataset_test')

In [None]:
input = layers.Input(shape=(324, 375))
x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(input)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=0,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=shift_size,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)
x = layers.GlobalAveragePooling1D()(x)
output = layers.Dense(num_classes, activation="sigmoid")(x)

In [None]:
from tensorflow.keras.callbacks import Callback

class SaveBestModel(Callback):
    def __init__(self, filepath, monitor='val_mae', mode='auto', save_best_only=True):
        super(SaveBestModel, self).__init__()
        self.filepath = filepath
        self.monitor = monitor
        self.mode = mode
        self.save_best_only = save_best_only
        self.best_value = None

        if self.mode == 'min':
            self.monitor_op = lambda a, b: a < b
            self.best_value = float('inf')
        elif self.mode == 'max':
            self.monitor_op = lambda a, b: a > b
            self.best_value = float('-inf')
        else:
            self.monitor_op = lambda a, b: a == b

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            logs = {}
        current_value = logs.get(self.monitor)
        if current_value is None:
            print(f"WARNING: Can't find {self.monitor} in logs. Model not saved.")
            return

        if self.monitor_op(current_value, self.best_value):
            if self.save_best_only:
                print(f"\nEpoch {epoch + 1}: {self.monitor} improved from {self.best_value} to {current_value}, saving model to {self.filepath}")
                self.model.save(self.filepath, overwrite=True)
                self.best_value = current_value
            else:
                print(f"\nEpoch {epoch + 1}: {self.monitor} improved from {self.best_value} to {current_value}, saving model to {self.filepath}")
                self.model.save(self.filepath, overwrite=True)

In [None]:
import tensorflow_addons as tfa

from tensorflow.keras.callbacks import ModelCheckpoint

model = keras.Model(input, output)

optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

model.compile(
        optimizer=optimizer,
        loss=keras.losses.MeanSquaredError(),
        metrics=[keras.metrics.MeanAbsoluteError(name="mae")]
    )


checkpoint_filepath = 'best_model.keras'
checkpoint_callback = SaveBestModel(filepath=checkpoint_filepath, monitor='val_mae',
                                     mode='min', save_best_only=True)

# Training loop
history = model.fit(
    dataset_train,
    batch_size=batch_size,
    epochs=num_epochs,
    validation_data=dataset_val,
    callbacks=[checkpoint_callback],
)

In [None]:
# Load the saved model
loaded_model = keras.models.load_model('best_model (2).keras', custom_objects={'SwinTransformer': SwinTransformer,
                                                                            'WindowAttention': WindowAttention,
                                                                            'PatchEmbedding': PatchEmbedding,
                                                                            'PatchMerging': PatchMerging})

In [None]:
y_pred1 = model.predict(dataset_test)
y_pred = y_pred1 * (max_agb - min_agb) + min_agb
y_test2 = y_test * (max_agb - min_agb) + min_agb

In [None]:
# Flatten the arrays
y_test2_flat = y_test2.flatten()
y_pred_flat = y_pred.flatten()

# Calculate R^2
r2 = r2_score(y_test2_flat, y_pred_flat)

In [None]:
# Predictions on training data
y_train_pred = model.predict(dataset_train)
y_train_pred = y_train_pred * (max_agb - min_agb) + min_agb  # Denormalize predictions
y_train_denormalized = y_train * (max_agb - min_agb) + min_agb  # Denormalize true values

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
# # Flatten the arrays
y_train_denormalized_flat = y_train_denormalized.flatten()
y_train_pred_flat = y_train_pred.flatten()

# Calculate metrics for training data
train_rmse = np.sqrt(mean_squared_error(y_train_denormalized_flat, y_train_pred_flat))
train_mae = mean_absolute_error(y_train_denormalized_flat, y_train_pred_flat)
train_r2 = r2_score(y_train_denormalized_flat, y_train_pred_flat)

# Flatten the arrays for test data as well
y_test_denormalized_flat = y_test2.flatten()
y_test_pred_flat = y_pred.flatten()

# Calculate metrics for test data
test_rmse = np.sqrt(mean_squared_error(y_test_denormalized_flat, y_test_pred_flat))
test_mae = mean_absolute_error(y_test_denormalized_flat, y_test_pred_flat)
test_r2 = r2_score(y_test_denormalized_flat, y_test_pred_flat)

#Print or use the metrics as needed
print(f"Train RMSE: {train_rmse:.4f}")
print(f"Train MAE: {train_mae:.4f}")
print(f"Train R2: {train_r2:.4f}")
print(f"Test RMSE: {test_rmse:.4f}")
print(f"Test MAE: {test_mae:.4f}")
print(f"Test R2: {test_r2:.4f}")
