In [None]:
import os
import numpy as np
%run ../../src/dsen2/utils/DSen2Net.py

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
MDL_PATH = "../../src/dsen2/models/"

input_shape = ((4, None, None), (6, None, None))
model = s2model(input_shape, num_layers=6, feature_size=128)
predict_file = MDL_PATH+'s2_032_lr_1e-04.hdf5'
model.load_weights(predict_file)

In [None]:
test_data = []
files = [x for x in os.listdir("../../data/test-raw/") if ".npy" in x]
for file in files:
    test_data.append(np.load("../../data/test-raw/" + file)[:5])
test_data = np.concatenate(test_data, axis = 0)
test_data = np.float32(test_data) / 65535
print(test_data.shape)

In [None]:
from skimage.transform import resize

def downPixelAggr(img, SCALE=2):
    from scipy import signal
    import skimage.measure
    from scipy.ndimage.filters import gaussian_filter
    img = resize(img, ((24, 24, 6)), 0)

    img_blur = np.zeros(img.shape)
    # Filter the image with a Gaussian filter
    for i in range(0, img.shape[2]):
        img_blur[:, :, i] = gaussian_filter(img[:, :, i], 1/SCALE)
    # New image dims
    new_dims = tuple(s//SCALE for s in img.shape)
    img_lr = np.zeros(new_dims[0:2]+(img.shape[-1],))
    # Iterate through all the image channels with avg pooling (pixel aggregation)
    for i in range(0, img.shape[2]):
        img_lr[:, :, i] = skimage.measure.block_reduce(img_blur[:, :, i], (SCALE, SCALE), np.mean)
        
    img_lr = resize(img_lr, ((48, 48, 6)), 0)

    return np.squeeze(img_lr)

def make_input_data(data):
    # 10 meter band, 20 -> 40 meter band, 20 meter band
    twentym = data[..., 4:]
    labels = np.copy(twentym)
    twentym = np.reshape(twentym, ((twentym.shape[0], 24, 2, 24, 2, 6)))
    twentym = np.mean(twentym, axis = (2, 4))
    tenm = data[..., :4]

    fourty_m = np.zeros_like(data[..., 4:])
    for sample in range(fourty_m.shape[0]):
        fourty_m[sample] = downPixelAggr(twentym[sample])

    bilinear_upsample = resize(fourty_m, (fourty_m.shape[0], 48, 48, 6), 2)
    input_data = np.concatenate([tenm, fourty_m], axis = -1)
    
    return bilinear_upsample, input_data, labels

def test_rmse(inp):
    label = np.copy(inp[..., 4:])
    inp_20m = inp[..., 4:]
    inp_40m = np.reshape(inp_20m, (24, 2, 24, 2, 6))
    inp_40m = np.mean(inp_40m, axis = (2, 4))
    
    fourty_m = downPixelAggr(inp_40m)
    inp[..., 4:] = fourty_m
    
    supered = np.squeeze(superresolve(inp[np.newaxis], model))
    se = (supered[..., 4:] - label)**2
    mse = np.mean(se, axis = (0, 1))
    rmse = np.sqrt(mse)
    return rmse
    

In [None]:
rmses = np.empty((len(test_data), 6))
for i in range(0, len(test_data)):
    rmse = test_rmse(test_data[i])
    print(i, rmse)
    rmses[i] = rmse

In [None]:
np.mean(rmses, axis = 0)

In [None]:
test_rmse(test)

In [None]:
plt.figure(figsize=(10,7.5))
sns.heatmap(test_data[135, :, :, 4])

In [None]:
%time x = superresolve(test_data[135][np.newaxis], model)
plt.figure(figsize=(10,7.5))
sns.heatmap(x[0, ..., 4])