In [None]:
%reload_ext autoreload
%autoreload 2

%cd /home/ubuntu/dione-sr/

In [None]:
import os

In [None]:
import imageio
import numpy as np
import pandas as pd
from fs_s3fs import S3FS
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

In [None]:
import cv2 as cv
from hrnet.src.train import resize_batch_images
from sr.data_loader import ImagesetDataset
from sr.metrics import METRICS, minshift_loss
from torch.utils.data import DataLoader

In [None]:
aws_access_key_id = ''
aws_secret_access_key = ''

In [None]:
filesystem = S3FS(bucket_name='',
                  aws_access_key_id=aws_access_key_id,
                  aws_secret_access_key=aws_secret_access_key)

In [None]:
norm_deimos = {k: v for k, v in np.load(filesystem.openbin('metadata/deimos_min_max_norm.npz')).items()}
norm_s2 = {k: v for k, v in np.load(filesystem.openbin('metadata/s2_min_max_norm.npz')).items()}

data_df = pd.read_parquet(filesystem.openbin('metadata/npz_info_small.pq'))
country_norm_df = pd.read_parquet(filesystem.openbin('metadata/s2_norm_per_country.pq'))

In [None]:
NPZ_FOLDER = ''

In [None]:
data_df.head()

In [None]:
dataset = ImagesetDataset(imset_dir=NPZ_FOLDER,
                          imset_npz_files=data_df.singleton_npz_filename.values,
                          filesystem=filesystem,
                          country_norm_df=country_norm_df,
                          normalize=True,
                          norm_deimos_npz=norm_deimos,
                          norm_s2_npz=norm_s2,
                          time_first=True
                          )

dataloader = DataLoader(dataset,
                        batch_size=256,
                        shuffle=False,
                        num_workers=16,
                        pin_memory=True)

In [None]:
SHIFTS = 6

### test run on  a single batch

In [None]:
batch = next(iter(dataloader))

In [None]:
batch.keys()

In [None]:
lrs = batch['lr']
hrs = batch['hr']
names = batch['name']
alphas = batch['alphas']

In [None]:
interpolated = resize_batch_images(lrs[:, -1, [-1], ...],
                                   fx=3, fy=3, interpolation=cv.INTER_CUBIC)

In [None]:
mse = METRICS['MSE'](hrs[:, [-1], ...], interpolated.float())
mse_shift, mse_ids = minshift_loss(hrs[:, [-1], ...], interpolated.float(),
                                   shifts=SHIFTS, metric='MSE')
mse_shift_c, mse_ids_c = minshift_loss(hrs[:, [-1], ...], interpolated.float(),
                                       metric='MSE', shifts=SHIFTS, apply_correction=True)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
ax.scatter(mse_shift.numpy(), mse_shift_c.numpy(), alpha=.3, label='MSE shifted corrected')
ax.scatter(mse_shift.numpy(), mse.numpy(), alpha=.3, label='MSE')
ax.plot([0, 1], [0, 1], 'k')
ax.grid()
ax.legend()
ax.set_xlabel('MSE shifted')

In [None]:
np.where(mse_shift_c.numpy() > .35)

In [None]:
np.where(mse_shift_c.numpy() < .02)

In [None]:
idx = 224

img_de = hrs[idx, [-1], ...].numpy().squeeze()
img_s2 = interpolated[idx].numpy().squeeze()

In [None]:
ids = mse_ids_c[idx, :].numpy().astype(np.uint8)
print(ids)

img_s2 = img_s2[SHIFTS//2:-SHIFTS//2, SHIFTS//2:-SHIFTS//2]
img_de = img_de[ids[0]:ids[1], ids[2]:ids[3]]

img_s2 = 255*(img_s2-img_s2.min())/(img_s2.max()-img_s2.min())
img_de = 255*(img_de-img_de.min())/(img_de.max()-img_de.min())

giffile = f's2-deimos-{names[idx]}.gif'
imageio.mimsave(giffile,
                [img_s2.astype(np.uint8), img_de.astype(np.uint8)],
                duration=0.5)

## Compute scores on entire dataset of patchlets

In [None]:
pq_filename = 'scores-bicubic-32x32.pq'

if not os.path.exists(pq_filename):

    scores = []
    for sample in tqdm(dataloader):
        hrs = sample['hr'][:, [-1], ...]

        interpolated = resize_batch_images(sample['lr'][:, -1, [-1], ...],
                                           fx=3, fy=3, interpolation=cv.INTER_CUBIC)
        mse_ = METRICS['MSE'](hrs.float(), interpolated.float())
        mse_shift, _ = minshift_loss(hrs.float(), interpolated.float(),
                                     metric='MSE', shifts=SHIFTS)
        mse_shift_c, _ = minshift_loss(hrs.float(), interpolated.float(),
                                       metric='MSE', shifts=SHIFTS, apply_correction=True)
        psnr_shift_c, _ = minshift_loss(hrs.float(), interpolated.float(),
                                        metric='PSNR', shifts=SHIFTS, apply_correction=True)
        ssim_shift_c, _ = minshift_loss(hrs.float(), interpolated.float(),
                                        metric='SSIM', shifts=SHIFTS, apply_correction=True)

        for name, mse, mse_s, mse_sc, psnr, ssim in zip(sample['name'],
                                                        mse_,
                                                        mse_shift,
                                                        mse_shift_c,
                                                        psnr_shift_c,
                                                        ssim_shift_c):
            scores.append({'name': name,
                           'MSE': mse.numpy().astype(np.float32),
                           'MSE_s': mse_s.numpy().astype(np.float32),
                           'MSE_s_c': mse_sc.numpy().astype(np.float32),
                           'PSNR_s_c': psnr.numpy().astype(np.float32),
                           'SSIM_s_c': ssim.numpy().astype(np.float32)})

    df = pd.DataFrame(scores)
    print(len(df))

    df.MSE = df.MSE.astype(np.float32)
    df.MSE_s = df.MSE_s.astype(np.float32)
    df.MSE_s_c = df.MSE_s_c.astype(np.float32)
    df.PSNR_s_c = df.PSNR_s_c.astype(np.float32)
    df.SSIM_s_c = df.SSIM_s_c.astype(np.float32)

    df.to_parquet(pq_filename)
else:
    df = pd.read_parquet(pq_filename)

In [None]:
len(df)

In [None]:
df.head()

In [None]:
fig, ax = plt.subplots(figsize=(15, 10))
df.MSE.hist(ax=ax, alpha=.3, bins=50, range=(0, 1), label='MSE')
df.MSE_s.hist(ax=ax, alpha=.3, bins=50, range=(0, 1), label='MSE_s')
df.MSE_s_c.hist(ax=ax, alpha=.3, bins=50, range=(0, 1), label='MSE_s_c')
ax.legend()

In [None]:
fig, ax = plt.subplots(figsize=(15, 10))
ax.scatter(df.MSE_s_c, df.SSIM_s_c, alpha=.1)

In [None]:
fig, ax = plt.subplots(figsize=(15, 10))
ax.scatter(df.PSNR_s_c, df.SSIM_s_c, alpha=.1)

In [None]:
data_df.rename(columns={'singleton_npz_filename': 'name'}, inplace=True)

In [None]:
scores_df = pd.merge(df, data_df, on='name')

In [None]:
scores_df.head()