# U-Net with ResNet50 Encoder for Golf Course Segmentation

## Architecture Overview

This notebook implements a **U-Net** architecture with a **ResNet50 encoder** for semantic segmentation of golf course aerial imagery.

### Why U-Net?
- **Encoder-Decoder structure**: Captures both high-level semantic features and low-level spatial details
- **Skip connections**: Preserve fine-grained spatial information lost during downsampling
- **Proven for segmentation**: Originally designed for biomedical image segmentation, works well for aerial imagery

### Why ResNet50 Encoder?
- **Transfer learning**: ImageNet pretrained weights provide robust feature extraction
- **Residual connections**: Enable training of deeper networks without degradation
- **Multi-scale features**: Different ResNet blocks capture features at 1/2, 1/4, 1/8, 1/16, 1/32 resolution

### Segmentation Classes (6)
| Class | Color | Description |
|-------|-------|-------------|
| 0 | Black | Background |
| 1 | Dark Green | Fairway |
| 2 | Bright Green | Green (putting surface) |
| 3 | Red | Tee box |
| 4 | Sandy Yellow | Bunker |
| 5 | Blue | Water hazard |

In [None]:
# Environment detection for Colab/local compatibility
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running on Google Colab")
    !pip install -q kagglehub
else:
    print("Running locally")

In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks
import kagglehub

## GPU Configuration

**Mixed Precision Training (float16)**:
- Reduces memory usage by ~50%, enabling larger batch sizes
- Speeds up computation on modern GPUs with Tensor Cores
- Maintains float32 for numerically sensitive operations (loss, gradients)

In [None]:
print(f"TensorFlow version: {tf.__version__}")

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    # Memory growth prevents TF from allocating all GPU memory at once
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    # Enable mixed precision for faster training
    keras.mixed_precision.set_global_policy('mixed_float16')
    print(f"GPU configured: {len(gpus)} device(s)")
    for gpu in gpus:
        print(f"  - {gpu}")
else:
    print("No GPU detected, using CPU")

In [None]:
# Download Danish Golf Courses dataset from Kaggle
# First run may require Kaggle authentication (upload kaggle.json)
dataset_path = kagglehub.dataset_download('jacotaco/danish-golf-courses-orthophotos')
print(f"Dataset path: {dataset_path}")

## Hyperparameters

| Parameter | Value | Rationale |
|-----------|-------|----------|
| Image Size | 512×832 | Maintains aspect ratio of orthophotos, fits in GPU memory |
| Batch Size | 2 | Limited by GPU memory due to large image size |
| Learning Rate | 1e-4 | Conservative for fine-tuning pretrained encoder |
| Augmentation | 25% | Prevents overfitting without excessive distortion |

In [None]:
BATCH_SIZE = 2          # Reduce to 1 if OOM on Colab T4
IMAGE_SIZE = (512, 832) # Height × Width - maintains orthophoto aspect ratio
IN_CHANNELS = 3
LEARNING_RATE = 1e-4
NUM_CLASSES = 6
MAX_EPOCHS = 10
AUGMENTATION_PROBABILITY = 0.25

# Dataset structure:
# - orthophotos/: RGB satellite images (.jpg)
# - segmentation masks/: Color-coded visualization (.png)
# - class masks/: Integer class labels 0-5 (.png)
base_path = dataset_path
IMAGES_DIR = os.path.join(base_path, '1. orthophotos')
SEGMASKS_DIR = os.path.join(base_path, '2. segmentation masks')
LABELMASKS_DIR = os.path.join(base_path, '3. class masks')

OUTPUT_DIR = '/content/output' if IN_COLAB else './output'
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# Preview a sample image and its segmentation mask
orthophoto_list = os.listdir(IMAGES_DIR)
print(f"Total images: {len(orthophoto_list)}")

idx = 5
golf_image = Image.open(os.path.join(IMAGES_DIR, orthophoto_list[idx]))
golf_segmask = Image.open(os.path.join(SEGMASKS_DIR, orthophoto_list[idx].replace(".jpg", ".png")))

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].set_title('Orthophoto')
axes[1].set_title('Segmentation Mask')
axes[0].imshow(golf_image)
axes[1].imshow(golf_segmask)
plt.show()

## Data Pipeline

### Preprocessing
- **Images**: Resize to 512×832, normalize to [0, 1]
- **Masks**: Resize with nearest-neighbor interpolation (preserves integer class labels)

### Synchronized Augmentation
For segmentation, image and mask must be augmented identically:
1. Concatenate image (3 channels) and mask (1 channel) → 4-channel tensor
2. Apply geometric transforms (flip, rotation) to combined tensor
3. Split back into image and mask
4. Apply photometric transforms (brightness, contrast) to image only

In [None]:
def load_and_preprocess_image(image_path, mask_path):
    """Load image-mask pair with proper preprocessing."""
    # Load and normalize image to [0, 1]
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, IMAGE_SIZE)
    image = tf.cast(image, tf.float32) / 255.0

    # Load mask - use nearest neighbor to preserve class labels
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize(mask, IMAGE_SIZE, method='nearest')
    mask = tf.cast(mask, tf.float32)
    mask = tf.squeeze(mask, axis=-1)  # Remove channel dimension

    return image, mask


def augment_image_and_mask(image, mask):
    """Synchronized augmentation - same transform applied to both."""
    def apply_augmentation():
        # Concatenate for synchronized geometric transforms
        mask_expanded = tf.expand_dims(mask, axis=-1)
        combined = tf.concat([image, mask_expanded], axis=-1)
        
        # Geometric: random horizontal flip
        combined = tf.image.random_flip_left_right(combined)
        
        # Split back
        aug_image = combined[:, :, :3]
        aug_mask = combined[:, :, 3]
        
        # Photometric: only apply to image (not mask)
        aug_image = tf.image.random_brightness(aug_image, 0.1)
        aug_image = tf.image.random_contrast(aug_image, 0.9, 1.1)
        aug_image = tf.clip_by_value(aug_image, 0.0, 1.0)
        
        return tf.cast(aug_image, tf.float32), aug_mask

    def keep_original():
        return tf.cast(image, tf.float32), mask

    # Stochastic augmentation: only augment 25% of samples
    should_augment = tf.random.uniform([]) < AUGMENTATION_PROBABILITY
    return tf.cond(should_augment, apply_augmentation, keep_original)


def create_dataset(images_dir, labelmasks_dir, shuffle=True):
    """Create tf.data.Dataset from directory."""
    image_filenames = sorted(os.listdir(images_dir))
    image_paths = [os.path.join(images_dir, fname) for fname in image_filenames]
    mask_paths = [os.path.join(labelmasks_dir, fname.replace('.jpg', '.png')) for fname in image_filenames]

    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(image_paths), seed=42)
    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset, len(image_paths)


def prepare_datasets():
    """70/20/10 train/val/test split."""
    full_dataset, total_size = create_dataset(IMAGES_DIR, LABELMASKS_DIR, shuffle=True)

    train_size = int(0.7 * total_size)
    val_size = int(0.2 * total_size)
    test_size = total_size - train_size - val_size

    print(f"Split: {train_size} train, {val_size} val, {test_size} test")

    train_ds = full_dataset.take(train_size)
    remaining = full_dataset.skip(train_size)
    val_ds = remaining.take(val_size)
    test_ds = remaining.skip(val_size)

    # Only augment training data
    train_ds = train_ds.map(augment_image_and_mask, num_parallel_calls=tf.data.AUTOTUNE)
    train_ds = train_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    test_ds = test_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    return train_ds, val_ds, test_ds

## U-Net Architecture

```
Input (512×832×3)
    │
    ▼
┌─────────────────────────────────────────┐
│         ResNet50 ENCODER                │
│  (ImageNet pretrained, frozen)          │
│                                         │
│  conv1_relu ──────────────────────────┐ │  1/2 (256×416)
│       │                               │ │
│  conv2_block3 ───────────────────┐    │ │  1/4 (128×208)
│       │                          │    │ │
│  conv3_block4 ──────────────┐    │    │ │  1/8 (64×104)
│       │                     │    │    │ │
│  conv4_block6 ─────────┐    │    │    │ │  1/16 (32×52)
│       │                │    │    │    │ │
│  conv5_block3 ───┐     │    │    │    │ │  1/32 (16×26) BOTTLENECK
└──────────────────│─────│────│────│────│─┘
                   │     │    │    │    │
┌──────────────────│─────│────│────│────│─┐
│         DECODER  │     │    │    │    │ │
│                  │     │    │    │    │ │
│  UpConv 512 ─────┘     │    │    │    │ │
│  + skip ───────────────┘    │    │    │ │  1/16
│  Conv 512×2                 │    │    │ │
│       │                     │    │    │ │
│  UpConv 256                 │    │    │ │
│  + skip ────────────────────┘    │    │ │  1/8
│  Conv 256×2                      │    │ │
│       │                          │    │ │
│  UpConv 128                      │    │ │
│  + skip ─────────────────────────┘    │ │  1/4
│  Conv 128×2                           │ │
│       │                               │ │
│  UpConv 64                            │ │
│  + skip ──────────────────────────────┘ │  1/2
│  Conv 64×2                              │
│       │                                 │
│  UpConv 32                              │  Full resolution
│  Conv 32×2                              │
│       │                                 │
│  Conv 1×1 (6 classes)                   │
└─────────────────────────────────────────┘
    │
    ▼
Output (512×832×6) logits
```

### Key Design Choices
- **Skip connections**: Concatenate encoder features at each resolution to preserve spatial detail
- **Conv2DTranspose**: Learned upsampling (better than bilinear interpolation)
- **Double convolutions**: Two 3×3 convs at each decoder level for feature refinement
- **1×1 output conv**: Maps features to class logits without spatial reduction

In [None]:
def build_unet_resnet50(input_shape=(512, 832, 3), num_classes=6):
    """U-Net with ResNet50 encoder pretrained on ImageNet."""
    inputs = keras.Input(shape=input_shape)

    # ENCODER: ResNet50 backbone
    # Using pretrained weights enables strong feature extraction
    # even with limited training data
    base_model = keras.applications.ResNet50(
        include_top=False,
        weights='imagenet',
        input_tensor=inputs
    )

    # Extract skip connections at each resolution
    # These preserve spatial details lost during downsampling
    skip_layer_names = [
        'conv1_relu',        # 1/2 resolution, 64 channels
        'conv2_block3_out',  # 1/4 resolution, 256 channels
        'conv3_block4_out',  # 1/8 resolution, 512 channels
        'conv4_block6_out',  # 1/16 resolution, 1024 channels
    ]
    skip_connections = [base_model.get_layer(name).output for name in skip_layer_names]
    bottleneck = base_model.get_layer('conv5_block3_out').output  # 1/32, 2048 channels

    # DECODER: Progressive upsampling with skip connections
    # Each decoder block: UpConv → Concat skip → Conv → Conv
    
    # 1/32 → 1/16
    x = layers.Conv2DTranspose(512, kernel_size=2, strides=2, padding='same')(bottleneck)
    x = layers.Concatenate()([x, skip_connections[3]])
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)

    # 1/16 → 1/8
    x = layers.Conv2DTranspose(256, kernel_size=2, strides=2, padding='same')(x)
    x = layers.Concatenate()([x, skip_connections[2]])
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)

    # 1/8 → 1/4
    x = layers.Conv2DTranspose(128, kernel_size=2, strides=2, padding='same')(x)
    x = layers.Concatenate()([x, skip_connections[1]])
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)

    # 1/4 → 1/2
    x = layers.Conv2DTranspose(64, kernel_size=2, strides=2, padding='same')(x)
    x = layers.Concatenate()([x, skip_connections[0]])
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)

    # 1/2 → full resolution
    x = layers.Conv2DTranspose(32, kernel_size=2, strides=2, padding='same')(x)
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)

    # Output: 1×1 conv to map to class logits
    # dtype='float32' ensures numerical stability with mixed precision
    outputs = layers.Conv2D(num_classes, kernel_size=1, padding='same', dtype='float32')(x)

    return keras.Model(inputs=inputs, outputs=outputs, name='UNet_ResNet50')

## Training Configuration

### Loss Function: Sparse Categorical Crossentropy
- `from_logits=True`: Model outputs raw logits, softmax applied in loss
- "Sparse": Masks contain integer class labels (0-5), not one-hot vectors

### Optimizer: AdamW
- Adam with decoupled weight decay
- Better generalization than standard Adam

In [None]:
model = build_unet_resnet50(input_shape=(*IMAGE_SIZE, 3), num_classes=NUM_CLASSES)

model.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=LEARNING_RATE),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

model.summary()

In [None]:
train_ds, val_ds, test_ds = prepare_datasets()

## Callbacks

- **ModelCheckpoint**: Save best model based on validation loss
- **EarlyStopping**: Stop if validation loss doesn't improve for 10 epochs
- **ReduceLROnPlateau**: Halve learning rate if plateau for 5 epochs

In [None]:
callback_list = [
    callbacks.ModelCheckpoint(
        filepath=os.path.join(OUTPUT_DIR, 'best_unet_resnet50.keras'),
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        verbose=1,
        min_lr=1e-7
    )
]

In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=MAX_EPOCHS,
    callbacks=callback_list,
    verbose=1
)

In [None]:
test_loss, test_accuracy = model.evaluate(test_ds)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")

## Visualization

Convert integer class masks to RGB using the standard color scheme.

In [None]:
# Class color mapping (RGB normalized to [0,1])
CLASS_COLORS = np.array([
    [0, 0, 0],        # 0: Background - Black
    [0, 140, 0],      # 1: Fairway - Dark Green
    [0, 255, 0],      # 2: Green - Bright Green
    [255, 0, 0],      # 3: Tee - Red
    [217, 230, 122],  # 4: Bunker - Sandy Yellow
    [7, 15, 247]      # 5: Water - Blue
], dtype=np.float32) / 255.0

CLASS_NAMES = ['Background', 'Fairway', 'Green', 'Tee', 'Bunker', 'Water']


def mask_to_rgb(mask):
    """Convert integer class mask to RGB visualization."""
    h, w = mask.shape
    rgb_mask = np.zeros((h, w, 3), dtype=np.float32)
    for class_id in range(NUM_CLASSES):
        rgb_mask[mask == class_id] = CLASS_COLORS[class_id]
    return rgb_mask

In [None]:
# Visualize predictions on test set
for images, masks in test_ds.take(3):
    # Get predictions (logits) and convert to class indices
    predictions = model.predict(images, verbose=0)
    pred_masks = np.argmax(predictions, axis=-1)
    
    for i in range(min(2, images.shape[0])):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(images[i].numpy())
        axes[0].set_title('Input')
        axes[0].axis('off')
        
        axes[1].imshow(mask_to_rgb(masks[i].numpy().astype(np.int32)))
        axes[1].set_title('Ground Truth')
        axes[1].axis('off')
        
        axes[2].imshow(mask_to_rgb(pred_masks[i]))
        axes[2].set_title('Prediction')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
model.save(os.path.join(OUTPUT_DIR, 'final_unet_resnet50.keras'))
print(f"Model saved to {OUTPUT_DIR}")

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Val')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train')
plt.plot(history.history['val_accuracy'], label='Val')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'training_history.png'), dpi=150)
plt.show()

In [None]:
# Download trained model (Colab only)
if IN_COLAB:
    from google.colab import files
    files.download(os.path.join(OUTPUT_DIR, 'final_unet_resnet50.keras'))