# Sample patchlets from downloaded eopatches

This notebook takes an eopatch as input and samples smaller patchlets of givven size (e.g. 256x256) to be used for training/validation/test.

Modify the parameters of the `SamplePatchlets` task to influence the behaviour of the sampled aptchlets distribution.


In [None]:
from abc import abstractmethod
import os
import boto3
import fs
from fs_s3fs import S3FS

from datetime import datetime, timedelta
import dateutil
import rasterio
import pandas as pd
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from concurrent.futures import ProcessPoolExecutor

from sentinelhub import CRS, BBox
from s2cloudless import S2PixelCloudDetector
from eolearn.core import FeatureType, EOPatch, EOTask, EOWorkflow, SaveTask, OverwritePermission, EOExecutor, FeatureTypeSet

In [None]:
def multiprocess(process_fun: Callable, arguments: List[Any], max_workers: int = 4) -> List[Any]:
    """
    Executes multiprocessing with tqdm.
    Parameters
    ----------
    process_fun: A function that processes a single item.
    arguments: Arguments with which te function is called.
    max_workers: Max workers for the process pool executor.

    Returns A list of results.
    -------

    """
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        results = list(tqdm(executor.map(process_fun, arguments), total=len(arguments)))
    return results

### Define filesystem and eopatches location 

In [None]:
filesystem = S3FS("bucket-name", 
              aws_access_key_id="",
              aws_secret_access_key="",
              region="eu-central-1") 

In [None]:
EOPATCHES_LOCATION = 'data/Lithuania/eopatches/2019/'

In [None]:
OUT_PATH = 'data/Lithuania/patchlets/2019'

In [None]:
EOPATCHES_PATHS = [os.path.join(EOPATCHES_LOCATION, eop_name) for eop_name in filesystem.listdir(EOPATCHES_LOCATION)]

### Define the sampling EOTask.

Some things are quite hardcoded, so use with caution anywhere else than here. 

In [None]:
class SamplePatchlets(EOTask):
    """
    The task samples patchlets of a certain size in a given timeless feature different from no valid data value with
    a certain percentage of valid pixels.
    """

    def __init__(self, feature, buffer, patch_size, num_samples, max_retries, fraction_valid=0.2, no_data_value=0,
                 sample_features=..., filter_cloudy=True):
        """ Task to sample pixels from a reference timeless raster mask, excluding a no valid data value

        :param feature:  Reference feature used to select points to be sampled
        :param fraction: Fraction of valid points to be sampled
        :param no_data_value: Value of non-valid points to be ignored
        """
        self.feature_type, self.feature_name, self.new_feature_name = next(
            self._parse_features(feature, new_names=True,
                                 default_feature_type=FeatureType.MASK_TIMELESS,
                                 allowed_feature_types={FeatureType.MASK_TIMELESS},
                                 rename_function='{}_SAMPLED'.format)())
        self.max_retries = max_retries
        self.fraction = fraction_valid
        self.no_data_value = no_data_value
        self.sample_features = self._parse_features(sample_features)
        self.num_samples = num_samples
        self.patch_size = patch_size
        self.buffer = buffer
        self.s2_cd = S2PixelCloudDetector(average_over=24)
    
        
    def _get_clear_indices(self, clp, vld):
        idxs = [] 
        for i, (probas, vld_mask) in enumerate(zip(clp, vld)):
            if self.s2_cd.get_mask_from_prob(probas/255.0).sum()/np.prod(probas.shape[1:3]) < 0.05:
                if np.sum(~vld_mask.astype(bool)) == 0:               
                    idxs.append(i)
        return idxs

    def execute(self, eopatch, seed=None):
        timestamps = np.array(eopatch.timestamp)
        mask = eopatch[self.feature_type][self.feature_name].squeeze()
        n_rows, n_cols = mask.shape

        if mask.ndim != 2:
            raise ValueError('Invalid shape of sampling reference map.')

        np.random.seed(seed)
        eops_out = []

        for patchlet_num in range(0, self.num_samples):
            ratio = 0.0
            retry_count = 0
            new_eopatch = EOPatch(timestamp=eopatch.timestamp)
            while ratio < self.fraction and retry_count < self.max_retries:
                row = np.random.randint(self.buffer, n_rows-self.patch_size-self.buffer)
                col = np.random.randint(self.buffer, n_cols-self.patch_size-self.buffer)
                patchlet = mask[row:row+self.patch_size, col:col+self.patch_size]
                ratio = np.sum(patchlet != self.no_data_value) / self.patch_size**2
                retry_count += 1

            if retry_count == self.max_retries:
                print(f'Could not determine an area with good enough ratio of valid sampled pixels for '
                               f'patchlet number: {patchlet_num}')
                continue
            for feature_type, feature_name in self.sample_features(eopatch):
                if feature_type in FeatureTypeSet.RASTER_TYPES.intersection(FeatureTypeSet.SPATIAL_TYPES):
                    feature_data = eopatch[feature_type][feature_name]
                    if feature_type.is_time_dependent():
                        sampled_data = feature_data[:, row:row+self.patch_size, col:col+self.patch_size, :]
                        clp_patchlet = eopatch.data['CLP'][:, row:row+self.patch_size, col:col+self.patch_size, :]
                        valid_patchlet = eopatch.mask['IS_DATA'][:, row:row+self.patch_size, col:col+self.patch_size, :]
                        idxs = self._get_clear_indices(clp_patchlet, valid_patchlet)
                        sampled_data = sampled_data[idxs]
                        new_eopatch.timestamp = list(timestamps[idxs])
                        
                    else:
                        sampled_data = feature_data[row:row+self.patch_size, col:col+self.patch_size, :]

                    # here a copy of sampled array is returned and assigned to feature of a shallow copy
                    # orig_eopatch[feature_type][feature_name] remains unmodified
                    patchlet_loc = np.array([row, col, self.patch_size])
                    new_eopatch[feature_type][f'{feature_name}'] = sampled_data
                    new_eopatch[FeatureType.SCALAR_TIMELESS][f'PATCHLET_LOC'] = patchlet_loc
                    new_eopatch[FeatureType.MASK_TIMELESS][f'EXTENT'] = eopatch.mask_timeless['EXTENT'][row:row+self.patch_size, col:col+self.patch_size]
                    new_eopatch[FeatureType.MASK_TIMELESS][f'BOUNDARY'] = eopatch.mask_timeless['BOUNDARY'][row:row+self.patch_size, col:col+self.patch_size]
                    new_eopatch[FeatureType.DATA_TIMELESS][f'DISTANCE'] = eopatch.data_timeless['DISTANCE'][row:row+self.patch_size, col:col+self.patch_size]
                    eops_out.append(new_eopatch)
        return eops_out

In [None]:
task = SamplePatchlets(feature=(FeatureType.MASK_TIMELESS, 'EXTENT'), buffer=0, patch_size=256, num_samples=10, max_retries=10, fraction_valid=0.4, sample_features=(FeatureType.DATA, 'BANDS'))

In [None]:
def create_and_save_patchlets(eop_path): 
    eop_name = os.path.basename(eop_path)
    print(f'Processing eop: {eop_name}')
    try: 
        patchlets = task.execute(EOPatch.load(eop_path, filesystem=filesystem, lazy_loading=True))
        for i, patchlet in enumerate(patchlets): 
            patchlet.save(os.path.join(OUT_PATH, f'{eop_name}_{i}'), filesystem=filesystem)
    except KeyError as e: 
        print(f'Key error. Could not find key: {e}')
    except ValueError as e: 
        print(f'Value error. Value does not exist: {e}')
        

In [None]:
# If there is an error, do not process the  same eopatch twice. Hopefully should not be needed. 

In [None]:
processed_eops = set([os.path.join(EOPATCHES_LOCATION, x.split('_')[0]) for x in filesystem.listdir(OUT_PATH)])

In [None]:
unprocessed_eops = set(EOPATCHES_PATHS).difference(processed_eops) 

In [None]:
multiprocess(create_and_save_patchlets, list(unprocessed_eops), max_workers=36)