# Explore the EarlyConvNet model class

In [9]:
from src.models import early_convnet

2023-12-31 21:26:20.203161: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX512F, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
model.summary()

Model: "early_convnet"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 C1 (Conv2D)                 multiple                  888       
                                                                 
 S2 (Subsampling)            multiple                  12        
                                                                 
 C3 (Conv2D)                 multiple                  3472      
                                                                 
 S4 (Subsampling)            multiple                  32        
                                                                 
 C5 (Conv2D)                 multiple                  23080     
                                                                 
 F6 (Dense)                  multiple                  287       
                                                                 
Total params: 27771 (108.48 KB)
Trainable params: 277

## Imports

In [13]:
import src.data.datasets.deep_globe_2018

import tensorflow_datasets as tfds
import tensorflow as tf
import keras

  from .autonotebook import tqdm as notebook_tqdm


## Parameters

In [30]:
## Pipeline
batch_size_images = 1
batch_size_patches = 8
img_size = 612
patch_size = 64
patch_size_annotation = 2
patch_stride = 32

## Training
epochs = 1

## Utility functions

In [4]:
def normalize(input_image):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  return input_image

In [5]:
def rgb_to_index(image):
    palette = [
        [0, 255, 255],   # urban_land
        [255, 255, 0],   # agriculture_land
        [255, 0, 255],   # rangeland
        [0, 255, 0],     # forest_land
        [0, 0, 255],     # water
        [255, 255, 255], # barren_land
        [0, 0, 0]        # unknown
    ]
    
    one_hot_map = []
    for colour in palette:
        class_map = tf.reduce_all(tf.equal(image, colour), axis=-1)
        one_hot_map.append(class_map)
    one_hot_map = tf.stack(one_hot_map, axis=-1)
    one_hot_map = tf.cast(one_hot_map, tf.uint8)
    indexed = tf.math.argmax(one_hot_map, axis=2)
    indexed = tf.cast(indexed, dtype=tf.uint8)
    indexed = tf.expand_dims(indexed, -1)

    return indexed

In [6]:
def load_patches_labels(datapoint, image_size, patch_size, patch_size_annotation, stride):
    crop_fraction = patch_size_annotation / patch_size
    
    images = tf.image.resize(datapoint['image'], (image_size, image_size))
    img_patches = tf.image.extract_patches(
        images = images,
        sizes = [1, patch_size, patch_size, 1],
        strides = [1, stride, stride, 1],
        rates = [1, 1, 1, 1],
        padding = 'VALID'
    )
    img_patches_flat = tf.reshape(img_patches, shape=(-1, patch_size, patch_size, 3))

    annotations = tf.map_fn(rgb_to_index, datapoint['segmentation_mask'])
    annotations = tf.image.resize(annotations, (image_size, image_size), method='nearest')

    ann_patches = tf.image.extract_patches(
        images = annotations,
        sizes = [1, patch_size, patch_size, 1],
        strides = [1, stride, stride, 1],
        rates = [1, 1, 1, 1],
        padding = 'VALID'
    )
    ann_patches_flat = tf.reshape(ann_patches, shape=(-1, patch_size, patch_size, 1))
    central_pixels = tf.image.central_crop(ann_patches_flat, crop_fraction)
    dim = tf.reduce_prod(tf.shape(central_pixels)[1:])
    central_pixels = tf.reshape(central_pixels, [-1, dim])
    pixel_category_idx = tf.reduce_max(central_pixels, axis=1) # reduce_mode is probably preferred but I chose a simpler implementation

    img_patches_flat = normalize(img_patches_flat)
    pixel_category_one_hot = tf.one_hot(
        pixel_category_idx,
        depth = 7, # TODO: make depth configurable
        on_value = 1,
        off_value = -1
    )

    return img_patches_flat, pixel_category_one_hot

## Optimizer, loss function, dataset and model

### Dataset

In [34]:
(ds_train, ds_valid, ds_test), ds_info = tfds.load(
    name='deep_globe_2018',
    download=False,
    with_info=True,
    split=['all_images[:7]', 'all_images[7:9]', 'all_images[9:10]']
)
train_batches = (
    ds_train
    .batch(batch_size_images)
    .map(lambda x: load_patches_labels(x, img_size, patch_size, patch_size_annotation, patch_stride), num_parallel_calls=tf.data.AUTOTUNE)
    .unbatch() # Flatten the batches for training
    .batch(batch_size_patches) # Rebatch patches as desired
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)
# ds_info

In [None]:
# train_batches = (
#     ds_train
#     .take(1)
#     .batch(1)
#     .map(lambda x: load_patches_labels(x, img_size, patch_size, patch_size_annotation, patch_stride), num_parallel_calls=tf.data.AUTOTUNE)
#     .prefetch(buffer_size=tf.data.AUTOTUNE)
# )
# for i, m in train_batches.take(1):
#     sample_images = i[1190:1210]
#     sample_masks = m[1190:1210]
#     print(i.shape, m.shape)
#     # samples = list(zip(sample_images, sample_masks))

### Loss function

In [16]:
loss_fn = tf.keras.losses.MeanSquaredError()

### Optimizer

In [17]:
optimizer = keras.optimizers.SGD(learning_rate=1e-3)

### Model

In [25]:
model = early_convnet.EarlyConvnet()

In [26]:
model.build((None, patch_size,patch_size,3))


## Training loop

In [None]:
def train_on_patch_in_batch():

In [37]:
# Source: https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch (last accessed 31.12.2023)

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    
    for step, (img_batch_train, cat_batch_train) in enumerate(train_batches):
        with tf.GradientTape() as tape:
            logits = model(img_batch_train, training=True)
            print(logits.shape)
            print(cat_batch_train.shape)
            loss_value = loss_fn(cat_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * batch_size))


Start of epoch 0
(8, 7, 7, 7)
(8, 7)


InvalidArgumentError: {{function_node __wrapped__SquaredDifference_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [8,7,7,7] vs. [8,7] [Op:SquaredDifference] name: 