Warning! This notebook requires atleast 90GB of RAM

In [None]:
import os

In [None]:
import numpy as np
import pandas as pd
from dateutil.parser import parse
from eolearn.core import EOPatch
from fs_s3fs import S3FS
from sentinelhub import SHConfig

In [None]:
from sg_utils.processing import multiprocess

In [None]:
config = SHConfig()
config.aws_access_key_id = ''
config.aws_secret_access_key = ''

In [None]:
filesystem = S3FS(bucket_name='',
                  aws_access_key_id=config.aws_access_key_id,
                  aws_secret_access_key=config.aws_secret_access_key)

In [None]:
DIR_SAMPLED_S2 = ''
DIR_SAMPLED_DEIMOS_1M = ''

In [None]:
metadata_ms4 = pd.read_parquet(filesystem.openbin('metadata/deimos_ms4_metadata.pq'))
metadata_ms4['Country'] = metadata_ms4.Projection_OGCWKT.apply(lambda x: 'Lithuania' if '34N' in x else 'Cyprus') # ! Warning, doesn't  ! 
timestamp_country_map = {ts: country for ts, country in metadata_ms4[['START_TIME', 'Country']].values}

In [None]:
MAX_CC = .05
N_DAYS = 60

S2_FACTOR = 10000.0

In [None]:
def normalize_deimos(eop, pan=False):
    bname = 'PANSHARPENED-DEIMOS' if pan else 'BANDS-DEIMOS'

    bands = eop.data[bname]
    for i, ts in enumerate(eop.timestamp):
        for chnl in range(0, 4):

            bands = bands.astype(np.float32)
            median = float(eop.meta_info['metadata'][ts]['MS4']['CLM_RADIANCE_BAND_STATS_PANSHARPENED']
                           [f'STX_CLM_RADIANCE_MEDIAN_PANSHARPENED_{chnl+1}'])
            std = float(eop.meta_info['metadata'][ts]['MS4']['CLM_RADIANCE_BAND_STATS_PANSHARPENED']
                        [f'STX_CLM_RADIANCE_STDV_PANSHARPENED_{chnl+1}'])

            gain = float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_GAIN_{chnl+1}'])
            bias = float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_BIAS_{chnl+1}'])

            bands[i, ..., chnl] = ((bands[i, ..., chnl]*gain + bias) - median) / std

    eop.data[bname] = bands
    return eop


def _valid_idxs_deimos(eop, max_cc, clm_band=0, threshold=95):

    idxs = []
    for i, ts in enumerate(eop.timestamp):

        float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_GAIN_{clm_band+1}'])
        float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_BIAS_{clm_band+1}'])

        # cloud_coverage = ((eop.data['BANDS-DEIMOS'][i, ..., clm_band]*gain + bias) > threshold).mean()
        cloud_coverage = eop.mask['CLM'][i].mean()
        if cloud_coverage <= max_cc and (eop.mask['IS_DATA'].mean() == 1):
            idxs.append(i)
    return idxs


def _filter_cloudy_s2(eop, max_cc):
    idxs = []
    for i, _ in enumerate(eop.timestamp):
        if (eop.mask['CLM'][i, ...].mean() <= max_cc) and (eop.mask['IS_DATA'].mean() == 1):
            idxs.append(i)
    eop.data['BANDS'] = eop.data['BANDS'][idxs, ...]
    eop.data['CLP'] = eop.data['CLP'][idxs, ...]
    eop.mask['CLM'] = eop.mask['CLM'][idxs, ...]
    eop.mask['IS_DATA'] = eop.mask['IS_DATA'][idxs, ...]
    eop.scalar['NORM_FACTORS'] = eop.scalar['NORM_FACTORS'][idxs, ...]

    eop.timestamp = list(np.array(eop.timestamp)[idxs])
    return eop


def _get_closest_timestamp_idx(eop, ref_timestamp):
    closest_idx = 0
    for i, ts in enumerate(eop.timestamp):
        if abs((ts - ref_timestamp).days) < abs((eop.timestamp[closest_idx] - ref_timestamp).days):
            closest_idx = i
    return closest_idx


def _idxs_within_n_days(eop, ref_ts, n_days=60):
    idxs = []
    for i, ts in enumerate(eop.timestamp):
        if 0 < (ref_ts - ts).days < 60:
            idxs.append(i)
    return idxs

In [None]:
DIR_SAMPLED_S2 = ''
DIR_SAMPLED_DEIMOS = ''

In [None]:
def construct_features_labels(eop_name):
    features, labels, s2_timestamps = [], [], []
    try:

        s2 = EOPatch.load(os.path.join(DIR_SAMPLED_S2, eop_name), filesystem=filesystem, lazy_loading=True)
        deimos = EOPatch.load(os.path.join(DIR_SAMPLED_DEIMOS, eop_name), filesystem=filesystem, lazy_loading=True)
        s2 = _filter_cloudy_s2(s2, MAX_CC)
        non_cloudy_idxs = _valid_idxs_deimos(deimos, MAX_CC)
        timestamps = np.array(deimos.timestamp)[non_cloudy_idxs]

        deimos_data = normalize_deimos(deimos, pan=False).data['BANDS-DEIMOS'][non_cloudy_idxs, ...]
        for ts, deim in zip(timestamps, deimos_data):

            s2_idxs = _idxs_within_n_days(s2, ts, N_DAYS)

            s2_timestamps.append(np.array(s2.timestamp)[s2_idxs])
            features.append(s2.data['BANDS'][s2_idxs, ...] / S2_FACTOR)
            labels.append(deim)

        return {'features': features, 'labels': labels,
                'patchlet_name': [eop_name]*len(features),
                'timestamps_deimos': timestamps,
                'timestamps_s2': s2_timestamps,
                'countries': [timestamp_country_map[ts] for ts in timestamps]
                }
    except Exception as e:
        print(f"Failed for {eop_name} with error: {e}")
        return {'features': [], 'labels': [],
                'patchlet_name': [],
                'timestamps_deimos': [],
                'timestamps_s2': [],
                'countries': []
                }

In [None]:
#sampled_list = filesystem.listdir(DIR_SAMPLED_S2)

In [None]:
results = multiprocess(construct_features_labels, sampled_list, max_workers=47)

In [None]:
filesystem.makedirs('')

In [None]:
def save_npz(result):
    info = []
    for i, (feats, labels, patch_name, ts_deim, ts_s2, ts_country) in enumerate(zip(result['features'], result['labels'],
                                                                                    result['patchlet_name'],
                                                                                    result['timestamps_deimos'],
                                                                                    result['timestamps_s2'],
                                                                                    result['countries'])):

        if len(feats) == 0:
            continue

        filename = f'data_{patch_name}_{i}.npz'
        info.append(dict(patchlet=patch_name, eopatch=patch_name.split('_')[0],
                         countries=ts_country, timestamp_deimos=ts_deim,
                         timestamps_s2=ts_s2,
                         singleton_npz_filename=filename))
        with filesystem.openbin(f'/{filename}', 'wb') as f:
            np.savez(f, features=feats,
                     labels=labels,
                     patchlet=patch_name,
                     timetamps_deimos=ts_deim,
                     timestamps_s2=ts_s2,
                     countries=ts_country)
    return pd.DataFrame(info)

In [None]:
dfs = multiprocess(save_npz, results, max_workers=47)

In [None]:
npz_files = filesystem.listdir('')

In [None]:
def create_info(filename):
    npz = np.load(filesystem.openbin(f'/{filename}'), allow_pickle=True)
    patchlet_name = npz['patchlet']
    eopatch_name = str(patchlet_name).split('_')[0]
    timestamp_deimos = npz['timetamps_deimos']
    timestamps_s2 = npz['timestamps_s2']
    countries = npz['countries']
    return dict(patchlet=patchlet_name,
                eopatch=eopatch_name,
                countries=countries,
                timestamp_deimos=timestamp_deimos,
                timestamps_s2=timestamps_s2,
                singleton_npz_filename=filename)

In [None]:
dicts = multiprocess(create_info, npz_files, max_workers=16)

In [None]:
df_concated = pd.concat(dfs)

In [None]:
df = df_concated

In [None]:
df['timestamps_s2_str'] = df.timestamps_s2.apply(lambda x: '|'.join([str(s) for s in x]))

In [None]:
df.timestamps_s2_str = df.timestamps_s2_str.astype(str)

In [None]:
df['num_tstamps'] = df.timestamps_s2.apply(lambda x: len(x))

In [None]:
df[['patchlet', 'eopatch', 'countries', 'timestamp_deimos',
    'singleton_npz_filename', 'timestamps_s2_str']].dtypes

In [None]:
df.timestamp_deimos = df.timestamp_deimos.apply(lambda x: parse(str(x)))

In [None]:
df.countries = df.countries.astype(str)
df.patchlet = df.patchlet.astype(str)

In [None]:
df

In [None]:
with filesystem.openbin('metadata/npz_info_small.pq', 'wb') as f:
    df[['patchlet', 'eopatch', 'countries', 'timestamp_deimos',
       'singleton_npz_filename', 'timestamps_s2_str', 'num_tstamps']].to_parquet(f)