# Download and process sentinel 2 data

This notebook downloads and processes one year of the training and validation plots as labelled on Collect Earth Online. 

## John Brandt
## Last edit: July 12, 2021

## Package imports, API import, source scripts

In [1]:
import datetime
import logging
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
import os
import hickle as hkl
import scipy.sparse as sparse

import seaborn as sns
import yaml

from collections import Counter
from random import shuffle
from scipy.sparse.linalg import splu
from sentinelhub import WmsRequest, WcsRequest, MimeType
from sentinelhub import CRS, BBox, constants, DataSource, CustomUrlParam
from skimage.transform import resize
from typing import Tuple, List
from scipy.ndimage import median_filter

with open("../config.yaml", 'r') as stream:
        key = (yaml.safe_load(stream))
        API_KEY = key['key'] 
        
%matplotlib inline
%run ../src/preprocessing/slope.py
%run ../src/preprocessing/indices.py
%run ../src/downloading/utils.py
%run ../src/preprocessing/cloud_removal.py
%run ../src/preprocessing/whittaker_smoother.py
%run ../src/downloading/io.py

In [2]:
# Load API key and set up connection to AWS
if os.path.exists("../config.yaml"):
    with open("../config.yaml", 'r') as stream:
        key = (yaml.safe_load(stream))
        API_KEY = key['key']
        AWSKEY = key['awskey']
        AWSSECRET = key['awssecret']
else:
    API_KEY = "none"
    

uploader = FileUploader(awskey = AWSKEY, awssecret = AWSSECRET)

## Parameters

In [3]:
# Parameters
YEAR = 2020
TIME = ('{}-11-15'.format(str(YEAR - 1)), '{}-02-15'.format(str(YEAR + 1)))
EPSG = CRS.WGS84
IMSIZE = 32

# Constants
starting_days = np.cumsum([0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30])

# Geographic utility functions

These cell blocks calculate the min_x, min_y, max_x, max_y of the area of interest (AOI)

In [4]:
def calc_bbox(plot_id: int, df: pd.DataFrame) -> List:
    """ Calculates the corners of a bounding box from an input
        pandas dataframe as output by Collect Earth Online

        Parameters:
         plot_id (int): plot_id of associated plot
         df (pandas.DataFrame): dataframe of associated CEO survey
    
        Returns:
         bounding_box (list): [(min(x), min(y)),
                              (max(x), max_y))]
    """
    subs = df[df['PLOT_ID'] == plot_id]
    # (min x, min y), (max x, max y)
    return [(min(subs['LON']), min(subs['LAT'])),
            (max(subs['LON']), max(subs['LAT']))]


def bounding_box(points: List[Tuple[float, float]], 
                 expansion: int = 160) -> ((Tuple, Tuple), str):
    """ Calculates the corners of a bounding box with an
        input expansion in meters from a given bounding_box
        
        Subcalls:
         calculate_epsg, convertCoords

        Parameters:
         points (list): output of calc_bbox
         expansion (float): number of meters to expand or shrink the
                            points edges to be
    
        Returns:
         bl (tuple): x, y of bottom left corner with edges of expansion meters
         tr (tuple): x, y of top right corner with edges of expansion meters
    """
    bl = list(points[0])
    tr = list(points[1])
    inproj = Proj('epsg:4326')
    outproj_code = calculate_epsg(bl)
    outproj = Proj('epsg:' + str(outproj_code))
    bl_utm =  transform(inproj, outproj, bl[1], bl[0])
    tr_utm =  transform(inproj, outproj, tr[1], tr[0])

    distance1 = tr_utm[0] - bl_utm[0]
    distance2 = tr_utm[1] - bl_utm[1]
    expansion1 = (expansion - distance1)/2
    expansion2 = (expansion - distance2)/2
        
    bl_utm = [bl_utm[0] - expansion1, bl_utm[1] - expansion2]
    tr_utm = [tr_utm[0] + expansion1, tr_utm[1] + expansion2]

    zone = str(outproj_code)[3:]
    zone = zone[1:] if zone[0] == "0" else zone
    direction = 'N' if tr[1] >= 0 else 'S'
    utm_epsg = "UTM_" + zone + direction
    return (bl_utm, tr_utm), CRS[utm_epsg]

# Data download

In [5]:
def extract_dates(date_dict: dict, year: int) -> List:
    """ Transforms a SentinelHub date dictionary to a
         list of integer calendar dates
    """
    dates = []
    days_per_month = [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30]
    starting_days = np.cumsum(days_per_month)
    for date in date_dict:
        if date.year == year - 1:
            dates.append(-365 + starting_days[(date.month-1)] + date.day)
        if date.year == year:
            dates.append(starting_days[(date.month-1)] + date.day)
        if date.year == year + 1:
            dates.append(365 + starting_days[(date.month-1)]+date.day)
    return dates

def to_float32(array: np.array) -> np.array:
    """Converts an int_x array to float32"""
    print(f'The original max value is {np.max(array)}')
    if not isinstance(array.flat[0], np.floating):
        assert np.max(array) > 1
        array = np.float32(array) / 65535.
    assert np.max(array) <= 1
    return array

## Cloud and cloud shadow calculation functions

This cell block identifies clouds using s2Cloudless, and identifies clouds (and shadows) using Candra et al. 2019.

The output is per-px cloud/shadow masks, and a list of sentinel 2 dates to download.

In [6]:
def identify_clouds(bbox: List[Tuple[float, float]], epsg: 'CRS', time: dict = TIME):
    """ Downloads and calculates cloud cover and shadow
        
        Parameters:
         bbox (list): output of calc_bbox
         epsg (float): EPSG associated with bbox 
         time (tuple): YY-MM-DD - YY-MM-DD bounds for downloading 
    
        Returns:
         cloud_img (np.array): (X, 96, 96) array of cloud probs
         shadows (np.array):  (X, 96, 96) array of shadow binary
         clean_steps (np.array): (N,) array of clean idx
         cloud_dates (np.array): (N,) array of clean cloud datets
    """
    box = BBox(bbox, crs = epsg)
    cloud_request = WcsRequest(
        layer='CLOUD_NEW',
        bbox=box, time=time,
        resx='160m', resy='160m',
        image_format = MimeType.TIFF_d8,
        maxcc=0.75, instance_id=API_KEY,
        custom_url_params = {constants.CustomUrlParam.UPSAMPLING: 'NEAREST'},
        time_difference=datetime.timedelta(hours=96))

    shadow_request = WcsRequest(
        layer='SHADOW',
        bbox=box, time=time,
        resx='60m', resy='60m',
        image_format =  MimeType.TIFF_d16,
        maxcc=0.75, instance_id=API_KEY,
        custom_url_params = {constants.CustomUrlParam.UPSAMPLING: 'NEAREST'},
        time_difference=datetime.timedelta(hours=96))

    cloud_img = np.array(cloud_request.get_data())
    if not isinstance(cloud_img.flat[0], np.floating):
        assert np.max(cloud_img) > 1
        cloud_img = cloud_img / 255.
    assert np.max(cloud_img) <= 1

    cloud_img = resize(cloud_img, (cloud_img.shape[0], 96, 96), order = 0)
    n_cloud_px = np.sum(cloud_img > 0.33, axis = (1, 2))
    cloud_steps = np.argwhere(n_cloud_px > (96**2 * 0.15))
    clean_steps = [x for x in range(cloud_img.shape[0]) if x not in cloud_steps]
    
    
    cloud_dates_dict = [x for x in cloud_request.get_dates()]
    cloud_dates = extract_dates(cloud_dates_dict, YEAR)
    cloud_dates = [val for idx, val in enumerate(cloud_dates) if idx in clean_steps]
    
    shadow_dates_dict = [x for x in shadow_request.get_dates()]
    shadow_dates = extract_dates(shadow_dates_dict, YEAR)
    shadow_steps = [idx for idx, val in enumerate(shadow_dates) if val in cloud_dates]    
    
    shadow_img = np.array(shadow_request.get_data(data_filter = shadow_steps))
    shadow_pus = (shadow_img.shape[1]*shadow_img.shape[2])/(512*512) * shadow_img.shape[0]
    shadow_img = resize(shadow_img, (shadow_img.shape[0], 96, 96, shadow_img.shape[-1]), order = 0,
                        anti_aliasing = False, preserve_range = True).astype(np.uint16)

    cloud_img = np.delete(cloud_img, cloud_steps, 0)
    assert shadow_img.shape[0] == cloud_img.shape[0], (shadow_img.shape, cloud_img.shape)
    shadows = mcm_shadow_mask(shadow_img, cloud_img) # Make usre this makes sense??
    print(f"Shadows ({shadows.shape}) used {round(shadow_pus, 1)} processing units")
    return cloud_img, shadows, clean_steps, np.array(cloud_dates)

# DEM and slope

In [7]:
def download_dem(plot_id: int, df: 'DataFrame', epsg: 'CRS') -> (np.ndarray, np.ndarray):
    """ Downloads MapZen digital elevation model and return slope

        Parameters:
         plot_id (tuple): plot id from collect earth online (CEO)
         df (pandas.DataFrame): data associated with plot_id from CEO
         epsg (int): UTM EPSG associated with plot_id
    
        Returns:
         slope (arr): (X, Y, 1) array of per-pixel slope from [0, 1]
    """
    location = calc_bbox(plot_id, df = df)
    bbox, epsg = bounding_box(location, expansion = (32+2)*10)
    box = BBox(bbox, crs = epsg)
    dem_request = WcsRequest(
                         layer='DEM', bbox=box,
                         resx = "10m", resy = "10m",
                         instance_id=API_KEY,
                         image_format= MimeType.TIFF_d16,
                         custom_url_params={CustomUrlParam.SHOWLOGO: False})
    dem_image_init = dem_request.get_data()[0]
    dem_image_init = dem_image_init - 12000
    dem_image_init = dem_image_init.astype(np.float32)
    dem_image = np.copy(dem_image_init)
    dem_image = median_filter(dem_image_init, size = 5)
    slope = calcSlope(dem_image.reshape((1, 32+2, 32+2)),
                      np.full((32+2, 32+2), 10),
                      np.full((32+2, 32+2), 10), 
                      zScale = 1, minSlope = 0.02)
    slope = slope / 90
    slope = slope.reshape((32+2, 32+2, 1))
    slope = slope[1:32+1, 1:32+1, :]
    return slope, dem_image_init

## 10 and 20 meter L2A bands

In [8]:
def download_layer(bbox: List[Tuple[float, float]],
                   clean_steps: np.ndarray, epsg: 'CRS',
                   dates: dict = TIME, year: int = YEAR) -> (np.ndarray, np.ndarray):
    """ Downloads the L2A sentinel layer with 10 and 20 meter bands
        
        Parameters:
         bbox (list): output of calc_bbox
         epsg (float): EPSG associated with bbox 
         time (tuple): YY-MM-DD - YY-MM-DD bounds for downloading 
    
        Returns:
         img (arr):
         img_request (obj): 
    """
    try:
        box = BBox(bbox, crs = epsg)
        image_request = WcsRequest(
                layer='L2A20',
                bbox=box, time=dates,
                image_format = MimeType.TIFF_d16,
                data_source = DataSource.SENTINEL2_L2A,
                maxcc=0.75,
                resx='20m', resy='20m',
                instance_id=API_KEY,
                custom_url_params = {constants.CustomUrlParam.DOWNSAMPLING: 'NEAREST',
                                    constants.CustomUrlParam.UPSAMPLING: 'NEAREST'},
                time_difference=datetime.timedelta(hours=96),
            )
        
        image_dates = []
        for date in image_request.get_dates():
            if date.year == YEAR - 1:
                image_dates.append(-365 + starting_days[(date.month-1)] + date.day)
            if date.year == YEAR:
                image_dates.append(starting_days[(date.month-1)] + date.day)
            if date.year == YEAR + 1:
                image_dates.append(365 + starting_days[(date.month-1)]+date.day)
        
        steps_to_download = [i for i, val in enumerate(image_dates) if val in clean_steps]
        dates_to_download = [val for i, val in enumerate(image_dates) if val in clean_steps]
              
        img_bands = image_request.get_data(data_filter = steps_to_download)
        img_20 = np.stack(img_bands)
        img_20 = to_float32(img_20)

        s2_20_usage = (img_20.shape[1]*img_20.shape[2])/(512*512) * (6/3) * img_20.shape[0]
        if (img_20.shape[1] * img_20.shape[2]) != 14*14:
            print(f"Original 20 meter bands size: {img_20.shape}, using {s2_20_usage} PU")
        img_20 = resize(img_20, (img_20.shape[0], IMSIZE, IMSIZE, img_20.shape[-1]), order = 0)
        
        image_request = WcsRequest(
                layer='L2A10',
                bbox=box, time=dates,
                image_format = MimeType.TIFF_d16,
                data_source = DataSource.SENTINEL2_L2A,
                maxcc=0.75,
                resx='10m', resy='10m',
                instance_id=API_KEY,
                custom_url_params = {constants.CustomUrlParam.DOWNSAMPLING: 'BICUBIC',
                                    constants.CustomUrlParam.UPSAMPLING: 'BICUBIC'},
                time_difference=datetime.timedelta(hours=96),
        )
        
        img_bands = image_request.get_data(data_filter = steps_to_download)
        img_10 = np.stack(img_bands)
        if (img_10.shape[1] * img_10.shape[2]) != 28*28:
            print(f"The original L2A image size is: {img_10.shape}")
        img_10 = to_float32(img_10)
            
        img_10 = resize(img_10, (img_10.shape[0], IMSIZE, IMSIZE, img_10.shape[-1]), order = 0)
        img = np.concatenate([img_10, img_20], axis = -1)

        
        return img, np.array(dates_to_download)

    except Exception as e:
        logging.fatal(e, exc_info=True)

# Super resolution

Super-resolve the 20 meter bands to 10 meters using DSen2.

In [9]:
import tensorflow as tf
sess = tf.Session()
from keras import backend as K
K.set_session(sess)

MDL_PATH = "../models/supres/"

model = tf.train.import_meta_graph(MDL_PATH + 'model.meta')
model.restore(sess, tf.train.latest_checkpoint(MDL_PATH))

logits = tf.get_default_graph().get_tensor_by_name("Add_6:0")
inp = tf.get_default_graph().get_tensor_by_name("Placeholder:0")
inp_bilinear = tf.get_default_graph().get_tensor_by_name("Placeholder_1:0")

def superresolve(input_data):
    bilinear_upsample = input_data[..., 4:]
    x = sess.run([logits], 
                 feed_dict={inp: input_data,
                            inp_bilinear: bilinear_upsample})
    return x[0]

def superresolve_tile(x):
    twentym = x[..., 4:]
    imsize = x.shape[1]
    twentym = np.reshape(twentym, (x.shape[0], imsize // 2, 2, imsize // 2, 2, 6))
    twentym = np.mean(twentym, (2, 4))
    twentym = resize(twentym, (x.shape[0], imsize, imsize, 6), 1)
    x[..., 4:] = twentym
    x[..., 4:] = superresolve(x)
    if imsize > 28:
        crop_amt = (imsize - 28) // 2
        x = x[:, crop_amt:-crop_amt, crop_amt:-crop_amt, :]
    return x

Using TensorFlow backend.


INFO:tensorflow:Restoring parameters from ../models/supres/model


# Download function

In [10]:
def download_new_dem(data_location: 'os.Path',
                     output_folder: 'os.Path',
                     image_format: 'MimeType' = MimeType.TIFF_d16):
    """ Downloads and saves DEM and slope files
        
        Parameters:
         data_location (os.path): 
         output_folder (os.path): 
         image_format (MimeType): 
    
        Returns:
         None
    """
    
    df = pd.read_csv(data_location)
    df.columns = [x.upper() for x in df.columns]
    for column in ['IMAGERY_TITLE', 'STACKINGPROFILEDG', 'PL_PLOTID', 'IMAGERYYEARDG',
                  'IMAGERYMONTHPLANET', 'IMAGERYYEARPLANET', 'IMAGERYDATESECUREWATCH',
                  'IMAGERYENDDATESECUREWATCH', 'IMAGERYFEATUREPROFILESECUREWATCH',
                  'IMAGERYSTARTDATESECUREWATCH','IMAGERY_ATTRIBUTIONS',
                  'SAMPLE_GEOM']:
        if column in df.columns:
            df = df.drop(column, axis = 1)
            
    df = df.dropna(axis = 0)
    plot_ids = sorted(df['PLOT_ID'].unique())
    existing = [int(x[:-4]) for x in os.listdir(output_folder) if ".DS" not in x]
    existing = existing + [139089844]
    to_download = [x for x in plot_ids if x not in existing]
    print(f"Starting download of {len(to_download)}"
          f" plots from {data_location} to {output_folder}")
    errors = []
    for i, val in enumerate(to_download):
        print(f"Downloading {i + 1}/{len(to_download)}, {val}")
        initial_bbx = calc_bbox(val, df = df)
        dem_bbx, epsg = bounding_box(initial_bbx, expansion = 32*10)
        slope, dem = download_dem(val, epsg = epsg, df = df)
        np.save(output_folder + str(val), dem)
        np.save("../data/train-slope/" + str(val), slope)
        


def concatenate_dem(x, dem):
    dem = np.tile(dem.reshape((1, 32, 32, 1)), (x.shape[0], 1, 1, 1))
    dem = dem[:, 2:-2, 2:-2, :]
    x = np.concatenate([x, dem], axis = -1)
    assert x.shape[1] == x.shape[2] == 28
    return x

In [11]:
def id_missing_px(sentinel2: np.ndarray, thresh: int = 100) -> np.ndarray:
    missing_images_0 = np.sum(sentinel2[..., :10] == 0.0, axis = (1, 2, 3))
    missing_images_p = np.sum(sentinel2[..., :10] >= 1., axis = (1, 2, 3))
    missing_images = missing_images_0 + missing_images_p
    
    missing_images = np.argwhere(missing_images >= (sentinel2.shape[1]**2) / thresh).flatten()
    return missing_images

def download_raw_data(data_location, output_folder, fmt = "train", image_format = MimeType.TIFF_d16):
    """ Downloads slope and sentinel-2 data for all plots associated
        with an input CSV from a collect earth online survey
        
        Parameters:
         data_location (os.path)
         output_folder (os.path)
        
        Creates:
         output_folder/{plot_id}.npy
    
        Returns:
         None
    """
    df = pd.read_csv(data_location)
    df.columns = [x.upper() for x in df.columns]
    for column in ['IMAGERY_TITLE', 'STACKINGPROFILEDG', 'PL_PLOTID', 'IMAGERYYEARDG',
                  'IMAGERYMONTHPLANET', 'IMAGERYYEARPLANET', 'IMAGERYDATESECUREWATCH',
                  'IMAGERYENDDATESECUREWATCH', 'IMAGERYFEATUREPROFILESECUREWATCH',
                  'IMAGERYSTARTDATESECUREWATCH','IMAGERY_ATTRIBUTIONS',
                  'SAMPLE_GEOM']:
        if column in df.columns:
            df = df.drop(column, axis = 1)

    df = df.dropna(axis = 0)
    plot_ids = sorted(df['PLOT_ID'].unique())
    existing = [int(x[:-4]) for x in os.listdir(f"../data/{fmt}-dates/") if ".DS" not in x]
    existing = existing + [139190271, 139187199, 139319876, 139319877]
    to_download = [x for x in plot_ids if x not in existing]
    print(f"Starting download of {len(to_download)}"
          f" plots from {data_location} to {output_folder}")
    for i, val in enumerate(to_download):
        print(f"Downloading {i + 1}/{len(to_download)}, {val}")
        initial_bbx = calc_bbox(val, df = df)
        sentinel2_bbx, epsg = bounding_box(initial_bbx, expansion = IMSIZE*10)
        cloud_bbx, _ = bounding_box(initial_bbx, expansion = 96*10)
        try:
            # Identify cloud steps, download DEM, and download L2A series
            cloud_probs, shadows, _, clean_dates = identify_clouds(cloud_bbx, epsg = epsg)
            dem, _ = download_dem(val, epsg = epsg, df = df)
            to_remove, _ = calculate_cloud_steps(cloud_probs, clean_dates)
            
            if len(to_remove) > 0:
                cloud_probs = np.delete(cloud_probs, to_remove, 0)
                clean_dates = np.delete(clean_dates, to_remove)
                shadows = np.delete(shadows, to_remove, 0)
                
            to_remove = subset_contiguous_sunny_dates(clean_dates)
            if len(to_remove) > 0:
                cloud_probs = np.delete(cloud_probs, to_remove, 0)
                clean_dates = np.delete(clean_dates, to_remove)
                shadows = np.delete(shadows, to_remove, 0)
        
            s2, s2_dates = download_layer(sentinel2_bbx, clean_steps = clean_dates, epsg = epsg)    
            
            # Step to ensure that shadows, clouds, sentinel l2a have aligned dates
            to_remove_clouds = [i for i, val in enumerate(clean_dates) if val not in s2_dates]
            to_remove_dates = [val for i, val in enumerate(clean_dates) if val not in s2_dates]
            if len(to_remove_clouds) > 0:
                print(f"Removing {to_remove_dates} from clouds because not in S2")
                cloud_probs = np.delete(cloud_probs, to_remove_clouds, 0)
                shadows = np.delete(shadows, to_remove_clouds, 0)
            print(f"Shadows {shadows.shape}, clouds {cloud_probs.shape},"
                  f" S2, {s2.shape}, S2d, {s2_dates.shape}")
            
            to_remove = remove_missed_clouds(s2)
            if len(to_remove) > 0:
                print(f"Removing {len(to_remove)} steps based on ratio")
                s2 = np.delete(s2, to_remove, 0)
                cloud_probs = np.delete(cloud_probs, to_remove, 0)
                s2_dates = np.delete(s2_dates, to_remove)
                shadows = np.delete(shadows, to_remove, 0)
            
            cloud_probs = cloud_probs[:, 24:-24, 24:-24]
            shadows = shadows[:, 24:-24, 24:-24]
            x, interp = remove_cloud_and_shadows(s2, cloud_probs, shadows, s2_dates)
            to_remove = np.argwhere(np.mean(interp, axis = (1, 2)) > 0.5)
            if len(to_remove) > 0:
                print(f"Removing {len(to_remove)} steps with >50% interpolation: {to_remove}")
                x = np.delete(x, to_remove, 0)
                cloud_probs = np.delete(cloud_probs, to_remove, 0)
                s2_dates = np.delete(s2_dates, to_remove)
                shadows = np.delete(shadows, to_remove, 0)
                print(np.sum(shadows, axis = (1, 2)))
            
            x_to_save = np.copy(x)
            x_to_save = np.clip(x_to_save, 0, 1)
            x_to_save = np.trunc(x_to_save * 65535).astype(np.uint16)
            np.save(f"../data/{fmt}-dates/{str(val)}", s2_dates)
            np.save(f"../data/{fmt}-raw/{str(val)}", x_to_save)
            file = f"../data/{fmt}-raw/{str(val)}.npy"
            key = f'restoration-mapper/model-data/{fmt}/raw/{str(val)}.npy'
            uploader.upload(bucket = 'restoration-monitoring', key = key, file = file)
            print("\n")

        except Exception as e:
            print(e)
            logging.fatal(e, exc_info=True)

In [17]:
def select_dates(dates):
    """For imagery that was downloaded prior to capping the number 
       of monthly images to be 3, it is necessary to enforce that cap
       on the training / testing data.
       
       This function identifies the indices of the imagery to deletet
       such that there is a maximum of three images per month.
    
    """
    
    
    before = len(dates)
    selected_indices = np.arange(len(dates))
    begin = [-60, 0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334]
    end = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 390]
    indices_to_remove = []
    for x, y in zip(begin, end):
        indices_month = np.argwhere(np.logical_and(dates >= x, dates < y)).flatten()
        if len(indices_month) > 3:
            to_delete = np.empty((0,))
            if begin == -60:
                to_delete = indices_month[:-3]
            elif begin == 334:
                to_delete = indices_month[3:]
            elif len(indices_month) == 4:
                to_delete = indices_month[1]
            elif len(indices_month) == 5:
                to_delete = np.array([indices_month[1],
                                      indices_month[3]])
            elif len(indices_month) == 6:
                to_delete = np.array([indices_month[1],
                                      indices_month[3],
                                      indices_month[4]])
                
            to_delete = np.array(to_delete)
            if to_delete.size > 0:
                indices_to_remove.append(to_delete.flatten())
                
    if len(indices_to_remove) > 0:
        indices_to_remove = np.concatenate(indices_to_remove)
        after = before - len(indices_to_remove)
        print(f"Keeping {after}/{before}")
        return indices_to_remove
    
    else:
        return []
    

def subset_contiguous_sunny_dates(dates):
    begin = [-60, 0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334]
    end = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 390]
    n_per_month = []
    months_to_adjust = []
    months_to_adjust_again = []
    indices_to_rm = []
    
    if len(dates) >= 9:
        for x, y in zip(begin, end):
            indices_month = np.argwhere(np.logical_and(
                dates >= x, dates < y)).flatten()
            n_per_month.append(len(indices_month))

        # Convert 3 image months to 2 in 3 month windows with at least 5 / 9 possible images
        for x in range(11):
            # This will only go from 3 images to 2 images
            three_m_sum = np.sum(n_per_month[x:x+3])
            # If at least 5/9 images and a minimum of 1:
            if three_m_sum >= 4 and np.min(n_per_month[x:x+3]) >= 0:
                # Add the months to be adjusted
                months_to_adjust += [x, x+1, x+2]

        months_to_adjust = list(set(months_to_adjust))

        # Jan - Mar, Mar - May, May - Jul, Jul - Sep, Sep - Nov, Oct - Dec
        # This will sometimes take 3 images down to 1 image
        for x in [0, 2, 4, 6, 8, 10]:
            three_m_sum = np.sum(n_per_month[x:x+3])
            # For windows that are 2, 2, 2 to 3, 3, 3
            if three_m_sum >= 5 and np.min(n_per_month[x:x+3]) >= 1: 
                # Prefer to adjust the middle month if possible
                if n_per_month[x + 1] == 3: # 3, 3, 3 or 2, 3, 2
                    months_to_adjust_again.append(x + 1)
                elif n_per_month[x] == 3: # 3, 2, 2 
                    months_to_adjust_again.append(x)
                elif n_per_month[x + 1] == 2: # 2, 2, 2 or 2, 2, 3
                    months_to_adjust_again.append(x + 1)
                elif n_per_month[x + 2] == 3: # 2, 2, 3
                    months_to_adjust_again.append(x + 2)

        if len(months_to_adjust) > 0:
            for month in months_to_adjust:
                indices_month = np.argwhere(np.logical_and(
                    dates >= begin[month], dates < end[month])).flatten()

                cloudiest_idx = 1
                # Remove the cloudiest image of the 3
                if len(indices_month) >= 3:
                    indices_to_rm.append(indices_month[cloudiest_idx])
                    
        n_remaining = len(dates) - len(indices_to_rm)
                    
        if len(months_to_adjust_again) > 0 and n_remaining >= 12:
            for month in months_to_adjust_again:
                indices_month = np.argwhere(np.logical_and(
                    dates >= begin[month], dates < end[month])).flatten()
                indices_month = [x for x in indices_month if x not in indices_to_rm]
                cloudiest_idx = 1
                # Remove the cloudiest image of the 3
                if len(indices_month) >= 2:
                    indices_to_rm.append(indices_month[cloudiest_idx])
                    
        print(f"Removing {len(indices_to_rm)} sunny dates")

    return indices_to_rm


def to_int16(array: np.array) -> np.array:
    '''Converts a float32 array to uint16, reducing storage costs by three-fold'''
    array = np.clip(array, 0, 1)
    array = np.trunc(array * 65535)
    assert np.min(array >= 0)
    assert np.max(array <= 65535)
    
    return array.astype(np.uint16)


def process_raw(plot_id, path = 'train'):
    """ Downloads slope and sentinel-2 data for all plots associated
        with an input CSV from a collect earth online survey
        
        Parameters:
         data_location (os.path)
         output_folder (os.path)
        
        Creates:
         output_folder/{plot_id}.npy
    
        Returns:
         None
    """         

    x = np.load(f"../data/{path}-raw/{plot_id}.npy")
    x = np.float32(x) / 65535
    if x.shape[-1] == 10:
        s2_dates = np.load(f"../data/{path}-dates/{plot_id}.npy")
        dem = np.load(f"../data/{path}-slope/{plot_id}.npy")

        assert x.shape[0] == s2_dates.shape[0]

        missing_px = id_missing_px(x)
        if len(missing_px) > 0:
            print(f"Deleting {missing_px} because of missing data")
            x = np.delete(x, missing_px, 0)
            s2_dates = np.delete(s2_dates, missing_px)


        n_images = x.shape[0]
        to_remove = remove_missed_clouds(x)
        if len(to_remove) > 0:
            x = np.delete(x, to_remove, 0)
            s2_dates = np.delete(s2_dates, to_remove)
            print(f"Removing {len(to_remove)} steps based on MCM mask")

        to_remove = select_dates(s2_dates)
        if len(to_remove) > 0:
            x = np.delete(x, to_remove, 0)
            s2_dates = np.delete(s2_dates, to_remove)

        
        to_remove = subset_contiguous_sunny_dates(s2_dates)
        if len(to_remove) > 0:
            x = np.delete(x, to_remove, 0)
            s2_dates = np.delete(s2_dates, to_remove)

        for band in range(0, 10):
            for time in range(0, x.shape[0]):
                x_i = x[time, :, :, band]
                x_i[np.argwhere(np.isnan(x_i))] = np.mean(x_i)
                x[time, :, :, band] = x_i

        # Interpolate linearly to 5 day frequency
        tiles, max_distance = calculate_and_save_best_images(x, s2_dates)
        sm = Smoother(lmbd = 800, size = tiles.shape[0],
                      nbands = 10, dim = tiles.shape[1])
        x = sm.interpolate_array(tiles)

        x = superresolve_tile(x)
        print(x.shape)
        tiles = concatenate_dem(x, dem)
        dates = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334]
        dates = np.array(dates) + 15
        closest_date = []
        for date in dates:
            date_diff = s2_dates[np.argmin(abs(s2_dates - date))]
            closest_date.append(date_diff)

        closest_date = np.array(closest_date)
        print(closest_date)
        closest_date = closest_date[:, np.newaxis, np.newaxis, np.newaxis]
        closest_date = np.broadcast_to(closest_date, (12, 28, 28, 1))
        # -45 is the minimum, so adding 45 ensures that it is [0, ]
        # 411 (365 + 46) si the maximum, so dividing by 416 ensures that it is [0, 1]
        closest_date = (closest_date + 45)  / 411

        tiles = np.concatenate([tiles, closest_date], axis = -1)



        if np.sum(np.isnan(tiles)) == 0:
            print(f"There are {np.sum(np.isnan(tiles))} NA values")
            if max_distance <= 300 and n_images >= 5:
                tiles = to_int16(tiles)
                #np.save(f"../data/{path}-s2-new/{plot_id}", tiles)
                tile_path = f"../data/{path}-s2-new/{plot_id}"
                tile_path = tile_path + ".hkl"
                hkl.dump(tiles, tile_path, mode='w', compression='gzip')
                print(f"Saved {tiles.shape} shape, {n_images} img,"
                      f" to {tile_path} \n")
            else:
                print(f"Skipping {plot_id} because {max_distance} distance, and {n_images} img \n")

        return tiles

# Function execution
## 1. Download DEM and Slope

In [13]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

for i in (os.listdir("../data/train-csv/")):
    if "kenya" in i:
        tile = download_new_dem("../data/train-csv/" + i,
                                "../data/train-dem/",
                                image_format = MimeType.TIFF_d16)        

Starting download of 177 plots from ../data/train-csv/ceo-kenya_shrubs-sample-data-2021-09-02.csv to ../data/train-dem/
Downloading 1/177, 141370594
Downloading 2/177, 141370595
Downloading 3/177, 141370596
Downloading 4/177, 141370597
Downloading 5/177, 141370598
Downloading 6/177, 141370599
Downloading 7/177, 141370600
Downloading 8/177, 141370601
Downloading 9/177, 141370602
Downloading 10/177, 141370603
Downloading 11/177, 141370604
Downloading 12/177, 141370605
Downloading 13/177, 141370606
Downloading 14/177, 141370607
Downloading 15/177, 141370608
Downloading 16/177, 141370609
Downloading 17/177, 141370610
Downloading 18/177, 141370611
Downloading 19/177, 141370612
Downloading 20/177, 141370613
Downloading 21/177, 141370614
Downloading 22/177, 141370615
Downloading 23/177, 141370616
Downloading 24/177, 141370617
Downloading 25/177, 141370618
Downloading 26/177, 141370619
Downloading 27/177, 141370620
Downloading 28/177, 141370621
Downloading 29/177, 141370622
Downloading 30/177,

## 2. Download Raw data files

In [19]:
for i in (os.listdir("../data/train-csv/")):
    if "kenya" in i:
        download_raw_data("../data/train-csv/" + i,
                          "../data/train-raw/", 
                          fmt = 'train',
                          image_format = MimeType.TIFF_d16)

Starting download of 5 plots from ../data/train-csv/ceo-kenya_shrubs-sample-data-2021-09-02.csv to ../data/train-raw/
Downloading 1/5, 141370604


CRITICAL:root:Failed to download from:
with HTTPError:
Server response: "Illegal request to https://sentinel-s2-l1c-index.s3.amazonaws.com/tiles/37/M/ER/2020/11/1/0/B8A.index. HTTP Status: 403"
Traceback (most recent call last):
  File "/Users/jbrandt.terminal/opt/anaconda3/envs/tf/lib/python3.7/site-packages/sentinelhub/download/handlers.py", line 44, in new_download_func
    return download_func(self, request)
  File "/Users/jbrandt.terminal/opt/anaconda3/envs/tf/lib/python3.7/site-packages/sentinelhub/download/handlers.py", line 28, in new_download_func
    raise exception from exception
  File "/Users/jbrandt.terminal/opt/anaconda3/envs/tf/lib/python3.7/site-packages/sentinelhub/download/handlers.py", line 22, in new_download_func
    return download_func(self, request)
  File "/Users/jbrandt.terminal/opt/anaconda3/envs/tf/lib/python3.7/site-packages/sentinelhub/download/sentinelhub_client.py", line 56, in _execute_download
    response.raise_for_status()
  File "/Users/jbrandt.ter

Failed to download from:
with HTTPError:
Server response: "Illegal request to https://sentinel-s2-l1c-index.s3.amazonaws.com/tiles/37/M/ER/2020/11/1/0/B8A.index. HTTP Status: 403"
Downloading 2/5, 141370605


CRITICAL:root:Failed to download from:
with HTTPError:
Server response: "Illegal request to https://sentinel-s2-l1c-index.s3.amazonaws.com/tiles/37/M/ER/2020/11/1/0/B8A.index. HTTP Status: 403"
Traceback (most recent call last):
  File "/Users/jbrandt.terminal/opt/anaconda3/envs/tf/lib/python3.7/site-packages/sentinelhub/download/handlers.py", line 44, in new_download_func
    return download_func(self, request)
  File "/Users/jbrandt.terminal/opt/anaconda3/envs/tf/lib/python3.7/site-packages/sentinelhub/download/handlers.py", line 28, in new_download_func
    raise exception from exception
  File "/Users/jbrandt.terminal/opt/anaconda3/envs/tf/lib/python3.7/site-packages/sentinelhub/download/handlers.py", line 22, in new_download_func
    return download_func(self, request)
  File "/Users/jbrandt.terminal/opt/anaconda3/envs/tf/lib/python3.7/site-packages/sentinelhub/download/sentinelhub_client.py", line 56, in _execute_download
    response.raise_for_status()
  File "/Users/jbrandt.ter

Failed to download from:
with HTTPError:
Server response: "Illegal request to https://sentinel-s2-l1c-index.s3.amazonaws.com/tiles/37/M/ER/2020/11/1/0/B8A.index. HTTP Status: 403"
Downloading 3/5, 141370606


KeyboardInterrupt: 

## 3. Process train / test data

In [20]:
i = 0
plots = [str(x[:-4]) for x in os.listdir("../data/train-raw/") if ".npy" in x]
for plot in plots:
    i += 1
    if not os.path.exists("../data/train-s2-new/" + plot + ".hkl"):
        if plot not in ['139189131', '139190072', '139320327', '136029661', '139190071', '141018056']:
            tiles = process_raw(plot, path = 'train')
            print(i, plot)

Missed shadow 4: 0.23046875
Removing 0 sunny dates
(12, 28, 28, 10)
[ 16  56  60  95 145 150 185 230 245 285 305 345]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 14 img, to ../data/train-s2-new/141370680.hkl 

1 141370680
(12, 28, 28, 10)
[ 11  21  85  85  85  85 260 260 260 275 325 325]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 7 img, to ../data/train-s2-new/141370694.hkl 

2 141370694
Removing 1 sunny dates
(12, 28, 28, 10)
[ 21  56  65  65 155 155 195 230 245 285 340 340]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 16 img, to ../data/train-s2-new/141370643.hkl 

3 141370643
Removing 0 sunny dates
(12, 28, 28, 10)
[ 11  56  60  95 130 130 220 230 270 270 335 345]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 14 img, to ../data/train-s2-new/141370657.hkl 

4 141370657
Missed shadow 0: 0.12109375
Removing 0 sunny dates
(12, 28, 28, 10)
[-14  56  60  60 155 165 210 210 285 285 305 335]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 10 img, to ../data

(12, 28, 28, 10)
[ 21  65  65 120 120 155 185 220 260 290 315 350]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 15 img, to ../data/train-s2-new/141370719.hkl 

80 141370719
Removing 0 sunny dates
(12, 28, 28, 10)
[-14  65  65 120 120 155 195 225 260 285 330 330]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 13 img, to ../data/train-s2-new/141370718.hkl 

81 141370718
Removing 0 sunny dates
(12, 28, 28, 10)
[ 16  65  80  95 120 170 170 230 270 275 275 275]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 12 img, to ../data/train-s2-new/141370730.hkl 

82 141370730
Missed shadow 1: 0.25390625
Removing 0 sunny dates
(12, 28, 28, 10)
[ 11  41  85 120 120 170 195 240 245 285 295 350]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 15 img, to ../data/train-s2-new/141370724.hkl 

83 141370724
Removing 0 sunny dates
(12, 28, 28, 10)
[ 11  51  80  80 185 185 205 245 270 270 335 345]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 12 img, to ../data/train-s2-new/141370678

(12, 28, 28, 10)
[ 16  36  60  60  60 245 245 245 245 245 345 350]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 10 img, to ../data/train-s2-new/141370756.hkl 

164 141370756
Removing 0 sunny dates
(12, 28, 28, 10)
[ 21  41  65  65 150 175 175 245 255 275 305 305]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 13 img, to ../data/train-s2-new/141370742.hkl 

165 141370742
Removing 0 sunny dates
(12, 28, 28, 10)
[ 11  41  65 120 120 150 185 210 245 285 345 345]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 14 img, to ../data/train-s2-new/141370622.hkl 

171 141370622
Missed shadow 6: 0.111328125
Removing 0 sunny dates
(12, 28, 28, 10)
[ 11  46  65  90 120 155 195 220 245 275 315 335]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 17 img, to ../data/train-s2-new/141370636.hkl 

172 141370636
(12, 28, 28, 10)
[ -1  -1 148 148 148 148 218 228 243 288 348 348]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 8 img, to ../data/train-s2-new/141370632.hkl 

173 14137063

Removing 1 sunny dates
(12, 28, 28, 10)
[ 21  46  46 145 145 160 195 225 245 330 330 330]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 13 img, to ../data/train-s2-new/141370610.hkl 

269 141370610
Missed shadow 3: 0.203125
Removing 0 sunny dates
(12, 28, 28, 10)
[ 21  36  60 120 120 155 205 230 260 285 285 385]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 12 img, to ../data/train-s2-new/141370638.hkl 

270 141370638
Removing 0 sunny dates
(12, 28, 28, 10)
[ 11  41  80 120 120 155 185 225 245 290 305 305]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 17 img, to ../data/train-s2-new/141370599.hkl 

274 141370599
Removing 1 sunny dates
(12, 28, 28, 10)
[ 16  36  60 105 130 170 180 215 245 285 345 345]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 16 img, to ../data/train-s2-new/141370770.hkl 

279 141370770
Removing 0 sunny dates
(12, 28, 28, 10)
[ 21  36  60  90 125 155 185 230 245 275 305 385]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 17 img, to ../da

(12, 28, 28, 10)
[ 21  41  95  95 145 165 205 205 285 285 340 340]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 11 img, to ../data/train-s2-new/141370704.hkl 

363 141370704
Missed shadow 8: 0.296875
Missed shadow 9: 0.1171875
Removing 1 steps based on MCM mask
Removing 0 sunny dates
(12, 28, 28, 10)
[ 41  41  41 120 120 180 185 245 270 270 270 385]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 11 img, to ../data/train-s2-new/141370710.hkl 

364 141370710
Missed shadow 6: 0.2607421875
Removing 1 steps based on MCM mask
Removing 0 sunny dates
(12, 28, 28, 10)
[-29  60  65 120 120 155 210 210 285 285 330 330]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 13 img, to ../data/train-s2-new/141370738.hkl 

365 141370738
Missed shadow 4: 0.203125
(12, 28, 28, 10)
[ 21  21  90  90 135 175 175 240 240 290 290 290]
There are 0 NA values
Saved (12, 28, 28, 12) shape, 7 img, to ../data/train-s2-new/141370739.hkl 

366 141370739
Removing 0 sunny dates
(12, 28, 28, 10)
[ 21  46  46