In [1]:
from IPython.display import display
from tqdm.notebook import tqdm

from multiprocessing import Pool
from time import sleep

from contextlib import ExitStack

import numpy as np
import pandas as pd

import datetime
from pathlib import Path
import json

import intake
import rasterio
import json
import pickle

import ee
from geemap import geemap

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
ee.Authenticate()
ee.Initialize(project='sentinel-treeclassification')

In [3]:
class SentinelGetter:
    def mask_s2_clouds(self, image):
      # Quality assessment with resolution in meters
      qa = image.select('QA60')
      # Bits 10 and 11 are clouds and cirrus, respectively.
      cloud_bit_mask = 1 << 10
      cirrus_bit_mask = 1 << 11
      # Both flags should be set to zero, indicating clear conditions.
      mask = (
          qa.bitwiseAnd(cloud_bit_mask)
          .eq(0)
          .And(qa.bitwiseAnd(cirrus_bit_mask).eq(0))
      )
      return image.updateMask(mask)

    def get_image(self, bbox, start_date, end_date):
        selected_bands = ['B[2-8]', 'B8A', 'B11', 'B12', 'TCI_R', 'TCI_G', 'TCI_B']
        image = (
            ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
            .filterDate(start_date, end_date)
            # Pre-filter to get less cloudy granules.
            .map(self.mask_s2_clouds)
            .select(selected_bands)
            .mean()
            .clip(bbox)
        )
        return image

In [4]:
catalog = intake.open_catalog(Path('../catalog.yml'))
source = getattr(catalog, 'treesat')
gdf = source.read()[source.metadata['usecols']]

# Buffer to avoid points lying on the borders
total_bounds = gdf.buffer(100, cap_style=3).to_crs(epsg=4326).geometry.total_bounds
bbox = ee.Geometry.BBox(*total_bounds)

gdf = gdf.to_crs(epsg=4326)

In [5]:
target = source.metadata['categories']['generic']
gdf[target] = gdf[target].astype('category')

category_map = dict(zip(gdf[target].cat.codes, gdf[target]))

gdf[target] = gdf[target].cat.codes

In [6]:
# def download_temporal_data(date, download=False):
#     download_dates = []
#     for month in range(4, 11, 2):
#         start_date = date.replace(month=month, day=1)
#         end_date = start_date.replace(
#             month=start_date.month%12 + 1, 
#             year=start_date.year + start_date.month//12)
        
#         download_dates.append(start_date.strftime('%b%Y'))
        
#         if download:
#             sentinel_image = SentinelGetter().get_image(
#                 bbox, start_date, end_date)
#             task = ee.batch.Export.image.toDrive(
#                 image=sentinel_image,
#                 fileNamePrefix=f"treesat_{start_date.strftime('%b%Y')}",
#                 description='TreeSatAI labels',
#                 folder='treesat',
#                 scale=10,
#             )
#             task.start()
#     return download_dates
# test_date = datetime.datetime(2020, 4, 15)
# filename_prefix_list = download_temporal_data(test_date, download=False)

In [7]:
# %%time
# class TifProcessor:
#     def crop_center(self, img, cropx=6, cropy=6):
#         bands, y, x = img.shape
#         startx = x//2 - (cropx//2)
#         starty = y//2 - (cropy//2)    
#         return img[:, starty:starty+cropy, startx:startx+cropx]

#     def read_tif(self, data_dir, filename_prefix):
#         tif_paths = list(Path(data_dir).glob(f'{filename_prefix}*.tif'))
#         save_path = Path(data_dir).joinpath(f'{filename_prefix}.npy')
        
#         if save_path.is_file():
#             with open(save_path, 'rb') as f:
#                 processed_data = pickle.load(f)
                
#             labels = [label for label, features in processed_data]
#             features = [features for label, features in processed_data]
#             return labels, features
            
#         if len(tif_paths) > 1:
#             with ExitStack() as stack:
#                 tif_files = [stack.enter_context(rasterio.open(fname)) for fname in tif_paths]
#                 print('Beginning Rasterio tif merge')
#                 tif_data = rasterio.merge.merge(tif_files)
#         elif len(tif_paths) == 1:
#             with rasterio.open(tif_paths[0]) as f:
#                 tif_data = f.read()
#         else:
#             raise FileNotFoundError(f'No files containing "{filename_prefix}" found in "{data_dir}".')

#         return tif_data
    
#     def tif_to_numpy(self, data_dir, filename_prefix, gdf):
        
#         tif_data = self.read_tif(data_dir, filename_prefix)

#         labels = []
#         features = []
        
#         for i, row in tqdm(gdf.iterrows(), total=gdf.shape[0]):
#             out_image, out_transform = rasterio.mask.mask(tif_data, [row.geometry], crop=True)
#             out_image = self.crop_center(out_image)
            
#             features.append(out_image)
#             labels.append(row[target_column])       

#         data_list = list(zip(labels, features))
#         with open(save_path, 'wb') as f:
#             pickle.dump(data_list, f)
#         return labels, features

# temporal_labels, temporal_features = [], []
# for filename_prefix in filename_prefix_list:
#     labels, features = TifProcessor().tif_to_numpy('data', filename_prefix, gdf)
#     temporal_labels.append(labels)
#     temporal_features.append(features)

In [8]:
# # soft data check
# if np.array_equal(temporal_labels[0], temporal_labels[1]):
#     raise ValueError(f"Don't panic")

In [9]:
# y = labels
# classes = np.unique(labels)

# X = np.array(features)
# X = X.reshape(len(y), -1)

# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

In [10]:
# feature_collection = geemap.geopandas_to_ee(gdf.sample(10000))
# map = geemap.Map()
# map.addLayerControl()
# map.setCenter(total_bounds[0::2].sum()/2, total_bounds[1::2].sum()/2, 8)
# url = 'https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}'
# map.add_tile_layer(url, name='Google Map', attribution='Google')
# map.addLayer(feature_collection, {}, "geopandas to ee")
# map.addLayer(bbox, {}, "bounding box")

In [11]:
def download_npy(bbox, start_date, end_date, gdf, sleep_time):
    sleep(sleep_time)
    params = {
        'format': 'NPY',
        'dimensions': (6, 6),
        # 'scale': 10,
    }
    sentinel_image = SentinelGetter().get_image(bbox, start_date, end_date)
    all_data = []
    labels = []
    save_path = Path('data').joinpath(f"treesat_{start_date.strftime('%m%Y')}.npy")
    
    print(f"Processing {start_date.strftime('%b %Y')}. ")
    for chunk in tqdm(np.array_split(gdf, 10)):
        for i in tqdm(chunk.index):
            retry = True
            while retry:
                try:
                    this_bbox = ee.Geometry.BBox(*chunk.geometry[i].bounds)
                    params['region'] = this_bbox
                    
                    url = sentinel_image.getDownloadURL(params)
                    file = np.DataSource().open(url)         
                    data = np.load(file.name)
                    labels.append(gdf.iloc[0][target])
                    all_data.append(data)
                    retry = False
                except:
                    sleep(1)
                    retry = True
                    
        with open(save_path, 'wb') as f:
            np.save(f, np.array(all_data))

In [None]:
start_date = datetime.datetime(2019, 1, 1)
end_date = start_date.replace(
    month=start_date.month%12 + 1, 
    year=start_date.year + start_date.month//12)
download_npy(bbox, start_date, end_date, gdf, 0)

Processing Jan 2019. 


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/5039 [00:00<?, ?it/s]

In [None]:
# pool = Pool(processes=6)
# map_inputs = []
# date = datetime.datetime(2019, 1, 1)
# for i in range(1, 13):
#     start_date = date.replace(month=i)
#     end_date = start_date.replace(
#         month=start_date.month%12 + 1, 
#         year=start_date.year + start_date.month//12)
#     map_inputs.append((bbox, start_date, end_date, gdf, i*2))

# pool.starmap(download_npy, map_inputs)
# pool.close()

In [None]:
# %%time
# dfs = []
# for chunk in tqdm(np.array_split(gdf, gdf.shape[0]//100)):
#     feature_collection = geemap.gdf_to_ee(chunk)
#     sample_regions = sentinel_image.sampleRegions(
#         collection=feature_collection, scale=10)
#     df = geemap.ee_to_gdf(sample_regions)
#     # info = sample_regions.getInfo()
#     # df = pd.json_normalize(info['features'])
#     dfs.append(df)

In [None]:
# gdfs = []
# for i, row in tqdm(gdf.head(100).iterrows(), total=gdf.head(100).shape[0]):
#     bbox = ee.Geometry.BBox(*row.geometry.bounds)
#     # sample = sentinel_image.stratifiedSample(
#     #     numPoints=36, classBand=target, region=bbox, scale=10, seed=42, geometries=True)
#     sample = sentinel_image.sample(
#         bbox, scale=10, numPixels=36, seed=42, geometries=True)
#     gdf = geemap.ee_to_gdf(sample)
#     gdfs.append(gdf)