In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
from datetime import datetime

import numpy as np
import pandas as pd
import yaml
from eolearn.core import EOPatch, FeatureType, OverwritePermission
from eolearn.io import ExportToTiff
from fs_s3fs import S3FS
from matplotlib import pyplot as plt
from skimage.exposure import match_histograms
from tqdm.auto import tqdm

import torch
import wandb
from cv2 import INTER_CUBIC, GaussianBlur, resize
from hrnet.src.predict import Model
from hrnet.src.train import resize_batch_images
from sr.data_loader import EopatchPredictionDataset, ImagesetDataset
from torch.utils.data import DataLoader

In [None]:
from sr.metrics import minshift_loss

## 1.0 Configuration

In [None]:
# ! wandb login <WANDB KEY>

In [None]:
aws_access_key_id = ''
aws_secret_access_key = ''

filesystem = S3FS(
    bucket_name='',
    aws_access_key_id=aws_access_key_id,
    aws_secret_access_key=aws_secret_access_key, region='eu-central-1')


# If 'LOCAL' it will be loaded from local wandb storage,  if 'WANDB' from online storage
MODEL_LOCATION = 'LOCAL'

MODEL_NAME = ''
MODEL_PREFIX = ''
MATCHES_S2 = True
LOCATION = f'wandb/latest-run/files/'

In [None]:
EOP_COUNTRIES_PQ = f'eop-countries_overlapped.pq'

In [None]:
if not os.path.exists(EOP_COUNTRIES_PQ):
    eops_countries = []
    for eopfname in filesystem.listdir(''):
        eop = EOPatch.load(os.path.join('',
                           eopfname), filesystem=filesystem, lazy_loading=True)
        eops_countries.append({'country': 'Lithuania' if str(eop.bbox.crs) == 'EPSG:32634' else 'Cyprus',
                               'eopatch': eopfname})
        pd.DataFrame(eops_countries).to_parquet(
            f'eop-countries_overlapped.pq')

In [None]:
checkpoint_filename = 'HRNet.pth'

In [None]:
if MODEL_LOCATION == 'WANDB':
    model_checkpoint = wandb.restore(
        checkpoint_filename, run_path=LOCATION, replace=True)
    model_checkpoint = open(checkpoint_filename, 'rb')
    model_config_yaml = yaml.load(wandb.restore(
        'config.yaml', run_path=LOCATION, replace=True))
elif MODEL_LOCATION == 'LOCAL':
    model_checkpoint = os.path.join(LOCATION, checkpoint_filename)
    model_config_yaml = yaml.load(open(os.path.join(LOCATION, 'config.yaml')))

    assert os.path.isfile(model_checkpoint)

In [None]:
config = {k: v['value']
          for k, v in model_config_yaml.items() if 'wandb' not in k}

In [None]:
country_norm_df = pd.read_parquet(
    filesystem.openbin('metadata/s2_norm_per_country.pq'))

norm_deimos = {k: v for k, v in np.load(
    filesystem.openbin('metadata/deimos_min_max_norm.npz')).items()}
norm_s2 = {k: v for k, v in np.load(
    filesystem.openbin('metadata/s2_min_max_norm.npz')).items()}

data_df = pd.read_parquet(filesystem.openbin('metadata/npz_info_small.pq'))
data_df.reset_index(inplace=True)

In [None]:
scores_df = pd.read_parquet(filesystem.openbin('scores-bicubic-32x32.pq')).rename(columns={'name': 'singleton_npz_filename'})
data_df = pd.merge(data_df, scores_df, on='singleton_npz_filename')
data_df['MSE_ratio'] = data_df['MSE_s']/data_df['MSE_s_c']

In [None]:
filtered_data = data_df[(data_df['SSIM_s_c'] > .2) &
                        (data_df['PSNR_s_c'] > 10) &
                        (data_df['MSE_ratio'] < 10) &
                        (data_df['is_shadow_v2'] == False) &
                        (data_df['countries'] == 'Lithuania') &
                        (data_df['num_tstamps'] > 1)]

In [None]:
model = Model(config)
model.load_checkpoint(checkpoint_file=model_checkpoint)

## 1.2 Load data

In [None]:
test_samples = filtered_data[(filtered_data.train_test_validation == 'validation')].sample(
    2000).singleton_npz_filename.values

test_dataset = ImagesetDataset(
    imset_dir=config['paths']['prefix'],
    imset_npz_files=test_samples,
    country_norm_df=country_norm_df,
    normalize=True,
    norm_deimos_npz=norm_deimos,
    norm_s2_npz=norm_s2,
    channels_labels=config['training']['channels_labels'],
    channels_feats=config['training']['channels_features'],
    time_first=True,
    n_views=config['training']['n_views'],
    histogram_matching=config['training']['histogram_matching']
)

In [None]:
def normalise_bands(eop, bands_name, eop_name, norm_df):
    """ Normalise bands """
    df_means = norm_df[norm_df.eopatch == eop_name].groupby('month').mean()[cols_mean]
    df_std = norm_df[norm_df.eopatch == eop_name].groupby('month').mean()[cols_std]
    
    bands = eop.data[bands_name]
    
    normalised = np.empty(bands.shape, dtype=np.float32)
    
    for nb, (band, ts) in enumerate(zip(bands, eop.timestamp)):
        means = df_means.loc[ts.strftime('%Y-%m')].values
        stds = df_std.loc[ts.strftime('%Y-%m')].values
        
        normalised[nb] = (band - means) / stds
        
    return normalised

In [None]:
test_dataloader = DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=8,
    pin_memory=True)

In [None]:
sample = test_dataset[0]

In [None]:
hr = np.moveaxis(sample['hr'].numpy(), 0, 2)

hr_ = resize(GaussianBlur(hr, ksize=(7, 7), sigmaX=4), None, fx=1/4, fy=1/4)

hr__ = resize(hr_, None, fx=4, fy=4, interpolation=INTER_CUBIC)

In [None]:
lr = np.moveaxis(
    sample['lr'][np.sum(sample['alphas'].int().numpy())-1].numpy(), 0, 2)

In [None]:
fig, axs = plt.subplots(ncols=4, figsize=(15, 7.5))
axs[0].imshow(hr[..., [2, 1, 0]])
axs[1].imshow(hr_[..., [2, 1, 0]])
axs[2].imshow(lr[..., [2, 1, 0]])
axs[3].imshow(hr__[..., [2, 1, 0]])

In [None]:
lr_ = match_histograms(lr, hr_, multichannel=True)
lr__ = resize(lr_, None, fx=4, fy=4, interpolation=INTER_CUBIC)

In [None]:
fig, axs = plt.subplots(ncols=4, figsize=(15, 7.5))
axs[0].imshow(hr[..., [2, 1, 0]])
axs[1].imshow(hr_[..., [2, 1, 0]])
axs[2].imshow(lr_[..., [2, 1, 0]])
axs[3].imshow(lr__[..., [2, 1, 0]])

In [None]:
ssims_bi_de, psnrs_bi_de = [], []
ssims_bi_s2, psnrs_bi_s2 = [], []
ssims_sr, psnrs_sr = [], []

for sample in tqdm(test_dataloader):
    sr = torch.from_numpy(model(sample))
    alphas = sample['alphas'].float()
    lrs = sample['lr'][np.arange(len(alphas)),
                        torch.sum(alphas, dim=1, dtype=torch.int64) - 1]
    hr = sample['hr'].float()

    lrs_hm = torch.tensor([match_histograms(np.moveaxis(lri.numpy(), 0, 2),
                                             np.moveaxis(hri.numpy(), 0, 2),
                                             multichannel=True)
                            for (lri, hri) in zip(lrs, hr)])
    
    lrs_hm = lrs_hm.permute([0, 3, 1, 2])

    baseline_s2 = resize_batch_images(lrs_hm, fx=4, fy=4).float()

    baseline_de = torch.tensor([resize(resize(GaussianBlur(np.moveaxis(hr_.numpy(), 0, 2),
                                            ksize=(7, 7),
                                            sigmaX=4), None, fx=1/4, fy=1/4),
                       None, fx=4, fy=4, interpolation=INTER_CUBIC) for hr_ in hr])
    baseline_de = baseline_de.permute([0, 3, 1, 2])

    ssims_sr.append(minshift_loss(hr, sr, metric='SSIM', apply_correction=False)[0])
    ssims_bi_de.append(minshift_loss(hr, baseline_de, metric='SSIM', apply_correction=False)[0])
    ssims_bi_s2.append(minshift_loss(hr, baseline_s2, metric='SSIM', apply_correction=False)[0])

    psnrs_sr.append(minshift_loss(hr, sr, metric='PSNR', apply_correction=False)[0])
    psnrs_bi_de.append(minshift_loss(hr, baseline_de, metric='PSNR', apply_correction=False)[0])
    psnrs_bi_s2.append(minshift_loss(hr, baseline_s2, metric='PSNR', apply_correction=False)[0])

In [None]:
ssim_bi_de = np.array([jj for item in ssims_bi_de for jj in item.numpy()])
ssim_bi_s2 = np.array([jj for item in ssims_bi_s2 for jj in item.numpy()])
ssim_sr = np.array([jj for item in ssims_sr for jj in item.numpy()])

psnr_bi_de = np.array([jj for item in psnrs_bi_de for jj in item.numpy()])
psnr_bi_s2 = np.array([jj for item in psnrs_bi_s2 for jj in item.numpy()])
psnr_sr = np.array([jj for item in psnrs_sr for jj in item.numpy()])