In [None]:
import os

import numpy as np
from eolearn.core import EOPatch, EOTask
from fs_s3fs import S3FS
from matplotlib import pyplot as plt
from sentinelhub import BBox, SHConfig

from sg_utils.processing import multiprocess

# Config

In [None]:
config = SHConfig()
config.instance_id = ''
config.aws_access_key_id = ''
config.aws_secret_access_key = ''

In [None]:
filesystem = S3FS(bucket_name='',
                  aws_access_key_id=config.aws_access_key_id,
                  aws_secret_access_key=config.aws_secret_access_key)

# Execute sampling

In [None]:
class SamplePatchlets(EOTask):

    MS4_DEIMOS_SCALING = 4

    def __init__(self, s2_patchlet_size: int, num_samples: int):
        self.s2_patchlet_size = s2_patchlet_size
        self.num_samples = num_samples

    def _calculate_sampled_bbox(self, bbox: BBox, r: int, c: int, s: int, resolution: float) -> BBox:
        return BBox(((bbox.min_x + resolution * c,  bbox.max_y - resolution * (r + s)),
                     (bbox.min_x + resolution * (c + s), bbox.max_y - resolution * r)),
                    bbox.crs)

    def _sample_s2(self, eop: EOPatch, row: int, col: int, size: int, resolution: float = 10):
        sampled_eop = EOPatch(timestamp=eop.timestamp, scalar=eop.scalar, meta_info=eop.meta_info)
        sampled_eop.data['CLP'] = eop.data['CLP'][:, row:row + size, col:col + size, :]
        sampled_eop.mask['CLM'] = eop.mask['CLM'][:, row:row + size, col:col + size, :]
        sampled_eop.mask['IS_DATA'] = eop.mask['IS_DATA'][:, row:row + size, col:col + size, :]
        sampled_eop.data['BANDS'] = eop.data['BANDS'][:, row:row + size, col:col + size, :]
        sampled_eop.scalar_timeless['PATCHLET_LOC'] = np.array([row, col, size])
        sampled_eop.bbox = self._calculate_sampled_bbox(eop.bbox, r=row, c=col, s=size, resolution=resolution)
        sampled_eop.meta_info['size_x'] = size
        sampled_eop.meta_info['size_y'] = size
        return sampled_eop

    def _sample_deimos(self, eop: EOPatch, row: int, col: int, size: int, resolution: float = 2.5):
        sampled_eop = EOPatch(timestamp=eop.timestamp, scalar=eop.scalar, meta_info=eop.meta_info)
        sampled_eop.data['BANDS-DEIMOS'] = eop.data['BANDS-DEIMOS'][:, row:row + size, col:col + size, :]
        sampled_eop.mask['CLM'] = eop.mask['CLM'][:, row:row + size, col:col + size, :]
        sampled_eop.mask['IS_DATA'] = eop.mask['IS_DATA'][:, row:row + size, col:col + size, :]

        sampled_eop.scalar_timeless['PATCHLET_LOC'] = np.array([row, col, size])

        sampled_eop.bbox = self._calculate_sampled_bbox(eop.bbox, r=row, c=col, s=size, resolution=resolution)
        sampled_eop.meta_info['size_x'] = size
        sampled_eop.meta_info['size_y'] = size
        return sampled_eop

    def execute(self, eopatch_s2, eopatch_deimos, buffer=20,  seed=42):
        _, n_rows, n_cols, _ = eopatch_s2.data['BANDS'].shape
        np.random.seed(seed)
        eops_out = []

        for patchlet_num in range(0, self.num_samples):
            row = np.random.randint(buffer, n_rows - self.s2_patchlet_size - buffer)
            col = np.random.randint(buffer, n_cols - self.s2_patchlet_size - buffer)
            sampled_s2 = self._sample_s2(eopatch_s2, row, col, self.s2_patchlet_size)
            sampled_deimos = self._sample_deimos(eopatch_deimos,
                                                 row*self.MS4_DEIMOS_SCALING,
                                                 col*self.MS4_DEIMOS_SCALING,
                                                 self.s2_patchlet_size*self.MS4_DEIMOS_SCALING)
            eops_out.append((sampled_s2, sampled_deimos))
        return eops_out

In [None]:
def sample_patch(eop_path_s2: str, eop_path_deimos,
                 sampled_folder_s2, sampled_folder_deimos,
                 s2_patchlet_size, num_samples, filesystem, buffer=20) -> None:

    task = SamplePatchlets(s2_patchlet_size=s2_patchlet_size, num_samples=num_samples)
    eop_name = os.path.basename(eop_path_s2)
    try:
        eop_s2 = EOPatch.load(eop_path_s2, filesystem=filesystem, lazy_loading=True)
        eop_deimos = EOPatch.load(eop_path_deimos, filesystem=filesystem, lazy_loading=True)
        patchlets = task.execute(eop_s2, eop_deimos, buffer=buffer)
        for i, (patchlet_s2, patchlet_deimos) in enumerate(patchlets):

            patchlet_s2.save(os.path.join(sampled_folder_s2, f'{eop_name}_{i}'),
                             filesystem=filesystem)

            patchlet_deimos.save(os.path.join(sampled_folder_deimos, f'{eop_name}_{i}'),
                                 filesystem=filesystem)

    except KeyError as e:
        print(f'Key error. Could not find key: {e}')
    except ValueError as e:
        print(f'Value error. Value does not exist: {e}')

In [None]:
EOPS_S2 = ''
EOPS_DEIMOS = ''

SAMPLED_S2_PATH = ''
SAMPLED_DEIMOS_3M_PATH = ''


eop_names = filesystem.listdir(EOPS_DEIMOS)  # Both folder have the same EOPatches


def sample_single(eop_name):
    path_s2 = os.path.join(EOPS_S2, eop_name)
    path_deimos = os.path.join(EOPS_DEIMOS, eop_name)

    sample_patch(path_s2, path_deimos, SAMPLED_S2_PATH, SAMPLED_DEIMOS_3M_PATH,
                 s2_patchlet_size=32, num_samples=140, filesystem=filesystem, buffer=20)


multiprocess(sample_single, eop_names, max_workers=16)

# Look at an example

In [None]:
sampled_s2 = EOPatch.load(os.path.join(SAMPLED_S2_PATH, 'eopatch-0000_122'), filesystem=filesystem)
sampled_deimos = EOPatch.load(os.path.join(SAMPLED_DEIMOS_3M_PATH, 'eopatch-0000_122'), filesystem=filesystem)

In [None]:
def _get_closest_timestamp_idx(eop, ref_timestamp):
    closest_idx = 0
    for i, ts in enumerate(eop.timestamp):
        if abs((ts - ref_timestamp).days) < abs((eop.timestamp[closest_idx] - ref_timestamp).days):
            closest_idx = i
    return closest_idx

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(15, 15))
idx_deimos = 1
closest_idx = _get_closest_timestamp_idx(sampled_s2, sampled_deimos.timestamp[idx_deimos])

ax[0].imshow(sampled_s2.data['BANDS'][closest_idx][..., [2, 1, 0]] / 10000*3.5)
ax[1].imshow(sampled_deimos.data['BANDS-DEIMOS'][idx_deimos][..., [2, 1, 0]] / 12000)