This is the dataset currently loaded:

1. [Sparcs Dataset ~2GB](https://www.usgs.gov/core-science-systems/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs)
  1. satellite tiff file (format w/ multiple channels, i.e. infrared as well as RGB)
  2. txt metadata about the image
  3. a satellite image png
  4. a satellite mask png (with colors representing masks)

These are some other options we have:

1. [Landsat Validation Data ~100GB](https://www.usgs.gov/core-science-systems/nli/landsat/landsat-8-cloud-cover-assessment-validation-data?qt-science_support_page_related_con=1#qt-science_support_page_related_con)

2. [Kaggle Dataset ~20GB](https://www.kaggle.com/sorour/95cloud-cloud-segmentation-on-satellite-images)

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import Sequential, layers, preprocessing 

import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

# convenience key-word args to parallel process
parallel_map_kwargs = dict(
  num_parallel_calls=tf.data.AUTOTUNE,
  deterministic=False)

In [None]:
#### DOWNLOAD THE SPARCS DATASET ####
def download(data_url):
  dl_manager = tfds.download.DownloadManager(download_dir='junk', extract_dir='/content/clouds')
  dataset_path = dl_manager.download_and_extract(data_url)
  dataset_path += '/sending' # weird USGS quirks
  return dataset_path

# creates a dataset consisting of image file paths
SPARCS_DATA_URL = 'https://landsat.usgs.gov/cloud-validation/sparcs/l8cloudmasks.zip'
dataset_path = download(SPARCS_DATA_URL)

In [25]:
dataset_path = '/content/clouds/ZIP.landsa.usgs.gov_cloud-valida_sparcs_l8clouN5mc1TWFYYYxSYyyS6tlUpIEWUINgMuHXOHfkoDGofw.zip/sending'

In [36]:
#### READ IMAGE & MASK TO DATASET ####
def read_img_and_mask(img_path: tf.Tensor):
    # read img at specified path
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img)
    #img = tf.image.rgb_to_grayscale(img)

    # read corresponding mask (whose path replaces 'photo' w/ 'mask')
    mask_path = tf.strings.regex_replace(img_path, "photo", "mask")
    mask = Image.open(mask_path.numpy())
    mask = tf.convert_to_tensor(np.array(mask))
    #mask = tf.where(mask == 5, tf.ones_like(mask), tf.zeros_like(mask))

    # tuple of img and mask
    return img, mask

In [37]:
ds = tf.data.Dataset.list_files(dataset_path + "/*photo.png")
# read in each image and its mask using those file paths 
ds = ds.map(lambda x: tf.py_function(func=read_img_and_mask,
                                     inp=[x], 
                                     Tout=(tf.uint8, tf.uint8)
                                     ), **parallel_map_kwargs)
# size of dataset
CARDINALITY = ds.cardinality()

In [38]:
# take n random crops of an image and its mask
@tf.function
def sample_crop(img, mask, w, h, n):
  img_and_mask = tf.experimental.numpy.dstack((img, mask))
  crops = [tf.image.random_crop(img_and_mask, (w, h, 4)) for i in range(n)]
  crops = tf.stack(crops)
  crops = tf.data.Dataset.from_tensor_slices(crops)
  return crops

In [39]:
# randomly crop each img (and its mask) several times
n, w, h = 5, 64, 64
ds = ds.interleave(lambda img, mask: sample_crop(img, mask, w, h, n), **parallel_map_kwargs)

# tf doesn't know cardinality after interleave, so we help it out
CARDINALITY *= n
ds = ds.apply(tf.data.experimental.assert_cardinality(CARDINALITY))

In [40]:
# represent as tuple of img, mask rather than mask stacked beneath image
ds = ds.map(lambda x: (x[:, :, 0:3], x[:, :, 3]), **parallel_map_kwargs)

In [41]:
@tf.function
def cloud_score(img, mask):  
  # cloud mask is 5
  clouds = tf.math.count_nonzero(mask == 5)
  cloud_score = clouds / tf.size(mask, out_type=tf.int64)
  cloud_score = cloud_score > 0.5
  return img, cloud_score

ds = ds.map(cloud_score, **parallel_map_kwargs)

In [70]:
# random shuffle
ds.shuffle(buffer_size=CARDINALITY)

# after shuffle, first 20% are test, last 80% are train
test_ds = ds.take(CARDINALITY // 5)
train_ds = ds.skip(CARDINALITY // 5)

test_ds = test_ds.prefetch(tf.data.AUTOTUNE).batch(32)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE).batch(32)

In [71]:
num_classes = 2

model = Sequential([
  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(64, 64, 3)),
  layers.Flatten(),
  layers.Dense(64 * 64, activation='relu'), 
  layers.Dense(64 * 64, activation='relu'), 
  layers.Dense(64 * 64), 
  layers.Dense(num_classes) 
])

In [72]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

In [73]:
callback = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=3,
    restore_best_weights=True,
)

In [74]:
history = model.fit(
  train_ds,
  validation_data=test_ds,
  epochs=10,
  callbacks=[callback]
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
