In [None]:
import ee

# initialize EE    
try:
    ee.Initialize()
except:
    ee.Authenticate()
    ee.Initialize()

import geemap
import pandas as pd
import numpy as np
from retry import retry

In [None]:
aoi = "Kenya"
start = 2015
end = 2020
collection = 'CCI'   # 'CCI', GFC'

grid_sizes = [2000, 5000, 20000] # in metres
nr_of_runs_per_grid = 10              # nr of times we sample on a specific grid
random_seed = 7                       # just for reproducability

In [None]:
#hansen = ee.Image("UMD/hansen/global_forest_change_2021_v1_9").select('lossyear').unmask(0)
#change = hansen.gt(ee.Number(start).subtract(2000)).And(hansen.lt(ee.Number(end).subtract(2000)))
#
#Map = geemap.Map(center=[9, 42], zoom=6)
#Map.add_basemap('HYBRID')
#Map.addLayer(change.randomVisualizer())
#Map

In [None]:
def get_sampling_errors(aoi, start, end, collection):    
    
    gaul = ee.FeatureCollection("FAO/GAUL/2015/level1")
    #if aoi is not instance(ee.FeatureCollection):
    aoi = gaul.filter(ee.Filter.eq("ADM0_NAME", aoi)).union()
    
    if collection == 'CCI':

        lc = ee.ImageCollection("users/amitghosh/sdg_module/esa/cci_landcover")
        lc_start = lc.filter(ee.Filter.eq("system:index", f'{start}')).first()
        lc_end = lc.filter(ee.Filter.eq("system:index", f'{end}')).first()
        change = lc_start.neq(lc_end).clip(aoi)
        scale = 300

    elif collection == 'GFC':

        hansen = ee.Image("UMD/hansen/global_forest_change_2021_v1_9").select('lossyear').unmask(0)
        change = hansen.gt(ee.Number(start).subtract(2000)).And(hansen.lt(ee.Number(end).subtract(2000)))
        scale = 30

    # create random seeds
    np.random.seed(random_seed)
    seeds = np.random.random(nr_of_runs_per_grid)
    seeds = list(np.round(np.multiply(seeds, 100), 0))
    
    # -----------------------------------------------------------------
    # getting total area of change
    total_image = change.addBands(ee.Image(1)).multiply(ee.Image.pixelArea()).rename(['total_change', 'total_area'])
    areas = total_image.reduceRegion(**{
            'reducer': ee.Reducer.sum(),
            'geometry': aoi,
            'scale': scale,
            'maxPixels': 1e14
        })
    proportional_change_map = ee.Number(areas.get('total_change')).divide(ee.Number(areas.get('total_area')))
    #print(proportional_change_map.getInfo())
    
    # -----------------------------------------------------------------
    # nested function for getting proportional change per grid
    def get_grid_sample_error(grid):
        
        # set pixel size
        proj = ee.Projection("EPSG:3857").atScale(grid)
        
        # get total sample size
        sample_size = ee.Image(1).rename('sample_size').reproject(proj).reduceRegion(**{
            'reducer': ee.Reducer.sum(),
            'geometry': aoi,
            'scale': grid,
            'maxPixels': 1e14
        }).get('sample_size')
        
        
        # -----------------------------------------------------------------
        # nested function for getting proportional change per seed and grid
        def get_sampled_proportional_change(seed, proj):

            # create a subsample of our change image
            cells = ee.Image.random(seed).multiply(1000000).int().reproject(proj)
            random = ee.Image.random(seed).multiply(1000000).int()
            maximum = cells.addBands(random).reduceConnectedComponents(ee.Reducer.max())
            points = random.eq(maximum).selfMask().clip(aoi).reproject(proj.atScale(scale))

            # create a stack with change and total pixels as 1
            stack = (change.updateMask(points)          # masked sample change
                .addBands(points)                       # all samples
                .multiply(ee.Image.pixelArea())         # multiply both for pixel area
                .rename(['sampled_change', 'sampled_area'])
            )

            # sum them up
            areas = stack.reduceRegion(**{
                'reducer': ee.Reducer.sum(),
                'geometry': aoi,
                'scale': scale,
                'maxPixels': 1e14
            })

            # calculate proportional change to entire sampled area
            proportional_change_sampled = ee.Number(areas.get('sampled_change')).divide(ee.Number(areas.get('sampled_area')))
            #print(proportional_change_sampled.getInfo())
            # return absolute difference from expacted value
            return ee.Number(
                ((proportional_change_sampled.subtract(proportional_change_map)).abs())                
                .divide(proportional_change_map)
                .multiply(100))
        # -----------------------------------------------------------------
        
        # get sample error mean and stddev
        sampling_iter = ee.List(seeds).map(lambda x: get_sampled_proportional_change(x, proj))
        
        #sampling_iter = ee.List(get_sampled_proportional_change(7, proj))
        sampling_iter = ee.List(
            ee.Dictionary(ee.List(sampling_iter).reduce(ee.Reducer.mean().combine(ee.Reducer.stdDev(), None, True))).values()
        )
        #print(ee.List(sampling_iter).getInfo())
        #sampling_iter = ee.Number(
        #    ee.List(sampling_iter).reduce(ee.Reducer.mean()) #.values()
        #)
        
        # add to a dict of all grids
        return ee.Dictionary({'stats': sampling_iter, 'size': sample_size})
    
    # we map over all different grid sizes
    results = ee.List(grid_sizes).map(lambda x: get_grid_sample_error(x))
    #get_grid_sample_error(2000)
    # and create final dataframe
    d, dfs = {}, []
    for idx, r in enumerate(results.getInfo()):
        d['idx'] = idx
        d['grid_size'] = grid_sizes[idx]
        d['sample_size'] = r['size']
        d['mean'] = r['stats'][0]
        d['stddev'] = r['stats'][1]
        dfs.append(pd.DataFrame([d]))

    return pd.concat(dfs)

In [None]:
#from matplotlib import pyplot as plt
df = get_sampling_errors(aoi, start, end, collection)
#df['cv'] = df['stddev'] / df['mean']
display(df)

#fig, ax = plt.subplots()
ax = df.plot(x='sample_size', y='mean', kind='scatter')
#ax.set_xscale('log')
ax.ticklabel_format(useOffset=False, style='plain')