# Satellite Image Segmentation — TensorFlow (Converted)

**Converted from**: `Notebook2_satellite_image_segmentation_pytorch_final.ipynb`

**What this notebook contains:**

- Dataset pipeline using `tf.data`
- Image / mask loading & preprocessing
- A U-Net implementation in `tf.keras`
- Training loop using `model.fit` with callbacks
- Visualization utilities

**Notes:** Adjust the `IMAGE_DIR`, `MASK_DIR`, and dataset splits to match your local paths.

In [1]:
# Basic imports and GPU setup
import os
import math
import glob
import random
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

print('TensorFlow version:', tf.__version__)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print('GPUs detected:', gpus)
    except Exception as e:
        print('GPU config error:', e)


TensorFlow version: 2.10.1


In [8]:
# === USER SETTINGS ===
IMAGE_DIR = 'Water Bodies Dataset/Images'   # Updated to local path
MASK_DIR  = 'Water Bodies Dataset/Masks'    # Updated to local path

IMAGE_GLOB = os.path.join("Water Bodies Dataset/Images", "*.jpg")
MASK_GLOB  = os.path.join("Water Bodies Dataset/Masks", "*.jpg")


BATCH_SIZE = 8
IMG_HEIGHT = 256
IMG_WIDTH  = 256
AUTOTUNE = tf.data.AUTOTUNE
SEED = 42


In [9]:
# List image and mask files (matched by filename)
import os, glob

def pair_image_mask_paths(image_glob, mask_dir):
    image_paths = sorted(glob.glob(image_glob))
    pairs = []
    for img_path in image_paths:
        fname = os.path.basename(img_path)
        mask_path = os.path.join(mask_dir, fname)
        if os.path.exists(mask_path):
            pairs.append((img_path, mask_path))
    return pairs

pairs = pair_image_mask_paths(IMAGE_GLOB, MASK_DIR)
print('Found', len(pairs), 'image-mask pairs. Example:', pairs[:3])


Found 2841 image-mask pairs. Example: [('Water Bodies Dataset/Images\\water_body_1.jpg', 'Water Bodies Dataset/Masks\\water_body_1.jpg'), ('Water Bodies Dataset/Images\\water_body_10.jpg', 'Water Bodies Dataset/Masks\\water_body_10.jpg'), ('Water Bodies Dataset/Images\\water_body_100.jpg', 'Water Bodies Dataset/Masks\\water_body_100.jpg')]


In [10]:
# Image & mask loading + preprocessing functions
import tensorflow as tf

def read_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
    return image

def read_mask(path):
    mask = tf.io.read_file(path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize(mask, [IMG_HEIGHT, IMG_WIDTH], method='nearest')
    mask = tf.cast(mask > 127, tf.float32)
    return mask

def load_pair(image_path, mask_path):
    image = read_image(image_path)
    mask = read_mask(mask_path)
    return image, mask

# Simple augmentation example
def augment(image, mask):
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)
    if tf.random.uniform(()) > 0.5:
        image = tf.image.random_brightness(image, 0.1)
    return image, mask


In [11]:
# Build tf.data pipeline
import random

def make_dataset(pairs, batch_size=8, augment_fn=None, shuffle=True):
    image_paths = [p for p,_ in pairs]
    mask_paths  = [m for _,m in pairs]
    ds = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(pairs), seed=SEED)
    ds = ds.map(lambda i,m: tf.py_function(func=lambda a,b: load_pair(a.numpy().decode('utf-8'), b.numpy().decode('utf-8')), inp=[i,m], Tout=[tf.float32, tf.float32]), num_parallel_calls=AUTOTUNE)
    def _set_shapes(img, msk):
        img.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])
        msk.set_shape([IMG_HEIGHT, IMG_WIDTH, 1])
        return img, msk
    ds = ds.map(_set_shapes, num_parallel_calls=AUTOTUNE)
    if augment_fn is not None:
        ds = ds.map(lambda x,y: augment_fn(x,y), num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds

random.Random(SEED).shuffle(pairs)
n = len(pairs)
n_train = int(0.7 * n)
n_val   = int(0.2 * n)
train_pairs = pairs[:n_train]
val_pairs   = pairs[n_train:n_train+n_val]
test_pairs  = pairs[n_train+n_val:]

train_ds = make_dataset(train_pairs, augment_fn=augment, shuffle=True)
val_ds   = make_dataset(val_pairs, augment_fn=None, shuffle=False)
test_ds  = make_dataset(test_pairs, augment_fn=None, shuffle=False)

print('Train pairs:', len(train_pairs), 'Val pairs:', len(val_pairs), 'Test pairs:', len(test_pairs))


Train pairs: 1988 Val pairs: 568 Test pairs: 285


In [12]:
# U-Net model (Keras)
from tensorflow.keras import layers

def conv_block(x, filters, kernel_size=3, padding='same', activation='relu'):
    x = layers.Conv2D(filters, kernel_size, padding=padding)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(activation)(x)
    x = layers.Conv2D(filters, kernel_size, padding=padding)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(activation)(x)
    return x

def encoder_block(x, filters):
    c = conv_block(x, filters)
    p = layers.MaxPooling2D((2,2))(c)
    return c, p

def decoder_block(x, skip, filters):
    us = layers.UpSampling2D((2,2))(x)
    concat = layers.Concatenate()([us, skip])
    c = conv_block(concat, filters)
    return c

def build_unet(input_shape=(256,256,3), num_classes=1):
    inputs = layers.Input(shape=input_shape)
    c1, p1 = encoder_block(inputs, 64)
    c2, p2 = encoder_block(p1, 128)
    c3, p3 = encoder_block(p2, 256)
    c4, p4 = encoder_block(p3, 512)
    b = conv_block(p4, 1024)
    d1 = decoder_block(b, c4, 512)
    d2 = decoder_block(d1, c3, 256)
    d3 = decoder_block(d2, c2, 128)
    d4 = decoder_block(d3, c1, 64)
    outputs = layers.Conv2D(num_classes, (1,1), activation='sigmoid')(d4)
    model = keras.Model(inputs, outputs, name='UNet')
    return model

model = build_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))
model.summary()


Model: "UNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 256, 256, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 256, 256, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                              

In [13]:
# Losses & metrics
import tensorflow as tf

def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def bce_dice_loss(y_true, y_pred):
    bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
    dice = 1 - dice_coef(y_true, y_pred)
    return bce + dice

model.compile(optimizer=keras.optimizers.Adam(1e-4),
              loss=bce_dice_loss,
              metrics=[dice_coef, 'accuracy'])


In [16]:
# Callbacks and training
checkpoint_path = 'unet_checkpoint.h5'
callbacks = [
    keras.callbacks.ModelCheckpoint(checkpoint_path, save_best_only=True, monitor='val_loss'),
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
]

EPOCHS = 20
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=callbacks
)


Epoch 1/20

KeyboardInterrupt: 

In [None]:
# Evaluate on test set
results = model.evaluate(test_ds)
print('Test results (loss, dice, acc):', results)

# Visualize predictions
import matplotlib.pyplot as plt

def visualize_samples(dataset, model, n=6):
    plt.figure(figsize=(12, 4*n))
    i = 0
    for imgs, masks in dataset.unbatch().take(n):
        pred = model.predict(tf.expand_dims(imgs, 0))[0]
        plt.subplot(n, 3, i*3 + 1)
        plt.imshow(imgs.numpy())
        plt.title('Image')
        plt.axis('off')
        plt.subplot(n, 3, i*3 + 2)
        plt.imshow(masks.numpy().squeeze(), cmap='gray')
        plt.title('Mask')
        plt.axis('off')
        plt.subplot(n, 3, i*3 + 3)
        plt.imshow((pred.squeeze() > 0.5).astype('float32'), cmap='gray')
        plt.title('Prediction')
        plt.axis('off')
        i += 1
    plt.tight_layout()

visualize_samples(val_ds, model, n=4)


## Saving the model

You can save the trained model for later inference:

```python
model.save('unet_saved_model')
```

Adjust any dataset paths or hyperparameters to match your original PyTorch notebook behaviour.

#Conclusion

Model used: U-Net (a convolutional encoder–decoder network for image segmentation).

Loss function used: A combination of Binary Cross-Entropy (BCE) and Dice Loss.