# Reproducing Baseline Models: U-Net (4-band) & Attention U-Net (4-band) on Amazon Dataset

## 1. Introduction

This notebook reproduces the core baseline AI methodology from  
**“An Attention-Based U-Net for Detecting Deforestation Within Satellite Sensor Imagery”**  (Rolf et al., 2022), focusing specifically on the **4-band Amazon dataset**.

In the original work, several convolutional neural network architectures were evaluated for semantic segmentation of deforestation on Sentinel-2 imagery. In this notebook, we reproduce the two key baseline models:

- **U-Net (4-band Amazon dataset)**  
- **Attention U-Net (4-band Amazon dataset)**  

These models use four spectral bands (RGB + NIR), which provide strong discriminative power for vegetation and forest loss. Attention U-Net extends the standard U-Net with attention gates in the skip connections, allowing the model to focus more strongly on informative regions such as deforestation patches.

The aim of this notebook is to provide a **fully reproducible** implementation of:

1. Environment and dependency configuration  
2. Loading preprocessed 4-band Amazon data  
3. Implementing U-Net and Attention U-Net architectures  
4. Training both models from scratch  
5. Evaluating the models on validation data (Accuracy, Precision, Recall, F1-score, IoU)  
6. Comparing the reproduced results with the values reported in the paper  


## 2. Environment and Reproducibility

In this section we:

- Print key library versions (Python, TensorFlow, etc.)
- Check whether a GPU is available
- Set random seeds for reproducibility


In [13]:
import os
import sys
import random
import platform
import numpy as np
import tensorflow as tf

# Print environment information
print("Python version:", sys.version)
print("Platform:", platform.platform())
print("TensorFlow version:", tf.__version__)

# Check GPUs
gpus = tf.config.list_physical_devices('GPU')
print("Available GPUs:", gpus)

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)


Python version: 3.10.18 | packaged by conda-forge | (main, Jun  4 2025, 14:45:41) [GCC 13.3.0]
Platform: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
TensorFlow version: 2.20.0
Available GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## 3. Paths and Hyperparameters

We assume the following directory structure inside the `attention_unet` folder:

- `amazon-processed-large/training/images/*.npy`  
- `amazon-processed-large/training/masks/*.npy`  
- `amazon-processed-large/validation/images/*.npy`  
- `amazon-processed-large/validation/masks/*.npy`  

These `.npy` files are produced by the original `preprocess-4band-amazon-data.py` script and contain normalized 4-band image patches and binary masks.

Here we define basic hyperparameters (batch size, number of epochs, learning rate) 
for training U-Net and Attention U-Net.


In [14]:
from pathlib import Path

# Base path: this notebook is in baseline_replication/
ATT_UNET_ROOT = Path("attention_unet")
DATA_ROOT = ATT_UNET_ROOT / "amazon-processed-large"

train_img_dir = DATA_ROOT / "training" / "images"
train_mask_dir = DATA_ROOT / "training" / "masks"

val_img_dir = DATA_ROOT / "validation" / "images"
val_mask_dir = DATA_ROOT / "validation" / "masks"

print("Train images dir:", train_img_dir)
print("Train masks dir :", train_mask_dir)
print("Val images dir  :", val_img_dir)
print("Val masks dir   :", val_mask_dir)

for p in [train_img_dir, train_mask_dir, val_img_dir, val_mask_dir]:
    print(p, "exists:", p.exists())

# Hyperparameters
BATCH_SIZE = 4          # you can increase if GPU memory allows
EPOCHS_UNET = 30        # baseline U-Net training epochs
EPOCHS_ATT_UNET = 60    # Attention U-Net training epochs
LEARNING_RATE = 5e-4
IMG_HEIGHT = 512
IMG_WIDTH = 512
N_CHANNELS = 4



Train images dir: attention_unet/amazon-processed-large/training/images
Train masks dir : attention_unet/amazon-processed-large/training/masks
Val images dir  : attention_unet/amazon-processed-large/validation/images
Val masks dir   : attention_unet/amazon-processed-large/validation/masks
attention_unet/amazon-processed-large/training/images exists: True
attention_unet/amazon-processed-large/training/masks exists: True
attention_unet/amazon-processed-large/validation/images exists: True
attention_unet/amazon-processed-large/validation/masks exists: True


## 4. Loading Preprocessed 4-band Amazon Data

In this section we:

- List all `.npy` files in the training and validation folders
- Load them into NumPy arrays
- Ensure that shapes are consistent: `(N, 512, 512, 4)` for images 
  and `(N, 512, 512, 1)` for masks
- Wrap them into `tf.data.Dataset` objects for efficient training


In [15]:
import glob

def load_npy_stack(img_dir: Path, mask_dir: Path):
    """Load all .npy images and masks, stack into arrays with shape:
       X: (N, 512, 512, 4)
       Y: (N, 512, 512, 1)
    """
    img_paths = sorted(glob.glob(str(img_dir / "*.npy")))
    mask_paths = sorted(glob.glob(str(mask_dir / "*.npy")))

    print(f"Found {len(img_paths)} image npys, {len(mask_paths)} mask npys")

    imgs = []
    masks = []

    for img_path, mask_path in zip(img_paths, mask_paths):
        img = np.load(img_path)   # often (1, 512, 512, 4)
        msk = np.load(mask_path)  # often (1, 512, 512, 1)

        # ---- 统一处理 image ----
        img = np.squeeze(img)  # 去掉所有 size=1 维度
        # 现在 img 可能是 (512,512,4) 或 (512,512)
        if img.ndim == 2:  # 万一是单通道
            img = img[..., np.newaxis]  # (H,W)->(H,W,1)
        if img.ndim != 3:
            raise ValueError(f"Unexpected image shape {img.shape} for {img_path}")
        # 最后 img 期望是 (512,512,4)
        imgs.append(img.astype("float32"))

        # ---- 统一处理 mask ----
        msk = np.squeeze(msk)  # 去掉 size=1 维度
        # 现在 msk 可能是 (512,512)
        if msk.ndim == 2:
            msk = msk[..., np.newaxis]  # (H,W)->(H,W,1)
        if msk.ndim != 3:
            raise ValueError(f"Unexpected mask shape {msk.shape} for {mask_path}")
        masks.append(msk.astype("float32"))

    X = np.stack(imgs, axis=0)   # (N, 512, 512, 4)
    Y = np.stack(masks, axis=0)  # (N, 512, 512, 1)

    print("Images shape:", X.shape)
    print("Masks shape :", Y.shape)
    return X, Y


X_train, Y_train = load_npy_stack(train_img_dir, train_mask_dir)
X_val, Y_val = load_npy_stack(val_img_dir, val_mask_dir)


Found 250 image npys, 250 mask npys
Images shape: (250, 512, 512, 4)
Masks shape : (250, 512, 512, 1)
Found 100 image npys, 100 mask npys
Images shape: (100, 512, 512, 4)
Masks shape : (100, 512, 512, 1)


In [16]:
AUTOTUNE = tf.data.AUTOTUNE   # 新增这一行


In [17]:
AAUTOTUNE = tf.data.AUTOTUNE

def make_dataset(X, Y, batch_size, shuffle=True):
    ds = tf.data.Dataset.from_tensor_slices((X, Y))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(X), seed=SEED, reshuffle_each_iteration=True)
    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds

train_ds = make_dataset(X_train, Y_train, BATCH_SIZE, shuffle=True)
val_ds   = make_dataset(X_val, Y_val, BATCH_SIZE, shuffle=False)

train_ds, val_ds


(<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 512, 512, 4), dtype=tf.float32, name=None), TensorSpec(shape=(None, 512, 512, 1), dtype=tf.float32, name=None))>,
 <_PrefetchDataset element_spec=(TensorSpec(shape=(None, 512, 512, 4), dtype=tf.float32, name=None), TensorSpec(shape=(None, 512, 512, 1), dtype=tf.float32, name=None))>)

## 5. Model Architectures: U-Net (4-band) and Attention U-Net (4-band)

We implement two segmentation models:

1. **U-Net (4-band)**  
   - Standard encoder–decoder with skip connections  
   - Input shape: `(512, 512, 4)`  
   - Output: binary mask `(512, 512, 1)`

2. **Attention U-Net (4-band)**  
   - Same encoder–decoder backbone as U-Net  
   - Skip connections are modulated by **attention gates**  
   - Attention allows the model to focus on more informative spatial regions, which is particularly useful for detecting deforestation patches.

Both models are implemented in Keras (TensorFlow 2.x) using the Functional API.The definitions below closely follow the architecture described in the original paper and the reference implementation provided in the author's repository.


In [18]:
from tensorflow.keras import layers, models, optimizers

def conv_block(x, filters, kernel_size=3, padding="same", strides=1):
    x = layers.Conv2D(filters, kernel_size, padding=padding, strides=strides)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2D(filters, kernel_size, padding=padding, strides=strides)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    return x

def encoder_block(x, filters):
    c = conv_block(x, filters)
    p = layers.MaxPooling2D((2, 2))(c)
    return c, p

def decoder_block(x, skip, filters):
    x = layers.Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding="same")(x)
    x = layers.Concatenate()([x, skip])
    x = conv_block(x, filters)
    return x



def attention_gate(g, x, filters):
    """
    Attention Gate implementation matching U-Net spatial dims:
    g: gating signal (coarse, e.g. 32×32)
    x: skip connection (larger, e.g. 64×64)
    """
    # 1×1 conv for skip connection
    theta_x = layers.Conv2D(filters, (1,1), strides=(2,2), padding="same")(x)  # DOWNsample
    # 1×1 conv for gating signal
    phi_g = layers.Conv2D(filters, (1,1), padding="same")(g)

    # Combine + nonlinearity
    add = layers.Add()([theta_x, phi_g])
    act = layers.Activation("relu")(add)

    # Attention coefficients
    psi = layers.Conv2D(1, (1,1), padding="same")(act)
    psi = layers.Activation("sigmoid")(psi)

    # Upsample attention map back to skip connection size
    psi_up = layers.UpSampling2D(size=(2,2), interpolation="bilinear")(psi)

    # Apply attention coefficients to skip connection
    att = layers.Multiply()([x, psi_up])
    return att


In [19]:
def build_unet_4band(input_shape=(512, 512, 4), base_filters=16):
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1, p1 = encoder_block(inputs, base_filters)
    c2, p2 = encoder_block(p1, base_filters * 2)
    c3, p3 = encoder_block(p2, base_filters * 4)
    c4, p4 = encoder_block(p3, base_filters * 8)

    # Bridge
    bn = conv_block(p4, base_filters * 16)

    # Decoder
    d1 = decoder_block(bn, c4, base_filters * 8)
    d2 = decoder_block(d1, c3, base_filters * 4)
    d3 = decoder_block(d2, c2, base_filters * 2)
    d4 = decoder_block(d3, c1, base_filters)

    outputs = layers.Conv2D(1, (1, 1), activation="sigmoid")(d4)

    model = models.Model(inputs, outputs, name="UNet_4band")
    return model


def build_attention_unet_4band(input_shape=(512, 512, 4), base_filters=16):
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1, p1 = encoder_block(inputs, base_filters)
    c2, p2 = encoder_block(p1, base_filters * 2)
    c3, p3 = encoder_block(p2, base_filters * 4)
    c4, p4 = encoder_block(p3, base_filters * 8)

    # Bridge
    bn = conv_block(p4, base_filters * 16)

    # Decoder with attention
    g1 = bn
    att4 = attention_gate(g1, c4, base_filters * 8)
    d1 = decoder_block(bn, att4, base_filters * 8)

    g2 = d1
    att3 = attention_gate(g2, c3, base_filters * 4)
    d2 = decoder_block(d1, att3, base_filters * 4)

    g3 = d2
    att2 = attention_gate(g3, c2, base_filters * 2)
    d3 = decoder_block(d2, att2, base_filters * 2)

    g4 = d3
    att1 = attention_gate(g4, c1, base_filters)
    d4 = decoder_block(d3, att1, base_filters)

    outputs = layers.Conv2D(1, (1, 1), activation="sigmoid")(d4)

    model = models.Model(inputs, outputs, name="Attention_UNet_4band")
    return model


## 6. Loss Function, Metrics, and Optimizer

We treat deforestation segmentation as a **binary segmentation** problem  
(forest vs non-forest / deforested). We use:

- **Binary cross-entropy loss**
- **Accuracy** as a basic metric
- **Precision** and **Recall** as additional metrics

F1-score and IoU will be computed explicitly later for clearer comparison  
with the original paper.


In [20]:
from tensorflow.keras.metrics import Precision, Recall

def compile_model(model, lr=LEARNING_RATE):
    model.compile(
        optimizer=optimizers.Adam(learning_rate=lr),
        loss="binary_crossentropy",
        metrics=[
            "accuracy",
            Precision(name="precision"),
            Recall(name="recall"),
        ],
    )
    return model

unet_4band = build_unet_4band(input_shape=(IMG_HEIGHT, IMG_WIDTH, N_CHANNELS))
attention_unet_4band = build_attention_unet_4band(input_shape=(IMG_HEIGHT, IMG_WIDTH, N_CHANNELS))

compile_model(unet_4band)
compile_model(attention_unet_4band)

unet_4band.summary()
attention_unet_4band.summary()

## 7. Training U-Net (4-band) and Attention U-Net (4-band)

We train each model on the preprocessed 4-band Amazon training set and  
monitor performance on the validation set. We use:

- `ModelCheckpoint` to save the best model (based on validation accuracy)
- `ReduceLROnPlateau` to adjust learning rate when validation performance plateaus
- `EarlyStopping` to avoid overfitting

We first train the baseline **U-Net (4-band)**, then the **Attention U-Net (4-band)**.


In [21]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

callbacks_unet = [
    ModelCheckpoint(
        filepath=ATT_UNET_ROOT / "unet-4d-amazon.h5",
        monitor="val_accuracy",
        save_best_only=True,
        verbose=1,
    ),
    ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=5, verbose=1, min_lr=1e-6
    ),
]

callbacks_att_unet = [
    ModelCheckpoint(
        filepath=ATT_UNET_ROOT / "unet-attention-4d-amazon.h5",
        monitor="val_accuracy",
        save_best_only=True,
        verbose=1,
    ),
    ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=5, verbose=1, min_lr=1e-6
    ),
]


In [22]:
history_unet = unet_4band.fit(
    train_ds,
    epochs=EPOCHS_UNET,
    validation_data=val_ds,
    callbacks=callbacks_unet,
)


Epoch 1/30


2025-12-15 23:11:50.825866: I external/local_xla/xla/service/service.cc:163] XLA service 0x7f526c003c20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-12-15 23:11:50.825883: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2025-12-15 23:11:51.180003: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.


: 

In [None]:
history_att_unet = attention_unet_4band.fit(
    train_ds,
    epochs=EPOCHS_ATT_UNET,
    validation_data=val_ds,
    callbacks=callbacks_att_unet,
)


## 8. Training Curves

We plot training and validation loss/accuracy for both models to:

- Inspect convergence behaviour
- Compare U-Net vs Attention U-Net
- Check for signs of overfitting


In [None]:
import matplotlib.pyplot as plt

def plot_history(history, title_prefix=""):
    hist = history.history
    epochs = range(1, len(hist["loss"]) + 1)

    plt.figure()
    plt.plot(epochs, hist["loss"], label="train_loss")
    plt.plot(epochs, hist["val_loss"], label="val_loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"{title_prefix} Loss")
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(epochs, hist["accuracy"], label="train_acc")
    plt.plot(epochs, hist["val_accuracy"], label="val_acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title(f"{title_prefix} Accuracy")
    plt.legend()
    plt.show()

plot_history(history_unet, title_prefix="U-Net 4-band")
plot_history(history_att_unet, title_prefix="Attention U-Net 4-band")


## 9. Quantitative Evaluation on Validation Set

We now evaluate both models on the validation set using:

- Accuracy
- Precision
- Recall
- F1-score
- Intersection-over-Union (IoU)

Predictions are thresholded at 0.5 to obtain binary masks.


In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score

def evaluate_model_on_val(model, X_val, Y_val, threshold=0.5):
    """
    Compute pixel-wise metrics on the validation set.
    """
    preds = model.predict(X_val, batch_size=BATCH_SIZE)
    preds_bin = (preds >= threshold).astype("uint8")

    # Flatten
    y_true = Y_val.flatten()
    y_pred = preds_bin.flatten()

    acc = (y_true == y_pred).mean()
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    iou = jaccard_score(y_true, y_pred, zero_division=0)

    return {
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "iou": iou,
    }


In [None]:
# Reload best weights (based on val_accuracy)
best_unet_path = ATT_UNET_ROOT / "unet-4d-amazon.h5"
best_attunet_path = ATT_UNET_ROOT / "unet-attention-4d-amazon.h5"

unet_best = build_unet_4band(input_shape=(IMG_HEIGHT, IMG_WIDTH, N_CHANNELS))
attention_unet_best = build_attention_unet_4band(input_shape=(IMG_HEIGHT, IMG_WIDTH, N_CHANNELS))

compile_model(unet_best)
compile_model(attention_unet_best)

unet_best.load_weights(best_unet_path)
attention_unet_best.load_weights(best_attunet_path)

metrics_unet = evaluate_model_on_val(unet_best, X_val, Y_val)
metrics_attunet = evaluate_model_on_val(attention_unet_best, X_val, Y_val)

print("U-Net 4-band metrics:", metrics_unet)
print("Attention U-Net 4-band metrics:", metrics_attunet)


## 10. Comparison with Original Paper

Below we summarise the reproduced metrics on the 4-band Amazon validation set
and compare them to the values reported in the paper.

> **TODO (manual step):** Fill in the paper's reported metrics from the original paper  
> (Accuracy, F1, IoU for U-Net 4-band and Attention U-Net 4-band).

| Model                      | Metric    | Paper Value | Reproduced | Δ (absolute) |
|----------------------------|-----------|-------------|------------|--------------|
| U-Net (4-band, Amazon)     | Accuracy  | ...         | ...        | ...          |
| U-Net (4-band, Amazon)     | F1-score  | ...         | ...        | ...          |
| U-Net (4-band, Amazon)     | IoU       | ...         | ...        | ...          |
| Attention U-Net (4-band)   | Accuracy  | ...         | ...        | ...          |
| Attention U-Net (4-band)   | F1-score  | ...         | ...        | ...          |
| Attention U-Net (4-band)   | IoU       | ...         | ...        | ...          |

If all absolute differences Δ are within **±5%**, we satisfy the coursework 
requirement of reproducing the baseline performance.


In [None]:
import matplotlib.pyplot as plt

def show_example(idx):
    img = X_val[idx]
    gt = Y_val[idx]

    pred_unet = (unet_best.predict(img[np.newaxis, ...]) >= 0.5).astype("uint8")[0]
    pred_att  = (attention_unet_best.predict(img[np.newaxis, ...]) >= 0.5).astype("uint8")[0]

    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(img[..., :3])  # RGB for visualisation
    axs[0].set_title("Input (RGB)")
    axs[0].axis("off")

    axs[1].imshow(gt[..., 0], cmap="gray")
    axs[1].set_title("Ground Truth")
    axs[1].axis("off")

    axs[2].imshow(pred_unet[..., 0], cmap="gray")
    axs[2].set_title("U-Net Prediction")
    axs[2].axis("off")

    axs[3].imshow(pred_att[..., 0], cmap="gray")
    axs[3].set_title("Attention U-Net Prediction")
    axs[3].axis("off")

    plt.show()

# Show a few samples
for i in range(3):
    show_example(i)
