In [None]:
import os
from datetime import datetime

In [None]:
import numpy as np
import pandas as pd
from eolearn.core import FeatureType
from eolearn.io import ImportFromTiff
from fs_s3fs import S3FS
from sentinelhub import SHConfig
from tqdm.auto import tqdm

In [None]:
AWS_ACCESS_KEY_ID = ''
AWS_SECRET_ACCESS_KEY = ''
BUCKET_NAME = ''
LOC_ON_BUCKET = ''

In [None]:
filesystem = S3FS(bucket_name=BUCKET_NAME,
                  aws_access_key_id=AWS_ACCESS_KEY_ID,
                  aws_secret_access_key=AWS_SECRET_ACCESS_KEY)

In [None]:
config = SHConfig()
config.aws_access_key_id = AWS_ACCESS_KEY_ID
config.aws_secret_access_key = AWS_SECRET_ACCESS_KEY

In [None]:
BAND_GAIN = {3: 0.006800104616, 2: 0.011123248049, 1: 0.013184818227, 0:  0.014307912429}
BAND_BIAS = {3: -0.00680010461, 2: -0.01112324804, 1: -0.01318481822, 0: -0.01430791242}
PAN_GAIN = 0.011354020831
PAN_BIAS = -0.01135402083

In [None]:
MS4_THRESHOLD = 100
PAN_THRESHOLD = 100

In [None]:
metadata_ms4 = pd.read_parquet(filesystem.openbin('metadata/deimos_ms4_metadata.pq'))
metadata_pan = pd.read_parquet(filesystem.openbin('metadata/deimos_pan_metadata.pq'))

In [None]:
metadata_ms4.columns

In [None]:
CLM_MASK_BAND = 0  # Blue

In [None]:
stats_ms4 = []
stats_pan = []


def calculate_stats(data, sensing_time):
    median = np.median(data, axis=0)
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)

    stats = {'sensing_time': sensing_time}
    for i, (bmedian, bstd, bmean) in enumerate(zip(median, std, mean)):
        band_stats = {f'STX_CLM_MEDIAN_{i+1}': bmedian,
                      f'STX_CLM_STDV_{i+1}': bstd,
                      f'STX_CLM_MEAN_{i+1}': bmean}

        stats = {**stats, **band_stats}

    return stats


def calculate_stats_radiance(data, sensing_time):
    _, chnls = data.shape
    if chnls == 1:
        data = data*PAN_GAIN + PAN_BIAS
    elif chnls == 4:
        data = np.add(np.multiply(data, list(BAND_GAIN.values())), list(BAND_BIAS.values()))
    else:
        raise ValueError("Wrong number of channels.")

    median = np.median(data, axis=0)
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)

    stats = {'sensing_time': sensing_time}
    for i, (bmedian, bstd, bmean) in enumerate(zip(median, std, mean)):
        band_stats = {f'STX_CLM_RADIANCE_MEDIAN_{i+1}': bmedian,
                      f'STX_CLM_RADIANCE_STDV_{i+1}': bstd,
                      f'STX_CLM_RADIANCE_MEAN_{i+1}': bmean}
        stats = {**stats, **band_stats}

    return stats


def calculate_cloudfree_stats(tile_folder, config, clm_mask_band, band_gain, band_bias, ms4_thr, pan_gain, pan_bias, pan_thr, calculate_stats_func):
    try:
        eop_ms4 = ImportFromTiff((FeatureType.DATA, 'MS4'), folder=tile_folder, config=config).execute(
            filename=['B04.tiff', 'B03.tiff', 'B02.tiff', 'B01.tiff'])
        eop = ImportFromTiff((FeatureType.DATA, 'PAN'), folder=tile_folder,
                             config=config).execute(eop_ms4, filename='PAN.tiff')
        eop.timestamp = [datetime.strptime(sensing_time, '%Y-%m-%d_%H-%M-%S')]
        data = eop.data['MS4']

        mask = (data[..., clm_mask_band]*band_gain[clm_mask_band] + band_bias[clm_mask_band]) > MS4_THRESHOLD
        mask = mask.astype(np.float32)
        mask[data[..., 0] == 0] = np.nan
        coverage = mask[mask == 1].sum() / np.count_nonzero(~np.isnan(mask))

        data_masked = data[mask == 0, :]
        # TODO: Why is this here... Serves me right for not commenting.
        if coverage > 0.1:
            stats_ms4 = calculate_stats_func(data_masked, eop.timestamp[0])
        else:
            stats_ms4 = calculate_stats_func(data[data[..., 0] > 0, :], eop.timestamp[0])

        data = eop.data['PAN'].squeeze()
        mask = ((eop.data['PAN']*PAN_GAIN + PAN_BIAS) > PAN_THRESHOLD).squeeze()
        data_masked = data[mask]
        mask = mask.astype(np.float32)
        mask[data == 0] = np.nan
        data_masked = data[mask == 0]

        if coverage > 0.1:
            stats_pan = calculate_stats_func(np.expand_dims(data_masked, -1), eop.timestamp[0])
        else:
            stats_pan = calculate_stats_func(np.expand_dims(data[data > 0], -1),  eop.timestamp[0])
        return stats_ms4, stats_pan

    except Exception as e:
        print(f'Failed for sensing time {sensing_time} with error: {e}')
        return None, None


results = []
for sensing_time in tqdm(filesystem.listdir(LOC_ON_BUCKET)):
    results.append(calculate_cloudfree_stats(tile_folder=os.path.join('s3://', BUCKET_NAME, LOC_ON_BUCKET, sensing_time),
                                             config=config,
                                             clm_mask_band=CLM_MASK_BAND,
                                             band_gain=BAND_GAIN,
                                             band_bias=BAND_BIAS,
                                             ms4_thr=MS4_THRESHOLD,
                                             pan_gain=PAN_GAIN,
                                             pan_bias=PAN_BIAS,
                                             pan_thr=PAN_THRESHOLD,
                                             calculate_stats_func=calculate_stats
                                             ))

In [None]:
stats_ms4, stats_pan = list(zip(*results))

In [None]:
stats_ms4

In [None]:
ms4 = pd.DataFrame([x for x in stats_ms4 if x is not None])
ms4.sensing_time = ms4.sensing_time.apply(lambda x: str(x).replace(' ', 'T'))
metadata_ms4_stats = metadata_ms4.set_index('START_TIME').join(ms4.set_index('sensing_time')).reset_index()
with filesystem.openbin('metadata/deimos_ms4_metadata.pq', 'wb') as f:
    metadata_ms4_stats.to_parquet(f)

In [None]:
pan = pd.DataFrame([x for x in stats_pan if x is not None])
pan.sensing_time = pan.sensing_time.apply(lambda x: str(x).replace(' ', 'T'))
pan_stats = metadata_pan.set_index('START_TIME').join(pan.set_index('sensing_time')).reset_index()
with filesystem.openbin('metadata/deimos_pan_metadata.pq', 'wb') as f:
    pan_stats.to_parquet(f)