In [1]:
from IPython.display import display
from tqdm.notebook import tqdm

from multiprocessing import Pool
from time import sleep

from contextlib import ExitStack

import numpy as np
import pandas as pd

import datetime
from pathlib import Path
import json

import intake
import rasterio
import json
import pickle

import ee
from geemap import geemap

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
ee.Authenticate()
ee.Initialize(project='sentinel-treeclassification')

In [3]:
class SentinelGetter:
    def mask_s2_clouds(self, image):
      # Quality assessment with resolution in meters
      qa = image.select('QA60')
      # Bits 10 and 11 are clouds and cirrus, respectively.
      cloud_bit_mask = 1 << 10
      cirrus_bit_mask = 1 << 11
      # Both flags should be set to zero, indicating clear conditions.
      mask = (
          qa.bitwiseAnd(cloud_bit_mask)
          .eq(0)
          .And(qa.bitwiseAnd(cirrus_bit_mask).eq(0))
      )
      return image.updateMask(mask)

    def get_image(self, bbox, start_date, end_date):
        selected_bands = ['B[2-8]', 'B8A', 'B11', 'B12', 'TCI_R', 'TCI_G', 'TCI_B']
        image = (
            ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
            .filterDate(start_date, end_date)
            # Pre-filter to get less cloudy granules.
            .map(self.mask_s2_clouds)
            .select(selected_bands)
            .mean()
            .clip(bbox)
        )
        return image

In [4]:
catalog = intake.open_catalog(Path('../catalog.yml'))
source = getattr(catalog, 'treesat')
gdf = source.read()[source.metadata['usecols']]

# Buffer to avoid points lying on the borders
total_bounds = gdf.buffer(100, cap_style=3).to_crs(epsg=4326).geometry.total_bounds
bbox = ee.Geometry.BBox(*total_bounds)

gdf = gdf.to_crs(epsg=4326)

In [5]:
target = source.metadata['categories']['generic']
gdf[target] = gdf[target].astype('category')

category_map = dict(zip(gdf[target].cat.codes, gdf[target]))

gdf[target] = gdf[target].cat.codes

In [6]:
# feature_collection = geemap.geopandas_to_ee(gdf.sample(10000))
# map = geemap.Map()
# map.addLayerControl()
# map.setCenter(total_bounds[0::2].sum()/2, total_bounds[1::2].sum()/2, 8)
# url = 'https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}'
# map.add_tile_layer(url, name='Google Map', attribution='Google')
# map.addLayer(feature_collection, {}, "geopandas to ee")
# map.addLayer(bbox, {}, "bounding box")

In [7]:
def download_npy(bbox, start_date, end_date, gdf, sleep_time):
    # Sleep time helps with parallel processing,
    # if you're brave enough to try it
    sleep(sleep_time)
    
    # For further options, see
    # https://developers.google.com/earth-engine/apidocs/ee-image-getdownloadurl
    params = {
        'format': 'NPY',
        'dimensions': (6, 6),
        # 'scale': 10,
    }
    # Cloud masked, band selected, mean image of the bbox area. 
    sentinel_image = SentinelGetter().get_image(bbox, start_date, end_date)

    save_path = Path('data').joinpath(f"treesat_{start_date.strftime('%m%Y')}.npy")
    print(f'Downloading {start_date}')
    # Continue from a previous run, else start new.
    if save_path.is_file():
        with open(save_path, 'rb') as f:
            # Convert outer array to list for appending, avoid ndarray.tolist()
            # as that converts nested arrays to list as well.
            all_data = list(np.load(save_path))
    else:
        all_data = []

    # Continue from previous iteration, or start new.
    continue_gdf = gdf.loc[len(all_data):]
    
    # Progress bar, tracks continuations
    for i, row in tqdm(
        continue_gdf.iterrows(), total=gdf.shape[0], initial=len(all_data)):
        # Not ideal but a lot of connection errors can occur here.
        # They are (so far) not program ending, simply retry.
        retry = True
        while retry:
            try:
                this_bbox = ee.Geometry.BBox(*row.geometry.bounds)
                params['region'] = this_bbox
                # Create the download for the image within the bbox as defined in params
                url = sentinel_image.getDownloadURL(params)

                # There can be a delay before the URL becomes available,
                # in which case the loop simply retries (seems rare so far).
                file = np.DataSource().open(url)         
                data = np.load(file.name)

                # Numpy ndarray being appended to a list of ndarrays.
                # Ensure all_data uses python's list instead of ndarray.tolist().
                all_data.append(data)
                retry = False
            except:
                # Sleep for 1 second if error, Google claims to be
                # fine with 100/s requests.
                sleep(sleep_time)
                retry = True
        # Save every 1000th run but not the first  
        if (i % 1000 == 0) and (i > 0):
            # Convert to array first in case of errors,
            # as that would erase the file.
            save_data = np.array(all_data)
            with open(save_path, 'wb') as f:
                np.save(f, save_data)
    save_data = np.array(all_data)
    with open(save_path, 'wb') as f:
        np.save(f, save_data)
                
    return np.array(all_data)

In [None]:
pool = Pool(processes=5)
map_inputs = []
date = datetime.datetime(2019, 1, 1)
for i in range(8, 13):
    start_date = date.replace(month=i)
    end_date = start_date.replace(
        month=start_date.month%12 + 1, 
        year=start_date.year + start_date.month//12)
    map_inputs.append((bbox, start_date, end_date, gdf, i*2))

pool.starmap(download_npy, map_inputs)
pool.close()

# start_date = datetime.datetime(2019, 6, 1)
# end_date = start_date.replace(
#     month=start_date.month%12 + 1, 
#     year=start_date.year + start_date.month//12)
# d = download_npy(bbox, start_date, end_date, gdf, 0)

Downloading 2019-08-01 00:00:00


 67%|######7   | 34001/50381 [00:00<?, ?it/s]

Downloading 2019-09-01 00:00:00


 67%|######7   | 34001/50381 [00:00<?, ?it/s]

Downloading 2019-10-01 00:00:00


 67%|######7   | 34001/50381 [00:00<?, ?it/s]

Downloading 2019-11-01 00:00:00


 67%|######7   | 34001/50381 [00:00<?, ?it/s]

Downloading 2019-12-01 00:00:00


 67%|######7   | 34001/50381 [00:00<?, ?it/s]

In [None]:
# pool = Pool(processes=6)
# map_inputs = []
# date = datetime.datetime(2018, 1, 1)
# for i in range(1, 7):
#     start_date = date.replace(month=i)
#     end_date = start_date.replace(
#         month=start_date.month%12 + 1, 
#         year=start_date.year + start_date.month//12)
#     map_inputs.append((bbox, start_date, end_date, gdf, i*2))

# pool.starmap(download_npy, map_inputs)
# pool.close()