# Explore implementation of FCN model

In [9]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import tensorflow_datasets as tfds

import src.data.datasets.deep_globe_2018

## Reuse code from previous notebooks

In [12]:
IMAGE_SIZE = 612
BATCH_SIZE = 2

In [4]:
def normalize(input_image):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  return input_image

In [5]:
def rgb_to_index(image):
    palette = [
        [0, 255, 255],   # urban_land
        [255, 255, 0],   # agriculture_land
        [255, 0, 255],   # rangeland
        [0, 255, 0],     # forest_land
        [0, 0, 255],     # water
        [255, 255, 255], # barren_land
        [0, 0, 0]        # unknown
    ]
    
    one_hot_map = []
    for colour in palette:
        class_map = tf.reduce_all(tf.equal(image, colour), axis=-1)
        one_hot_map.append(class_map)
    one_hot_map = tf.stack(one_hot_map, axis=-1)
    one_hot_map = tf.cast(one_hot_map, tf.uint8)
    indexed = tf.math.argmax(one_hot_map, axis=2)
    indexed = tf.cast(indexed, dtype=tf.uint8)
    indexed = tf.expand_dims(indexed, -1)

    return indexed

In [14]:
def load_images(datapoint, image_size):
    
    images = tf.image.resize(datapoint['image'], (image_size, image_size))

    annotations = tf.map_fn(rgb_to_index, datapoint['segmentation_mask'])
    annotations = tf.image.resize(annotations, (image_size, image_size), method='nearest')
    
    images = normalize(images)

    return images, annotations

In [6]:
def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    print(len(display_list))
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

In [19]:
(ds_train, ds_valid, ds_test), ds_info = tfds.load(
    name='deep_globe_2018',
    download=False,
    with_info=True,
    split=['all_images[700:710]', 'all_images[7:9]', 'all_images[9:10]']
)
train_batches = (
    ds_train
    .batch(BATCH_SIZE)
    .map(lambda x: load_images(x, IMAGE_SIZE), num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)
validation_batches = (
    ds_valid
    .batch(BATCH_SIZE)
    .map(lambda x: load_images(x, IMAGE_SIZE), num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)
test_batches = (
    ds_test
    .batch(BATCH_SIZE)
    .map(lambda x: load_images(x, IMAGE_SIZE), num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)    

## Explore model

In [23]:
base_model = tf.keras.applications.VGG16(
    include_top=False,
    weights='imagenet',
    input_tensor=None,
    input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
)

In [24]:
for i, m in train_batches.take(1):
    y = base_model(i)

In [26]:
base_model.summary()

Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 612, 612, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 612, 612, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 612, 612, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 306, 306, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 306, 306, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 306, 306, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 153, 153, 128)     0     

In [25]:
inputs = tf.keras.Input(shape=(160, 160, 3))
x = base_model(inputs)
x = 

TensorShape([2, 19, 19, 512])