In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os

import numpy as np
import pandas as pd
import yaml
from eolearn.core import EOPatch, OverwritePermission
from fs_s3fs import S3FS
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

import torch
import wandb
from hrnet.src.predict import Model

## 1.0 Configuration

In [None]:
# ! wandb login 

In [None]:
aws_access_key_id = ''
aws_secret_access_key = ''

filesystem = S3FS(
    bucket_name='',
    aws_access_key_id=aws_access_key_id,
    aws_secret_access_key=aws_secret_access_key, region='eu-central-1')


MODEL_LOCATION = 'LOCAL'  # If 'LOCAL' it will be loaded from local wandb storage,  if 'WANDB' from online storage
MODEL_NAME = ''
MODEL_PREFIX = ''
MATCHES_S2 = True
LOCATION = f'wandb/latest-run/files/'

In [None]:
EOP_COUNTRIES_PQ = f'{DIONE_DIR}/eop-countries_overlapped.pq'

In [None]:
if not os.path.exists(EOP_COUNTRIES_PQ):
    eops_countries = []
    for eopfname in filesystem.listdir(''):
        eop = EOPatch.load(os.path.join('',
                           eopfname), filesystem=filesystem, lazy_loading=True)
        eops_countries.append({'country': 'Lithuania' if str(eop.bbox.crs) == 'EPSG:32634' else 'Cyprus',
                               'eopatch': eopfname})
        pd.DataFrame(eops_countries).to_parquet(f'{DIONE_DIR}/eop-countries_overlapped.pq')

In [None]:
checkpoint_filename = 'HRNet.pth'

In [None]:
if MODEL_LOCATION == 'WANDB':
    model_checkpoint = wandb.restore(checkpoint_filename, run_path=LOCATION, replace=True)
    model_checkpoint = open(checkpoint_filename, 'rb')
    model_config_yaml = yaml.load(wandb.restore('config.yaml', run_path=LOCATION, replace=True))
elif MODEL_LOCATION == 'LOCAL':
    model_checkpoint = os.path.join(LOCATION, checkpoint_filename)
    model_config_yaml = yaml.load(open(os.path.join(LOCATION, 'config.yaml')))

    assert os.path.isfile(model_checkpoint)

In [None]:
config = {k: v['value'] for k, v in model_config_yaml.items() if 'wandb' not in k}

In [None]:
country_norm_df = pd.read_parquet(filesystem.openbin('metadata/s2_norm_per_country.pq'))

norm_deimos = {k: v for k, v in np.load(filesystem.openbin('metadata/deimos_min_max_norm.npz')).items()}
norm_s2 = {k: v for k, v in np.load(filesystem.openbin('metadata/s2_min_max_norm.npz')).items()}

data_df = pd.read_parquet(filesystem.openbin('metadata/npz_info_small.pq'))
data_df.reset_index(inplace=True)

In [None]:
model = Model(config)
model.load_checkpoint(checkpoint_file=model_checkpoint)

# Predict on EOPatches

In [None]:
def _filter_cloudy_s2(eop, max_cc):
    idxs = []
    for i, _ in enumerate(eop.timestamp):
        if (eop.mask['CLM'][i, ...].mean() <= max_cc) and (eop.mask['IS_DATA'].mean() == 1):
            idxs.append(i)
    eop.data['BANDS'] = eop.data['BANDS'][idxs, ...]
    eop.data['CLP'] = eop.data['CLP'][idxs, ...]
    eop.mask['CLM'] = eop.mask['CLM'][idxs, ...]
    eop.mask['IS_DATA'] = eop.mask['IS_DATA'][idxs, ...]
    eop.scalar['NORM_FACTORS'] = eop.scalar['NORM_FACTORS'][idxs, ...]

    eop.timestamp = list(np.array(eop.timestamp)[idxs])
    return eop


def _timestamps_within_date(timestamps, start_date, end_date):
    return [i for i, ts in enumerate(timestamps) if ts >= start_date and ts < end_date]


def predict_sr_images(eopatch_name: str,
                      model: Model,
                      model_prefix: str,
                      scale_factor: int = 4,
                      filesystem: S3FS = None,
                      normalize: bool = True,
                      country_norm_df: pd.DataFrame = None,
                      norm_s2_npz: np.lib.npyio.NpzFile = None,
                      max_cc: float = 0.05,
                      n_views: int = 8,
                      padding: str = 'zeros'):
    """ Predict an SR image at the EOPatch level for all timeframes available  """
    assert padding in ['zeros', 'repeat']

    eopatch = EOPatch.load(eopatch_name,
                           filesystem=filesystem,
                           lazy_loading=True)
    noncloudy = _filter_cloudy_s2(eopatch, max_cc=max_cc)
#     ts_idxs = _timestamps_within_date(noncloudy.timestamp, start_date, end_date)
    features = noncloudy.data['BANDS'] / 10000
#     filtered_ts = [eopatch.timestamp[tsi] for tsi in ts_idxs]

    if normalize:
        country = 'Lithuania' if str(eopatch.bbox.crs) == 'EPSG:32634' else 'Cyprus'  # WARNING EXTREMLY HACKY HACKY
        country_stats = country_norm_df[country_norm_df.country == str(country)]

        norm_median = country_stats[['median_0', 'median_1', 'median_2', 'median_3']].values
        norm_std = country_stats[['std_0', 'std_1', 'std_2', 'std_3']].values

        features = (features - norm_median) / norm_std

        s2_p1 = norm_s2_npz['p1']
        s2_p99 = norm_s2_npz['p99']

        features = (features - s2_p1) / (s2_p99 - s2_p1)

    n_frames, height, width, nch = features.shape
    super_resolved = np.empty((n_frames,
                               height*scale_factor,
                               width*scale_factor,
                               nch), dtype=np.uint16)
    actual_n_views = np.array([np.min([n_views, nfr+1])
                               for nfr in np.arange(n_frames)]).astype(np.uint8)

    for nfr in np.arange(n_frames):
        inarr = None
        alphas = None
        if nfr < n_views:
            inarr = np.concatenate([features[:nfr+1],
                                    np.zeros((n_views-nfr-1, height, width, nch),
                                             dtype=np.float32)],
                                   axis=0)
            alphas = np.zeros(n_views, dtype=np.uint8)
            alphas[:nfr+1] = 1
        else:
            inarr = features[nfr-n_views+1:nfr+1]
            alphas = np.ones(n_views, dtype=np.uint8)

        # CxTxHxW
        inarr = np.moveaxis(inarr, -1, 1)

#         np.testing.assert_array_equal(inarr[nfr if nfr < n_views else -1], features[nfr])

        sr = model({'lr': torch.from_numpy(inarr.copy()),
                    'alphas': torch.from_numpy(alphas),
                    'name': eopatch_name})

        # channels back to last
        sr = np.moveaxis(sr.squeeze(), 0, 2)

        # denormalise
        sr = (sr * (s2_p99 - s2_p1) + s2_p1) * norm_std + norm_median

        super_resolved[nfr] = (np.clip(sr, 0, 3)*10000).astype(np.uint16)

    eop_sr = EOPatch(bbox=eopatch.bbox, timestamp=noncloudy.timestamp)
    eop_sr.data[f'SR-{model_prefix.upper()}'] = super_resolved
    eop_sr.data['S2'] = noncloudy.data['BANDS'].astype(np.uint16)
    eop_sr.scalar['N_VIEWS'] = actual_n_views[..., np.newaxis]

    return eop_sr

In [None]:
eops_folder = ''
deimos_eops_folder = ''

In [None]:
eop_countries = pd.read_parquet(EOP_COUNTRIES_PQ)

In [None]:
eopatch_names = eop_countries[eop_countries.country == 'Lithuania'].eopatch.unique()

In [None]:
eop_sr = predict_sr_images(f'{eops_folder}/{eopatch_names[0]}',
                           model,
                           MODEL_PREFIX,
                           scale_factor=4,
                           country_norm_df=country_norm_df,
                           filesystem=filesystem,
                           normalize=True,
                           norm_s2_npz=norm_s2,
                           max_cc=0.05,
                           n_views=config['training']['n_views'])

In [None]:
eop_sr

In [None]:
fig, axs = plt.subplots(ncols=2, nrows=17, figsize=(15, 17*7.5))

for ni, (s2, sr) in enumerate(zip(eop_sr.data['S2'], eop_sr.data[f'SR-{MODEL_PREFIX.upper()}'])):
    axs[ni][0].imshow(2.5*s2[..., [2, 1, 0]]/10000)
    axs[ni][1].imshow(2.5*sr[..., [2, 1, 0]]/10000)
    axs[ni][0].set_title(f'S2 - {eop_sr.timestamp[ni]}')
    axs[ni][1].set_title(f'SR - {eop_sr.scalar["N_VIEWS"][ni][0]} actual views')

fig.tight_layout()

In [None]:
eops_sr_folder = f'eopatches-{MODEL_PREFIX}/'

In [None]:
eops_sr_folder

In [None]:
for eopatch_name in tqdm(eopatch_names):
    try:
        eop_sr = predict_sr_images(f'{eops_folder}/{eopatch_name}',
                                   model,
                                   MODEL_PREFIX,
                                   scale_factor=4,
                                   country_norm_df=country_norm_df,
                                   filesystem=filesystem,
                                   normalize=True,
                                   norm_s2_npz=norm_s2,
                                   max_cc=0.05,
                                   n_views=config['training']['n_views'])
        eop_sr.save(f'{eops_sr_folder}/{eopatch_name}',
                    filesystem=filesystem,
                    overwrite_permission=OverwritePermission.OVERWRITE_FEATURES)
        del eop_sr
    except RuntimeError:
        print(f'Error in {eopatch_name}')