In [2]:
import pandas as pd
import geopandas as gpd
import leafmap.foliumap as leafmapfol
import leafmap
import ast
import geopandas as gpd
from shapely.geometry import Polygon
import datacube
from odc.ui import with_ui_cbk
from deafrica_tools.plotting import rgb, display_map

%matplotlib inline
import matplotlib.pyplot as plt
from datacube.utils import geometry

from deafrica_tools.datahandling import load_ard
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.plotting import rgb, map_shapefile
from deafrica_tools.spatial import xr_rasterize
from deafrica_tools.classification import HiddenPrints
from odc.io.cgroups import get_cpu_quota
from deafrica_tools.classification import collect_training_data

import numpy as np
import xarray as xr
from odc.algo import xr_reproject
from pyproj import Proj, transform
from datacube.utils.geometry import assign_crs
from datacube.testutils.io import rio_slurp_xarray
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.dask import create_local_dask_cluster
pd.set_option('display.max_colwidth', 500)
from odc.algo import xr_geomedian

## datacube init

In [3]:
dc = datacube.Datacube(app='training_data_extraction')

In [3]:
product = "s2_l2a"
measurements = dc.list_measurements()
measurements.loc[product]

Unnamed: 0_level_0,name,dtype,units,nodata,aliases,flags_definition
measurement,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
B01,B01,uint16,1,0.0,"[band_01, coastal_aerosol]",
B02,B02,uint16,1,0.0,"[band_02, blue]",
B03,B03,uint16,1,0.0,"[band_03, green]",
B04,B04,uint16,1,0.0,"[band_04, red]",
B05,B05,uint16,1,0.0,"[band_05, red_edge_1]",
B06,B06,uint16,1,0.0,"[band_06, red_edge_2]",
B07,B07,uint16,1,0.0,"[band_07, red_edge_3]",
B08,B08,uint16,1,0.0,"[band_08, nir, nir_1]",
B8A,B8A,uint16,1,0.0,"[band_8a, nir_narrow, nir_2]",
B09,B09,uint16,1,0.0,"[band_09, water_vapour]",


## Read vector file

In [4]:
crop_vectors = gpd.read_file('../Numeric data models/data/crop_vectors_v4.geojson')

## Define feature layers

### Annual model data

##### Biannual geomedian (extent: yearly)

In [1]:
#QUERY
time_range = ('2019')
measurements = ["blue", "green", "red", "nir", "swir_1", "swir_2", "red_edge_1", "red_edge_2", "red_edge_3", "sdev", "bcdev", "edev"]
resolution = (-10,10)
output_crs = 'EPSG:6933'
query_year_biannual_gm = {'time': time_range, 'measurements': measurements, 'resolution': resolution, 'output_crs': output_crs, 'group_by' : 'solar_day', 'resampling': 'bilinear'}

In [None]:
def feature_layer(ds, era):
    
    # Normalize bands 
    for band in ds.data_vars:
        if band not in ["sdev", "bcdev"]:
            ds[band] = ds[band] / 10000

    # Add indices
    feature_data = calculate_indices(
        ds,
        index=["NDVI", "LAI", "EVI", "SAVI", "NDMI"],
        drop=False,
        normalise=False,
        collection="s2",
    )
    
    # Normalize geomads using log
    #feature_data["sdev"] = -np.log(feature_data["sdev"])
    #feature_data["bcdev"] = -np.log(feature_data["bcdev"])
    #feature_data["edev"] = -np.log(feature_data["edev"])
    
    for band in feature_data.data_vars:
        feature_data = feature_data.rename({band: band + era})
    
    return feature_data


def biannual_gm(query):
    
    # Connect to the datacube
    dc = datacube.Datacube(app='feature_layers')
    
    # load S2 geomedian
    ds = dc.load(product='gm_s2_semiannual', **query)
    
    # load the data
    dss = {"S1_n": ds.isel(time=0),
           "S2_n": ds.isel(time=1), }
    
    # Create features
    epoch1 = feature_layer(dss["S1"], era="_S1")
    epoch2 = feature_layer(dss["S2"], era="_S2")

    result = xr.merge([epoch1, epoch2], compat="override")

    return result.astype(np.float32).squeeze()

##### Biannual geomedian (extent : september 18 - july 19)

In [5]:
#QUERY
time_range = ('2018', '2019')
measurements = ["blue", "green", "red", "nir", "swir_1", "swir_2", "red_edge_1", "red_edge_2", "red_edge_3", "sdev", "bcdev", "edev"]
resolution = (-10,10)
output_crs = 'EPSG:6933'
query_sep_jul_biannual_gm = {'time': time_range, 'measurements': measurements, 'resolution': resolution, 'output_crs': output_crs, 'group_by' : 'solar_day', 'resampling': 'bilinear'}

In [6]:
def feature_layer(ds, era):
    
    # Normalize bands 
    for band in ds.data_vars:
        if band not in ["sdev", "bcdev"]:
            ds[band] = ds[band] / 10000

    # Add indices
    feature_data = calculate_indices(
        ds,
        index=["NDVI", "LAI", "EVI", "SAVI", "NDMI"],
        drop=False,
        normalise=False,
        collection="s2",
    )
    
    # Normalize geomads using log
    #feature_data["sdev"] = -np.log(feature_data["sdev"])
    #feature_data["bcdev"] = -np.log(feature_data["bcdev"])
    #feature_data["edev"] = -np.log(feature_data["edev"])
    
    for band in feature_data.data_vars:
        feature_data = feature_data.rename({band: band + era})
    
    return feature_data


def biannual_gm(query):
    
    # Connect to the datacube
    dc = datacube.Datacube(app='feature_layers')
    
    # load S2 geomedian
    ds = dc.load(product='gm_s2_semiannual', **query)
    
    # load the data
    dss = {"S2_n-1": ds.isel(time=1),
           "S1_n": ds.isel(time=2), }
    
    # Create features
    epoch1 = feature_layer(dss["S1"], era="_S1")
    epoch2 = feature_layer(dss["S2"], era="_S2")

    result = xr.merge([epoch1, epoch2], compat="override")

    return result.astype(np.float32).squeeze()

In [None]:
#QUERY
time_range = ('2018-10-01', '2019-07-31')
measurements = ["blue", "green", "red", "nir", "swir_1", "swir_2", "red_edge_1", "red_edge_2", "red_edge_3"]
resolution = (-10,10)
output_crs = 'EPSG:6933'
query_sep_jul_biannual_custom_gm = {'time': time_range, 'measurements': measurements, 'resolution': resolution, 'output_crs': output_crs, 'group_by' : 'solar_day', 'resampling': 'bilinear'}

In [None]:
def apply_function_over_custom_times(ds, func, func_name, time_ranges):
    output_list = []

    for timelabel, timeslice in time_ranges.items():

        if isinstance(timeslice, slice):
            ds_timeslice = ds.sel(time=timeslice)
        else:
            ds_timeslice = ds.sel(time=timeslice, method="nearest")

        ds_modified = func(ds_timeslice)

        rename_dict = {
            key: f"{key}_{func_name}_{timelabel}" for key in list(ds_modified.keys())
        }

        ds_modified = ds_modified.rename(name_dict=rename_dict)

        if "time" in list(ds_modified.coords):
            ds_modified = ds_modified.reset_coords().drop_vars(["time", "spatial_ref"])

        output_list.append(ds_modified)

    return output_list


def geomedian_with_indices_wrapper(ds):
    indices = ["NDVI", "LAI", "SAVI", "MSAVI", "NDMI"]
    satellite_mission = "s2"

    ds_geomedian = xr_geomedian(ds)

    ds_geomedian = calculate_indices(
        ds_geomedian,
        index=indices,
        drop=False,
        satellite_mission=satellite_mission)

    return ds_geomedian


def custom_gm(query):
    # Connnect to datacube
    dc = datacube.Datacube(app="crop_type_ml")
    
    time_ranges = {
        "_oct_feb_gm": slice("2018-10-01", "2019-02-31"),
        "_mar_jul_gm": slice("2019-03-01", "2019-07-31")}

    ds = load_ard(
        dc=dc,
        products=["s2_l2a"],
        verbose=False,
        **query)

    # Apply geomedian over time ranges and calculate band indices
    s2_geomad_list = apply_function_over_custom_times(ds, geomedian_with_indices_wrapper, "s2", time_ranges)
    ds_list = []
    ds_list.extend(s2_geomad_list)
    ds_final = xr.merge(ds_list)

    return ds_final

##### Custom geomedian (extent: september 18 - July 19)
* To add: other data sources (example : S1 data and climate data)

In [9]:
#QUERY
time_range = ('2018-10-01', '2019-07-31')
measurements = ["blue", "green", "red", "nir", "swir_1", "swir_2", "red_edge_1", "red_edge_2", "red_edge_3"]
resolution = (-10,10)
output_crs = 'EPSG:6933'
query1 = {'time': time_range, 'measurements': measurements, 'resolution': resolution, 'output_crs': output_crs, 'group_by' : 'solar_day', 'resampling': 'bilinear'}

In [None]:
def apply_function_over_custom_times(ds, func, func_name, time_ranges):
    output_list = []

    for timelabel, timeslice in time_ranges.items():

        if isinstance(timeslice, slice):
            ds_timeslice = ds.sel(time=timeslice)
        else:
            ds_timeslice = ds.sel(time=timeslice, method="nearest")

        ds_modified = func(ds_timeslice)

        rename_dict = {
            key: f"{key}_{func_name}_{timelabel}" for key in list(ds_modified.keys())
        }

        ds_modified = ds_modified.rename(name_dict=rename_dict)

        if "time" in list(ds_modified.coords):
            ds_modified = ds_modified.reset_coords().drop_vars(["time", "spatial_ref"])

        output_list.append(ds_modified)

    return output_list


def geomedian_with_indices_wrapper(ds):
    indices = ["NDVI", "LAI", "SAVI", "MSAVI", "NDMI"]
    satellite_mission = "s2"

    ds_geomedian = xr_geomedian(ds)

    ds_geomedian = calculate_indices(
        ds_geomedian,
        index=indices,
        drop=False,
        satellite_mission=satellite_mission)

    return ds_geomedian


def custom_gm(query):
    # Connnect to datacube
    dc = datacube.Datacube(app="crop_type_ml")
    
    time_ranges = {
        "18_1": slice("2018-10-01", "2018-12-31"),
        "19_1": slice("2019-01-01", "2019-03-31"),
        "19_2": slice("2019-04-01", "2019-07-31")}

    ds = load_ard(
        dc=dc,
        products=["s2_l2a"],
        verbose=False,
        **query)

    # Apply geomedian over time ranges and calculate band indices
    s2_geomad_list = apply_function_over_custom_times(ds, geomedian_with_indices_wrapper, "s2", time_ranges)
    ds_list = []
    ds_list.extend(s2_geomad_list)
    ds_final = xr.merge(ds_list)

    return ds_final

### In season models data

##### TS extraction

In [None]:
#QUERY
time_range = ('2018-10', '2019-07')
products = ['s2_l2a']
measurements = ["blue", "green", "red", "red_edge_1", "red_edge_2", "red_edge_3", "nir", "nir_narrow", "swir_1", "swir_2", "mask"]
resolution = (-20,20)
output_crs = 'EPSG:6933'
query_sp_jul_ts = {'time': time_range, 'measurements': measurements, 'resolution': resolution, 'output_crs': output_crs, 'group_by' : 'solar_day'}

In [None]:
def feature_layer(ds, era):
    #add indices
    feature_data = calculate_indices(
        ds,
        index=["NDVI", "EVI", "MSAVI", "NDMI"],
        drop=False,
        normalise=True,
        collection="s2")
    
    for band in feature_data.data_vars:
        feature_data = feature_data.rename({band: band + era})
    
    return feature_data


def ts_indices(query):
    
    #connect to the datacube
    dc = datacube.Datacube(app='feature_layers')
    #load S2
    ds = dc.load(product='s2_l2a', **query, resampling={"mask": "nearest", "*": "bilinear"})
    #ds = load_ard(dc=dc, products=['s2_l2a'], **query)
    ds = ds.groupby('time.week').median(dim='time')
    keys = [ds.week.values[i].astype(str) for i in range(ds.week.values.shape[0])]
    values = [ds.isel(week=i) for i in range(ds.week.values.shape[0])]
    dss = dict(zip(keys, values))
    # load the data
    epochs = [feature_layer(dss[k], era="_{}".format(k)) for k in keys]
    result = xr.merge(epochs, compat="override")

    return result.astype(np.float32).squeeze()

## Save queries

In [6]:
import json
with open('./data/query_year_biannual_gm.json', 'w') as f:
    json.dump(query_sep_jul_biannual_gm, f)

In [14]:
year='2020'
query_year_biannual_gm['time']=year

In [15]:
query_year_biannual_gm

{'time': '2020',
 'measurements': ['blue',
  'green',
  'red',
  'nir',
  'swir_1',
  'swir_2',
  'red_edge_1',
  'red_edge_2',
  'red_edge_3',
  'sdev',
  'bcdev',
  'edev'],
 'resolution': (-10, 10),
 'output_crs': 'EPSG:6933',
 'group_by': 'solar_day',
 'resampling': 'bilinear'}

In [8]:
with open('./data/models_data_queries.json', 'w') as fp:
    json.dump({'year_biannual_gm': query_year_biannual_gm, 'run2': query_sep_jul_biannual_gm}, fp)

## Extracting multispectral data

In [7]:
%%time
#zonal_stats = 'median'
return_coords = True
field = 'field'
ncpus=round(get_cpu_quota())
print('ncpus = '+str(ncpus))

column_names, array_values = collect_training_data(gdf=crop_vectors,
                                                    dc_query=query0,
                                                    ncpus=15,
                                                    clean=False,
                                                    return_coords=return_coords,
                                                    field=field,
                                                    #zonal_stats=zonal_stats,
                                                    fail_threshold=0.0075,
                                                    feature_func=biannual_gm)

ncpus = 4
Collecting training data in parallel mode


  0%|          | 0/2355 [00:00<?, ?it/s]

Percentage of possible fails after run 1 = 0.0 %
Returning data without cleaning
Output shape:  (402236, 37)
CPU times: user 26.5 s, sys: 2.19 s, total: 28.7 s
Wall time: 40min 14s
