# Download and process sentinel 2 data

This notebook downloads and processes one year of the training and validation plots as labelled on Collect Earth Online. 

## John Brandt
## Last edit: Sept 20, 2021

## Package imports, API import, source scripts

In [1]:
import datetime
import logging
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
import os
import hickle as hkl
import scipy.sparse as sparse

import seaborn as sns
import yaml

from collections import Counter
from random import shuffle
from scipy.sparse.linalg import splu
from sentinelhub import WmsRequest, WcsRequest, MimeType
from sentinelhub import CRS, BBox, constants, DataSource, CustomUrlParam
from skimage.transform import resize
from typing import Tuple, List
from scipy.ndimage import median_filter
from sentinelhub.config import SHConfig

with open("../config.yaml", 'r') as stream:
        key = (yaml.safe_load(stream))
        API_KEY = key['key'] 
        
%matplotlib inline
%run ../src/preprocessing/slope.py
%run ../src/preprocessing/indices.py
%run ../src/downloading/utils.py
%run ../src/preprocessing/cloud_removal.py
%run ../src/preprocessing/whittaker_smoother.py
%run ../src/downloading/io.py

In [2]:
with open("../config.yaml", 'r') as stream:
    key = (yaml.safe_load(stream))
    API_KEY = key['key']
    SHUB_SECRET = key['shub_secret']
    SHUB_KEY = key['shub_id']
    AWSKEY = key['awskey']
    AWSSECRET = key['awssecret']
            
shconfig = SHConfig()
shconfig.instance_id = API_KEY
shconfig.sh_client_id = SHUB_KEY
shconfig.sh_client_secret = SHUB_SECRET
    

uploader = FileUploader(awskey = AWSKEY, awssecret = AWSSECRET)

## Parameters

In [3]:
# Parameters
YEAR = 2020
TIME = ('{}-11-15'.format(str(YEAR - 1)), '{}-02-15'.format(str(YEAR + 1)))
EPSG = CRS.WGS84
IMSIZE = 32

# Constants
starting_days = np.cumsum([0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30])

# Geographic utility functions

These cell blocks calculate the min_x, min_y, max_x, max_y of the area of interest (AOI)

In [4]:
def calc_bbox(plot_id: int, df: pd.DataFrame) -> List:
    """ Calculates the corners of a bounding box from an input
        pandas dataframe as output by Collect Earth Online

        Parameters:
         plot_id (int): plot_id of associated plot
         df (pandas.DataFrame): dataframe of associated CEO survey
    
        Returns:
         bounding_box (list): [(min(x), min(y)),
                              (max(x), max_y))]
    """
    subs = df[df['PLOT_ID'] == plot_id]
    # (min x, min y), (max x, max y)
    return [(min(subs['LON']), min(subs['LAT'])),
            (max(subs['LON']), max(subs['LAT']))]


def bounding_box(points: List[Tuple[float, float]], 
                 expansion: int = 160) -> ((Tuple, Tuple), str):
    """ Calculates the corners of a bounding box with an
        input expansion in meters from a given bounding_box
        
        Subcalls:
         calculate_epsg, convertCoords

        Parameters:
         points (list): output of calc_bbox
         expansion (float): number of meters to expand or shrink the
                            points edges to be
    
        Returns:
         bl (tuple): x, y of bottom left corner with edges of expansion meters
         tr (tuple): x, y of top right corner with edges of expansion meters
    """
    bl = list(points[0])
    tr = list(points[1])
    inproj = Proj('epsg:4326')
    outproj_code = calculate_epsg(bl)
    outproj = Proj('epsg:' + str(outproj_code))
    bl_utm =  transform(inproj, outproj, bl[1], bl[0])
    tr_utm =  transform(inproj, outproj, tr[1], tr[0])

    distance1 = tr_utm[0] - bl_utm[0]
    distance2 = tr_utm[1] - bl_utm[1]
    expansion1 = (expansion - distance1)/2
    expansion2 = (expansion - distance2)/2
        
    bl_utm = [bl_utm[0] - expansion1, bl_utm[1] - expansion2]
    tr_utm = [tr_utm[0] + expansion1, tr_utm[1] + expansion2]

    zone = str(outproj_code)[3:]
    zone = zone[1:] if zone[0] == "0" else zone
    direction = 'N' if tr[1] >= 0 else 'S'
    utm_epsg = "UTM_" + zone + direction
    return (bl_utm, tr_utm), CRS[utm_epsg]

# Data download

In [5]:
def extract_dates(date_dict: dict, year: int) -> List:
    """ Transforms a SentinelHub date dictionary to a
         list of integer calendar dates
    """
    dates = []
    days_per_month = [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30]
    starting_days = np.cumsum(days_per_month)
    for date in date_dict:
        if date.year == year - 1:
            dates.append(-365 + starting_days[(date.month-1)] + date.day)
        if date.year == year:
            dates.append(starting_days[(date.month-1)] + date.day)
        if date.year == year + 1:
            dates.append(365 + starting_days[(date.month-1)]+date.day)
    return dates

def to_float32(array: np.array) -> np.array:
    """Converts an int_x array to float32"""
    print(f'The original max value is {np.max(array)}')
    if not isinstance(array.flat[0], np.floating):
        assert np.max(array) > 1
        array = np.float32(array) / 65535.
    assert np.max(array) <= 1
    return array

## Cloud and cloud shadow calculation functions

This cell block identifies clouds using s2Cloudless, and identifies clouds (and shadows) using Candra et al. 2019.

The output is per-px cloud/shadow masks, and a list of sentinel 2 dates to download.

In [6]:
def identify_clouds(bbox: List[Tuple[float, float]], epsg: 'CRS', time: dict = TIME):
    """ Downloads and calculates cloud cover and shadow
        
        Parameters:
         bbox (list): output of calc_bbox
         epsg (float): EPSG associated with bbox 
         time (tuple): YY-MM-DD - YY-MM-DD bounds for downloading 
    
        Returns:
         cloud_img (np.array): (X, 96, 96) array of cloud probs
         shadows (np.array):  (X, 96, 96) array of shadow binary
         clean_steps (np.array): (N,) array of clean idx
         cloud_dates (np.array): (N,) array of clean cloud datets
    """
    box = BBox(bbox, crs = epsg)
    cloud_request = WcsRequest(
        layer='CLOUD_NEW',
        bbox=box, time=time,
        resx='160m', resy='160m',
        image_format = MimeType.TIFF,
        maxcc=0.75, config=shconfig,
        custom_url_params = {constants.CustomUrlParam.UPSAMPLING: 'NEAREST'},
        time_difference=datetime.timedelta(hours=96))

    shadow_request = WcsRequest(
        layer='SHADOW',
        bbox=box, time=time,
        resx='60m', resy='60m',
        image_format =  MimeType.TIFF,
        maxcc=0.75, config=shconfig,
        custom_url_params = {constants.CustomUrlParam.UPSAMPLING: 'NEAREST'},
        time_difference=datetime.timedelta(hours=96))

    cloud_img = np.array(cloud_request.get_data())
    if not isinstance(cloud_img.flat[0], np.floating):
        assert np.max(cloud_img) > 1
        cloud_img = cloud_img / 255.
    assert np.max(cloud_img) <= 1

    cloud_img = resize(cloud_img, (cloud_img.shape[0], 96, 96), order = 0)
    n_cloud_px = np.sum(cloud_img > 0.33, axis = (1, 2))
    cloud_steps = np.argwhere(n_cloud_px > (96**2 * 0.15))
    clean_steps = [x for x in range(cloud_img.shape[0]) if x not in cloud_steps]
    
    
    cloud_dates_dict = [x for x in cloud_request.get_dates()]
    cloud_dates = extract_dates(cloud_dates_dict, YEAR)
    cloud_dates = [val for idx, val in enumerate(cloud_dates) if idx in clean_steps]
    
    shadow_dates_dict = [x for x in shadow_request.get_dates()]
    shadow_dates = extract_dates(shadow_dates_dict, YEAR)
    shadow_steps = [idx for idx, val in enumerate(shadow_dates) if val in cloud_dates]    
    
    shadow_img = np.array(shadow_request.get_data(data_filter = shadow_steps))
    shadow_pus = (shadow_img.shape[1]*shadow_img.shape[2])/(512*512) * shadow_img.shape[0]
    shadow_img = resize(shadow_img, (shadow_img.shape[0], 96, 96, shadow_img.shape[-1]), order = 0,
                        anti_aliasing = False, preserve_range = True).astype(np.uint16)

    cloud_img = np.delete(cloud_img, cloud_steps, 0)
    assert shadow_img.shape[0] == cloud_img.shape[0], (shadow_img.shape, cloud_img.shape)
    shadows = mcm_shadow_mask(shadow_img, cloud_img) # Make usre this makes sense??
    print(f"Shadows ({shadows.shape}) used {round(shadow_pus, 1)} processing units")
    return cloud_img, shadows, clean_steps, np.array(cloud_dates)

# DEM and slope

In [7]:
def download_dem(plot_id: int, df: 'DataFrame', epsg: 'CRS') -> (np.ndarray, np.ndarray):
    """ Downloads MapZen digital elevation model and return slope

        Parameters:
         plot_id (tuple): plot id from collect earth online (CEO)
         df (pandas.DataFrame): data associated with plot_id from CEO
         epsg (int): UTM EPSG associated with plot_id
    
        Returns:
         slope (arr): (X, Y, 1) array of per-pixel slope from [0, 1]
    """
    location = calc_bbox(plot_id, df = df)
    bbox, epsg = bounding_box(location, expansion = (32+2)*10)
    box = BBox(bbox, crs = epsg)
    dem_request = WcsRequest(
                         layer='DEM', bbox=box,
                         resx = "10m", resy = "10m",
                         config=shconfig,
                         image_format= MimeType.TIFF,
                         custom_url_params={CustomUrlParam.SHOWLOGO: False})
    dem_image_init = dem_request.get_data()[0]
    dem_image_init = dem_image_init - 12000
    dem_image_init = dem_image_init.astype(np.float32)
    dem_image = np.copy(dem_image_init)
    dem_image = median_filter(dem_image_init, size = 5)
    slope = calcSlope(dem_image.reshape((1, 32+2, 32+2)),
                      np.full((32+2, 32+2), 10),
                      np.full((32+2, 32+2), 10), 
                      zScale = 1, minSlope = 0.02)
    slope = slope / 90
    slope = slope.reshape((32+2, 32+2, 1))
    slope = slope[1:32+1, 1:32+1, :]
    return slope, dem_image_init

## 10 and 20 meter L2A bands

In [8]:
def download_layer(bbox: List[Tuple[float, float]],
                   clean_steps: np.ndarray, epsg: 'CRS',
                   dates: dict = TIME, year: int = YEAR) -> (np.ndarray, np.ndarray):
    """ Downloads the L2A sentinel layer with 10 and 20 meter bands
        
        Parameters:
         bbox (list): output of calc_bbox
         epsg (float): EPSG associated with bbox 
         time (tuple): YY-MM-DD - YY-MM-DD bounds for downloading 
    
        Returns:
         img (arr):
         img_request (obj): 
    """
    try:
        box = BBox(bbox, crs = epsg)
        image_request = WcsRequest(
                layer='L2A20',
                bbox=box, time=dates,
                image_format = MimeType.TIFF,
                data_source = DataSource.SENTINEL2_L2A,
                maxcc=0.75,
                resx='20m', resy='20m',
                config=shconfig,
                custom_url_params = {constants.CustomUrlParam.DOWNSAMPLING: 'NEAREST',
                                    constants.CustomUrlParam.UPSAMPLING: 'NEAREST'},
                time_difference=datetime.timedelta(hours=96),
            )
        
        image_dates = []
        for date in image_request.get_dates():
            if date.year == YEAR - 1:
                image_dates.append(-365 + starting_days[(date.month-1)] + date.day)
            if date.year == YEAR:
                image_dates.append(starting_days[(date.month-1)] + date.day)
            if date.year == YEAR + 1:
                image_dates.append(365 + starting_days[(date.month-1)]+date.day)
        
        steps_to_download = [i for i, val in enumerate(image_dates) if val in clean_steps]
        dates_to_download = [val for i, val in enumerate(image_dates) if val in clean_steps]
              
        img_bands = image_request.get_data(data_filter = steps_to_download)
        img_20 = np.stack(img_bands)
        img_20 = to_float32(img_20)

        s2_20_usage = (img_20.shape[1]*img_20.shape[2])/(512*512) * (6/3) * img_20.shape[0]
        if (img_20.shape[1] * img_20.shape[2]) != 14*14:
            print(f"Original 20 meter bands size: {img_20.shape}, using {s2_20_usage} PU")
        img_20 = resize(img_20, (img_20.shape[0], IMSIZE, IMSIZE, img_20.shape[-1]), order = 0)
        
        image_request = WcsRequest(
                layer='L2A10',
                bbox=box, time=dates,
                image_format = MimeType.TIFF,
                data_source = DataSource.SENTINEL2_L2A,
                maxcc=0.75,
                resx='10m', resy='10m',
                config=shconfig,
                custom_url_params = {constants.CustomUrlParam.DOWNSAMPLING: 'BICUBIC',
                                    constants.CustomUrlParam.UPSAMPLING: 'BICUBIC'},
                time_difference=datetime.timedelta(hours=96),
        )
        
        img_bands = image_request.get_data(data_filter = steps_to_download)
        img_10 = np.stack(img_bands)
        if (img_10.shape[1] * img_10.shape[2]) != 28*28:
            print(f"The original L2A image size is: {img_10.shape}")
        img_10 = to_float32(img_10)
            
        img_10 = resize(img_10, (img_10.shape[0], IMSIZE, IMSIZE, img_10.shape[-1]), order = 0)
        img = np.concatenate([img_10, img_20], axis = -1)

        
        return img, np.array(dates_to_download)

    except Exception as e:
        logging.fatal(e, exc_info=True)

# Super resolution

Super-resolve the 20 meter bands to 10 meters using DSen2.

In [9]:
import tensorflow as tf
sess = tf.Session()
from keras import backend as K
K.set_session(sess)

MDL_PATH = "../models/supres/"

model = tf.train.import_meta_graph(MDL_PATH + 'model.meta')
model.restore(sess, tf.train.latest_checkpoint(MDL_PATH))

logits = tf.get_default_graph().get_tensor_by_name("Add_6:0")
inp = tf.get_default_graph().get_tensor_by_name("Placeholder:0")
inp_bilinear = tf.get_default_graph().get_tensor_by_name("Placeholder_1:0")

def superresolve(input_data):
    bilinear_upsample = input_data[..., 4:]
    x = sess.run([logits], 
                 feed_dict={inp: input_data,
                            inp_bilinear: bilinear_upsample})
    return x[0]

def superresolve_tile(x):
    twentym = x[..., 4:]
    imsize = x.shape[1]
    twentym = np.reshape(twentym, (x.shape[0], imsize // 2, 2, imsize // 2, 2, 6))
    twentym = np.mean(twentym, (2, 4))
    twentym = resize(twentym, (x.shape[0], imsize, imsize, 6), 1)
    x[..., 4:] = twentym
    x[..., 4:] = superresolve(x)
    if imsize > 28:
        crop_amt = (imsize - 28) // 2
        x = x[:, crop_amt:-crop_amt, crop_amt:-crop_amt, :]
    return x

Using TensorFlow backend.


INFO:tensorflow:Restoring parameters from ../models/supres/model


# Download function

In [10]:
def download_new_dem(data_location: 'os.Path',
                     output_folder: 'os.Path',
                     image_format: 'MimeType' = MimeType.TIFF):
    """ Downloads and saves DEM and slope files
        
        Parameters:
         data_location (os.path): 
         output_folder (os.path): 
         image_format (MimeType): 
    
        Returns:
         None
    """
    
    df = pd.read_csv(data_location)
    df.columns = [x.upper() for x in df.columns]
    for column in ['IMAGERY_TITLE', 'STACKINGPROFILEDG', 'PL_PLOTID', 'IMAGERYYEARDG',
                  'IMAGERYMONTHPLANET', 'IMAGERYYEARPLANET', 'IMAGERYDATESECUREWATCH',
                  'IMAGERYENDDATESECUREWATCH', 'IMAGERYFEATUREPROFILESECUREWATCH',
                  'IMAGERYSTARTDATESECUREWATCH','IMAGERY_ATTRIBUTIONS',
                  'SAMPLE_GEOM']:
        if column in df.columns:
            df = df.drop(column, axis = 1)
            
    df = df.dropna(axis = 0)
    plot_ids = sorted(df['PLOT_ID'].unique())
    existing = [int(x[:-4]) for x in os.listdir(output_folder) if ".DS" not in x]
    existing = existing + [139089844]
    to_download = [x for x in plot_ids if x not in existing]
    print(f"Starting download of {len(to_download)}"
          f" plots from {data_location} to {output_folder}")
    errors = []
    for i, val in enumerate(to_download):
        print(f"Downloading {i + 1}/{len(to_download)}, {val}")
        initial_bbx = calc_bbox(val, df = df)
        dem_bbx, epsg = bounding_box(initial_bbx, expansion = 32*10)
        slope, dem = download_dem(val, epsg = epsg, df = df)
        np.save(output_folder + str(val), dem)
        np.save("../data/train-slope/" + str(val), slope)
        


def concatenate_dem(x, dem):
    dem = np.tile(dem.reshape((1, 32, 32, 1)), (x.shape[0], 1, 1, 1))
    dem = dem[:, 2:-2, 2:-2, :]
    x = np.concatenate([x, dem], axis = -1)
    assert x.shape[1] == x.shape[2] == 28
    return x

In [11]:
def id_missing_px(sentinel2: np.ndarray, thresh: int = 100) -> np.ndarray:
    missing_images_0 = np.sum(sentinel2[..., :10] == 0.0, axis = (1, 2, 3))
    missing_images_p = np.sum(sentinel2[..., :10] >= 1., axis = (1, 2, 3))
    missing_images = missing_images_0 + missing_images_p
    
    missing_images = np.argwhere(missing_images >= (sentinel2.shape[1]**2) / thresh).flatten()
    return missing_images


def download_raw_data(data_location, output_folder, fmt = "train", image_format = MimeType.TIFF):
    """ Downloads slope and sentinel-2 data for all plots associated
        with an input CSV from a collect earth online survey
        
        Parameters:
         data_location (os.path)
         output_folder (os.path)
        
        Creates:
         output_folder/{plot_id}.npy
    
        Returns:
         None
    """
    df = pd.read_csv(data_location)
    df.columns = [x.upper() for x in df.columns]
    for column in ['IMAGERY_TITLE', 'STACKINGPROFILEDG', 'PL_PLOTID', 'IMAGERYYEARDG',
                  'IMAGERYMONTHPLANET', 'IMAGERYYEARPLANET', 'IMAGERYDATESECUREWATCH',
                  'IMAGERYENDDATESECUREWATCH', 'IMAGERYFEATUREPROFILESECUREWATCH',
                  'IMAGERYSTARTDATESECUREWATCH','IMAGERY_ATTRIBUTIONS',
                  'SAMPLE_GEOM']:
        if column in df.columns:
            df = df.drop(column, axis = 1)

    df = df.dropna(axis = 0)
    plot_ids = sorted(df['PLOT_ID'].unique())
    existing = [int(x[:-4]) for x in os.listdir(f"../data/{fmt}-dates/") if ".DS" not in x]
    existing = existing + [139190271, 139187199, 139319876, 139319877]
    to_download = [x for x in plot_ids if x not in existing]
    print(f"Starting download of {len(to_download)}"
          f" plots from {data_location} to {output_folder}")
    for i, val in enumerate(reversed(to_download)):
        print(f"Downloading {i + 1}/{len(to_download)}, {val}")
        initial_bbx = calc_bbox(val, df = df)
        sentinel2_bbx, epsg = bounding_box(initial_bbx, expansion = IMSIZE*10)
        cloud_bbx, _ = bounding_box(initial_bbx, expansion = 96*10)
        try:
            # Identify cloud steps, download DEM, and download L2A series
            cloud_probs, shadows, _, clean_dates = identify_clouds(cloud_bbx, epsg = epsg)
            dem, _ = download_dem(val, epsg = epsg, df = df)
            #to_remove, _ = calculate_cloud_steps(cloud_probs, clean_dates)
            
            #if len(to_remove) > 0:
            #    cloud_probs = np.delete(cloud_probs, to_remove, 0)
            #    clean_dates = np.delete(clean_dates, to_remove)
            #    shadows = np.delete(shadows, to_remove, 0)
                
            to_remove = subset_contiguous_sunny_dates(clean_dates, cloud_probs)
            if len(to_remove) > 0:
                cloud_probs = np.delete(cloud_probs, to_remove, 0)
                clean_dates = np.delete(clean_dates, to_remove)
                shadows = np.delete(shadows, to_remove, 0)
                
            _ = print_dates(clean_dates, np.mean(cloud_probs, axis = (1, 2)))
        
            s2, s2_dates = download_layer(sentinel2_bbx, clean_steps = clean_dates, epsg = epsg)    
            
            # Step to ensure that shadows, clouds, sentinel l2a have aligned dates
            to_remove_clouds = [i for i, val in enumerate(clean_dates) if val not in s2_dates]
            to_remove_dates = [val for i, val in enumerate(clean_dates) if val not in s2_dates]
            if len(to_remove_clouds) > 0:
                print(f"Removing {to_remove_dates} from clouds because not in S2")
                cloud_probs = np.delete(cloud_probs, to_remove_clouds, 0)
                shadows = np.delete(shadows, to_remove_clouds, 0)
            print(f"Shadows {shadows.shape}, clouds {cloud_probs.shape},"
                  f" S2, {s2.shape}, S2d, {s2_dates.shape}")
            
            print(s2.shape)
            
            cloud_probs = cloud_probs[:, 24:-24, 24:-24]
            shadows = shadows[:, 24:-24, 24:-24]
            x, interp = remove_cloud_and_shadows(s2, cloud_probs, shadows, s2_dates)
            to_remove = np.argwhere(np.mean(interp, axis = (1, 2)) > 0.5)
            if len(to_remove) > 0:
                print(f"Removing {len(to_remove)} steps with >50% interpolation: {to_remove}")
                x = np.delete(x, to_remove, 0)
                cloud_probs = np.delete(cloud_probs, to_remove, 0)
                s2_dates = np.delete(s2_dates, to_remove)
                shadows = np.delete(shadows, to_remove, 0)
                print(np.sum(shadows, axis = (1, 2)))
            
            x_to_save = np.copy(x)
            x_to_save = np.clip(x_to_save, 0, 1)
            x_to_save = np.trunc(x_to_save * 65535).astype(np.uint16)
            np.save(f"../data/{fmt}-dates/{str(val)}", s2_dates)
            np.save(f"../data/{fmt}-raw/{str(val)}", x_to_save)
            file = f"../data/{fmt}-raw/{str(val)}.npy"
            key = f'restoration-mapper/model-data/{fmt}/raw/{str(val)}.npy'
            uploader.upload(bucket = 'restoration-monitoring', key = key, file = file)
            print("\n")

        except Exception as e:
            print(e)
            logging.fatal(e, exc_info=True)

In [12]:
def select_dates(dates):
    """For imagery that was downloaded prior to capping the number 
       of monthly images to be 3, it is necessary to enforce that cap
       on the training / testing data.
       
       This function identifies the indices of the imagery to deletet
       such that there is a maximum of three images per month.
    
    """
    
    
    before = len(dates)
    selected_indices = np.arange(len(dates))
    begin = [-60, 0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334]
    end = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 390]
    indices_to_remove = []
    for x, y in zip(begin, end):
        indices_month = np.argwhere(np.logical_and(dates >= x, dates < y)).flatten()
        if len(indices_month) > 3:
            to_delete = np.empty((0,))
            if begin == -60:
                to_delete = indices_month[:-3]
            elif begin == 334:
                to_delete = indices_month[3:]
            elif len(indices_month) == 4:
                to_delete = indices_month[1]
            elif len(indices_month) == 5:
                to_delete = np.array([indices_month[1],
                                      indices_month[3]])
            elif len(indices_month) == 6:
                to_delete = np.array([indices_month[1],
                                      indices_month[3],
                                      indices_month[4]])
                
            to_delete = np.array(to_delete)
            if to_delete.size > 0:
                indices_to_remove.append(to_delete.flatten())
                
    if len(indices_to_remove) > 0:
        indices_to_remove = np.concatenate(indices_to_remove)
        after = before - len(indices_to_remove)
        print(f"Keeping {after}/{before}")
        return indices_to_remove
    
    else:
        return []
    

def subset_contiguous_sunny_dates(dates, probs):
    """
    The general imagery subsetting strategy is as below:
        - Select all images with < 30% cloud cover
        - For each month, select up to 2 images that are <30% CC and are the closest to
          the beginning and the midde of the month
        - Select only one image per month for each month if the following criteria are met
              - Within Q1 and Q4, apply if at least 3 images in quarter
              - Otherwise, apply if at least 8 total images for year
              - Select the second image if max CC < 15%, otherwise select least-cloudy image
        - If more than 10 images remain, remove any images for April and September

    """
    
    probs = np.mean(probs, axis = (1, 2))
    begin = [-60, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334]
    end = [31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 410]
    n_per_month = []
    months_to_adjust = []
    months_to_adjust_again = []
    indices_to_rm = []
    indices = [x for x in range(len(dates))]
    
    def _indices_month(dates, x, y):
        indices_month = np.argwhere(np.logical_and(
                    dates >= x, dates < y)).flatten()
        return indices_month

    
    _ = print_dates(dates, probs)
    # Select the best 2 images per month to start with
    best_two_per_month = []
    for x, y in zip(begin, end):
        indices_month = np.argwhere(np.logical_and(
            dates >= x, dates < y)).flatten()

        month_dates = dates[indices_month]
        month_clouds = probs[indices_month]
        month_good_dates = month_dates[month_clouds < 0.20]
        indices_month = indices_month[month_clouds < 0.20]

        if len(month_good_dates) >= 2:
            if x > 0:
                ideal_dates = [x, x + 15]
            else:
                ideal_dates = [0, 15]

            # We first pick the 2 images with <30% cloud cover that are the closest
            # to the 1st and 15th of the month
            # todo: if both these images are above 15%, and one below 15% is available, include it
            closest_to_first_img = np.argmin(abs(month_good_dates - ideal_dates[0]))
            closest_to_second_img = np.argmin(abs(month_good_dates - ideal_dates[1]))
            if closest_to_second_img == closest_to_first_img:
                distances = abs(month_good_dates - ideal_dates[1])
                closest_to_second_img = np.argsort(distances)[1]

            first_image = indices_month[closest_to_first_img]
            second_image = indices_month[closest_to_second_img]
            best_two_per_month.append(first_image)
            best_two_per_month.append(second_image)
                    
        elif len(month_good_dates) >= 1:
            if x > 0:
                ideal_dates = [x, x + 15]
            else:
                ideal_dates = [0, 15]

            closest_to_second_img = np.argmin(abs(month_good_dates - ideal_dates[1]))
            second_image = indices_month[closest_to_second_img]
            best_two_per_month.append(second_image)
                
    dates_round_2 = dates[best_two_per_month]
    probs_round_2 = probs[best_two_per_month]
    
    # We then select between those two images to keep a max of one per month
    # We select the least cloudy image if the most cloudy has >15% cloud cover
    # Otherwise we select the second image

    # If there are more than 8 images, subset so only 1 image per month,
    # To bring down to a min of 8 images
    if len(dates_round_2) >= 8:
        n_to_rm = len(dates_round_2) - 8
        monthly_dates = []
        monthly_probs = []
        monthly_dates_date = []
        removed = 0
        for x, y in zip(begin, end):
            indices_month = np.argwhere(np.logical_and(
                dates >= x, dates < y)).flatten()
            dates_month = dates[indices_month]
            indices_month = [val for i, val in enumerate(indices_month) if dates_month[i] in dates_round_2]
            if len(indices_month) > 1:
                month_dates = dates[indices_month]
                month_clouds = probs[indices_month]

                subset_month = True
                if x == -60:
                    feb_mar = np.argwhere(np.logical_and(
                        dates >= 31, dates < 90)).flatten()
                    subset_month = False if len(feb_mar) < 2 else True
                if x == 334:
                    oct_nov = np.argwhere(np.logical_and(
                        dates >= 273, dates < 334)).flatten()
                    subset_month = False if len(oct_nov) < 2 else True

                if subset_month:
                    subset_month = True if removed <= n_to_rm else False
                if subset_month:
                    if np.max(month_clouds) >= 0.10:
                        month_best_date = [indices_month[np.argmin(month_clouds)]]
                    else:
                        month_best_date = [indices_month[1]]
                else:
                    month_best_date = indices_month
                monthly_dates.extend(month_best_date)
                monthly_probs.extend(probs[month_best_date])
                monthly_dates_date.extend(dates[month_best_date])
                removed += 1
            elif len(indices_month) == 1:
                monthly_dates.append(indices_month[0])
                monthly_probs.append(probs[indices_month[0]])
                monthly_dates_date.append(dates[indices_month[0]])
    else:
        monthly_dates = best_two_per_month
        
    indices_to_rm = [x for x in indices if x not in monthly_dates]


    dates_round_3 = dates[monthly_dates]
    probs_round_3 = probs[monthly_dates]

    if len(dates_round_3) >= 10:
        delete_max = False
        if np.max(probs_round_3) >= 0.15:
            delete_max = True
            indices_to_rm.append(monthly_dates[np.argmax(probs_round_3)])
        for x, y in zip(begin, end):
            indices_month = np.argwhere(np.logical_and(
                dates >= x, dates < y)).flatten()
            dates_month = dates[indices_month]
            indices_month = [x for x in indices_month if x in monthly_dates]

            n_removed = 0
            if len(indices_month) >= 1:
                if len(monthly_dates) == 11 and delete_max:
                    continue
                elif len(monthly_dates) >= 11:
                    if x in [90, 243]:
                        indices_to_rm.append(indices_month[0])

    return indices_to_rm


def to_int16(array: np.array) -> np.array:
    '''Converts a float32 array to uint16, reducing storage costs by three-fold'''
    array = np.clip(array, 0, 1)
    array = np.trunc(array * 65535)
    assert np.min(array >= 0)
    assert np.max(array <= 65535)
    
    return array.astype(np.uint16)


def process_raw(plot_id, path = 'train'):
    """ Downloads slope and sentinel-2 data for all plots associated
        with an input CSV from a collect earth online survey
        
        Parameters:
         data_location (os.path)
         output_folder (os.path)
        
        Creates:
         output_folder/{plot_id}.npy
    
        Returns:
         None
    """         

    x = np.load(f"../data/{path}-raw/{plot_id}.npy")
    x = np.float32(x) / 65535
    if x.shape[-1] == 10:
        s2_dates = np.load(f"../data/{path}-dates/{plot_id}.npy")
        dem = np.load(f"../data/{path}-slope/{plot_id}.npy")

        assert x.shape[0] == s2_dates.shape[0]

        missing_px = id_missing_px(x)
        if len(missing_px) > 0:
            print(f"Deleting {missing_px} because of missing data")
            x = np.delete(x, missing_px, 0)
            s2_dates = np.delete(s2_dates, missing_px)


        n_images = x.shape[0]

        to_remove = select_dates(s2_dates)
        if len(to_remove) > 0:
            x = np.delete(x, to_remove, 0)
            s2_dates = np.delete(s2_dates, to_remove)

        print(x.shape)
        #to_remove = subset_contiguous_sunny_dates(s2_dates)
        #if len(to_remove) > 0:
        #    x = np.delete(x, to_remove, 0)
        #    s2_dates = np.delete(s2_dates, to_remove)

        for band in range(0, 10):
            for time in range(0, x.shape[0]):
                x_i = x[time, :, :, band]
                x_i[np.argwhere(np.isnan(x_i))] = np.mean(x_i)
                x[time, :, :, band] = x_i

        # Interpolate linearly to 5 day frequency
        tiles, max_distance = calculate_and_save_best_images(x, s2_dates)
        sm = Smoother(lmbd = 150, size = tiles.shape[0],
                      nbands = 10, dimx = tiles.shape[1], dimy = tiles.shape[2])
        x = sm.interpolate_array(tiles)
        x = superresolve_tile(x)
        tiles = concatenate_dem(x, dem)
        dates = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334]
        dates = np.array(dates) + 15
        closest_date = []
        for date in dates:
            date_diff = s2_dates[np.argmin(abs(s2_dates - date))]
            closest_date.append(date_diff)

        closest_date = np.array(closest_date)
        closest_date = closest_date[:, np.newaxis, np.newaxis, np.newaxis]
        closest_date = np.broadcast_to(closest_date, (12, 28, 28, 1))
        closest_date = (closest_date + 45)  / 411

        tiles = np.concatenate([tiles, closest_date], axis = -1)

        if np.sum(np.isnan(tiles)) == 0:
            print(f"There are {np.sum(np.isnan(tiles))} NA values")
            if max_distance <= 300 and n_images >= 5:
                tiles = to_int16(tiles)
                #np.save(f"../data/{path}-s2-new/{plot_id}", tiles)
                tile_path = f"../data/{path}-s2/{plot_id}"
                tile_path = tile_path + ".hkl"
                hkl.dump(tiles, tile_path, mode='w', compression='gzip')
                print(f"Saved {tiles.shape} shape, {n_images} img,"
                      f" to {tile_path} \n")
            else:
                print(f"Skipping {plot_id} because {max_distance} distance, and {n_images} img \n")

        return tiles

# Function execution
## 1. Download DEM and Slope

In [13]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

for i in (os.listdir("../data/train-csv/")):
    if "ceo-asia" in i:
        tile = download_new_dem("../data/train-csv/" + i,
                                "../data/train-dem/",
                                image_format = MimeType.TIFF)        

Starting download of 0 plots from ../data/train-csv/ceo-asia-otherland-sample-data-20.csv to ../data/train-dem/
Starting download of 0 plots from ../data/train-csv/ceo-asia-tml-finetune-sample-data-22.csv to ../data/train-dem/
Starting download of 0 plots from ../data/train-csv/ceo-asia-rice-sample-data-21.csv to ../data/train-dem/


## 2. Download Raw data files

In [None]:
for i in (os.listdir("../data/train-csv/")):
    if "ceo-asia" in i:
        download_raw_data("../data/train-csv/" + i,
                          "../data/train-raw/", 
                          fmt = 'train',
                          image_format = MimeType.TIFF)

Starting download of 0 plots from ../data/train-csv/ceo-asia-otherland-sample-data-20.csv to ../data/train-raw/
Starting download of 24 plots from ../data/train-csv/ceo-asia-tml-finetune-sample-data-22.csv to ../data/train-raw/
Downloading 1/24, 220039
Shadows ((28, 96, 96)) used 0.0 processing units
1, Dates: [17, 22, 366, 371, 386], Probs: [0.09, 0.05, 0.07, 0.06, 0.18]
2, Dates: [32, 37, 42, 52, 396, 401, 406], Probs: [0.11, 0.18, 0.12, 0.1, 0.08, 0.04, 0.05]
3, Dates: [86], Probs: [0.04]
4, Dates: [91, 96, 101, 106], Probs: [0.07, 0.14, 0.12, 0.17]
5, Dates: [], Probs: []
6, Dates: [156], Probs: [0.16]
7, Dates: [206], Probs: [0.09]
8, Dates: [236], Probs: [0.04]
9, Dates: [261, 271], Probs: [0.14, 0.1]
10, Dates: [281], Probs: [0.12]
11, Dates: [-33, 311, 316], Probs: [0.07, 0.07, 0.05]
12, Dates: [336, 356], Probs: [0.12, 0.0]
1, Dates: [22], Probs: [0.05]
2, Dates: [32], Probs: [0.11]
3, Dates: [86], Probs: [0.04]
4, Dates: [91], Probs: [0.07]
5, Dates: [], Probs: []
6, Dates: [



Downloading 7/24, 220033
Shadows ((25, 96, 96)) used 0.0 processing units
1, Dates: [366, 371], Probs: [0.05, 0.17]
2, Dates: [32, 37, 57, 399], Probs: [0.09, 0.14, 0.06, 0.14]
3, Dates: [84, 89], Probs: [0.06, 0.12]
4, Dates: [94, 99, 109], Probs: [0.18, 0.19, 0.1]
5, Dates: [129, 134], Probs: [0.13, 0.2]
6, Dates: [], Probs: []
7, Dates: [206], Probs: [0.18]
8, Dates: [], Probs: []
9, Dates: [259], Probs: [0.12]
10, Dates: [281, 286], Probs: [0.06, 0.07]
11, Dates: [-33, 306, 311, 316], Probs: [0.13, 0.1, 0.02, 0.03]
12, Dates: [-3, 334, 351, 361], Probs: [0.11, 0.14, 0.15, 0.13]
1, Dates: [], Probs: []
2, Dates: [32], Probs: [0.09]
3, Dates: [84], Probs: [0.06]
4, Dates: [109], Probs: [0.1]
5, Dates: [129], Probs: [0.13]
6, Dates: [], Probs: []
7, Dates: [], Probs: []
8, Dates: [], Probs: []
9, Dates: [259], Probs: [0.12]
10, Dates: [286], Probs: [0.07]
11, Dates: [316], Probs: [0.03]
12, Dates: [-3, 334], Probs: [0.11, 0.14]
The original max value is 31057
Original 20 meter bands

CRITICAL:root:need at least one array to stack
Traceback (most recent call last):
  File "<ipython-input-8-cdb85e55c715>", line 43, in download_layer
    img_20 = np.stack(img_bands)
  File "<__array_function__ internals>", line 6, in stack
  File "/Users/jbrandt.terminal/opt/anaconda3/envs/tf/lib/python3.7/site-packages/numpy/core/shape_base.py", line 422, in stack
    raise ValueError('need at least one array to stack')
ValueError: need at least one array to stack
CRITICAL:root:cannot unpack non-iterable NoneType object
Traceback (most recent call last):
  File "<ipython-input-11-c86e3e63af6b>", line 65, in download_raw_data
    s2, s2_dates = download_layer(sentinel2_bbx, clean_steps = clean_dates, epsg = epsg)
TypeError: cannot unpack non-iterable NoneType object


cannot unpack non-iterable NoneType object
Downloading 13/24, 220017


CRITICAL:root:tuple index out of range
Traceback (most recent call last):
  File "<ipython-input-11-c86e3e63af6b>", line 48, in download_raw_data
    cloud_probs, shadows, _, clean_dates = identify_clouds(cloud_bbx, epsg = epsg)
  File "<ipython-input-6-ded57245c76e>", line 55, in identify_clouds
    shadow_pus = (shadow_img.shape[1]*shadow_img.shape[2])/(512*512) * shadow_img.shape[0]
IndexError: tuple index out of range


tuple index out of range
Downloading 14/24, 220016
Shadows ((6, 96, 96)) used 0.0 processing units
1, Dates: [6], Probs: [0.15]
2, Dates: [36, 46], Probs: [0.18, 0.17]
3, Dates: [60], Probs: [0.18]
4, Dates: [], Probs: []
5, Dates: [], Probs: []
6, Dates: [], Probs: []
7, Dates: [], Probs: []
8, Dates: [], Probs: []
9, Dates: [], Probs: []
10, Dates: [300], Probs: [0.16]
11, Dates: [], Probs: []
12, Dates: [-4], Probs: [0.17]
1, Dates: [6], Probs: [0.15]
2, Dates: [36, 46], Probs: [0.18, 0.17]
3, Dates: [60], Probs: [0.18]
4, Dates: [], Probs: []
5, Dates: [], Probs: []
6, Dates: [], Probs: []
7, Dates: [], Probs: []
8, Dates: [], Probs: []
9, Dates: [], Probs: []
10, Dates: [300], Probs: [0.16]
11, Dates: [], Probs: []
12, Dates: [-4], Probs: [0.17]
The original max value is 31870
Original 20 meter bands size: (6, 16, 16, 6), using 0.01171875 PU
The original L2A image size is: (6, 32, 32, 4)
The original max value is 34786
Shadows (6, 96, 96), clouds (6, 96, 96), S2, (6, 32, 32, 10), 

CRITICAL:root:tuple index out of range
Traceback (most recent call last):
  File "<ipython-input-11-c86e3e63af6b>", line 48, in download_raw_data
    cloud_probs, shadows, _, clean_dates = identify_clouds(cloud_bbx, epsg = epsg)
  File "<ipython-input-6-ded57245c76e>", line 55, in identify_clouds
    shadow_pus = (shadow_img.shape[1]*shadow_img.shape[2])/(512*512) * shadow_img.shape[0]
IndexError: tuple index out of range


tuple index out of range
Downloading 17/24, 220013


CRITICAL:root:tuple index out of range
Traceback (most recent call last):
  File "<ipython-input-11-c86e3e63af6b>", line 48, in download_raw_data
    cloud_probs, shadows, _, clean_dates = identify_clouds(cloud_bbx, epsg = epsg)
  File "<ipython-input-6-ded57245c76e>", line 55, in identify_clouds
    shadow_pus = (shadow_img.shape[1]*shadow_img.shape[2])/(512*512) * shadow_img.shape[0]
IndexError: tuple index out of range


tuple index out of range
Downloading 18/24, 220010
Shadows ((1, 96, 96)) used 0.0 processing units
1, Dates: [], Probs: []
2, Dates: [], Probs: []
3, Dates: [], Probs: []
4, Dates: [], Probs: []
5, Dates: [150], Probs: [0.18]
6, Dates: [], Probs: []
7, Dates: [], Probs: []
8, Dates: [], Probs: []
9, Dates: [], Probs: []
10, Dates: [], Probs: []
11, Dates: [], Probs: []
12, Dates: [], Probs: []
1, Dates: [], Probs: []
2, Dates: [], Probs: []
3, Dates: [], Probs: []
4, Dates: [], Probs: []
5, Dates: [150], Probs: [0.18]
6, Dates: [], Probs: []
7, Dates: [], Probs: []
8, Dates: [], Probs: []
9, Dates: [], Probs: []
10, Dates: [], Probs: []
11, Dates: [], Probs: []
12, Dates: [], Probs: []
The original max value is 38469
Original 20 meter bands size: (1, 16, 16, 6), using 0.001953125 PU
The original L2A image size is: (1, 32, 32, 4)
The original max value is 35572
Shadows (1, 96, 96), clouds (1, 96, 96), S2, (1, 32, 32, 10), S2d, (1,)
(1, 32, 32, 10)
uploading ../data/train-raw/220010.npy 

## 3. Process train / test data

In [None]:
i = 0
plots = [str(x[:-4]) for x in os.listdir("../data/train-raw/") if ".npy" in x]
for plot in plots:
    i += 1
    if not os.path.exists("../data/train-s2/" + plot + ".hkl"):
        print(plot)
        try:
            tiles = process_raw(plot, path = 'train')
            print(i, plot)
        except:
            continue