# Tile test data for smooth interpolation prediction

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import os

%run ../src/preprocessing/slope.py
%run ../src/preprocessing/indices.py
%run ../src/dsen2/utils/DSen2Net.py

In [None]:
from scipy.sparse.linalg import splu
import scipy.sparse as sparse
from skimage.transform import resize

In [None]:
MDL_PATH = "../src/dsen2/models/"
INPUT_SHAPE = ((4, None, None), (6, None, None))
MODEL = s2model(INPUT_SHAPE, num_layers=6, feature_size=128)
PREDICT_FILE = MDL_PATH+'s2_032_lr_1e-04.hdf5'
MODEL.load_weights(PREDICT_FILE)

In [None]:
c_arr = np.array([[1, 1, 1, 1, 1,],
                  [1, 2, 2, 2, 1,],
                  [1, 2, 3, 2, 1,],
                  [1, 2, 2, 2, 1,],
                  [1, 1, 1, 1, 1,],])
                  
c_arr = c_arr / 3
o_arr = 1 - c_arr
c_arr = np.tile(c_arr[:, :, np.newaxis], (1, 1, 11))
o_arr = np.tile(o_arr[:, :, np.newaxis], (1, 1, 11))

In [None]:
test_files = [x for x in os.listdir("../data/test-s2/") if x in os.listdir("../data/test-s1")]
data = np.load("../data/raw/test-raw/{}".format(test_files[0]))

In [None]:
sns.heatmap(data[0, :, :, 2])

In [None]:
tile_lookups = { #'x_start, y_start'
    'center': [8, 8],
    'left': [0, 8],
    'right': [16, 8],
    'up': [8, 16],
    'down': [8, 0],
    'ul': [16, 0],
    'ur': [16, 16],
    'dl': [0, 0],
    'dr': [16, 0],
}

IMSIZE = 32

In [None]:
list(tile_lookups.keys())

In [None]:
def remove_missed_clouds(img):
    """ Removes steps that are likely to be missed cloud or shadows
        based on two interquartile ranges for the near infrared band
        
        Parameters:
         img (arr):

        Returns:
         to_remove (list): 
    """
    iqr = np.percentile(img[:, :, :, 3].flatten(), 75) - np.percentile(img[:, :, :, 3].flatten(), 25)
    thresh_t = np.percentile(img[:, :, :, 3].flatten(), 75) + iqr*2
    thresh_b = np.percentile(img[:, :, :, 3].flatten(), 25) - iqr*2
    diffs_fw = np.diff(img, 1, axis = 0)
    diffs_fw = np.mean(diffs_fw, axis = (1, 2, 3))
    diffs_fw = np.array([0] + list(diffs_fw))
    diffs_bw = np.diff(np.flip(img, 0), 1, axis = 0)
    diffs_bw = np.flip(np.mean(diffs_bw, axis = (1, 2, 3)))
    diffs_bw = np.array(list(diffs_bw) + [0])
    diffs = abs(diffs_fw - diffs_bw) * 100 # 3, -3 -> 6, -3, 3 -> 6, -3, -3
    #diffs = [int(x) for x in diffs]
    outlier_percs = []
    for step in range(img.shape[0]):
        bottom = len(np.argwhere(img[step, :, :, 3].flatten() > thresh_t))
        top = len(np.argwhere(img[step, :, :, 3].flatten() < thresh_b))
        p = 100* ((bottom + top) / (48*48))
        outlier_percs.append(p)
    print(outlier_percs)
    to_remove = np.argwhere(np.array(outlier_percs) > 15)
    return to_remove

def remove_bad_steps(img, probs, shadows, image_dates):
    shadow_sums = np.sum(shadows, axis = (1, 2))
    shadow_steps = np.argwhere(shadow_sums > (48*48/3))
    #probs = np.mean(probs, axis = (1, 2))
    args = np.array([len(np.argwhere(probs[x].flatten() > 0.3)) for x in range(probs.shape[0])])
    dirty_steps = np.argwhere(args > (48)*(48) / 5)
    missing_images = [np.argwhere(img[x, :, : :].flatten() == 0.0) for x in range(img.shape[0])]
    missing_images = np.array([len(x) for x in missing_images])
    missing_images_p = [np.argwhere(img[x, :, : :].flatten() >= 1) for x in range(img.shape[0])]
    missing_images_p = np.array([len(x) for x in missing_images_p])
    missing_images += missing_images_p
    missing_images = list(np.argwhere(missing_images >= 25))
    to_remove = np.unique(np.array(list(dirty_steps) + list(missing_images) + list(shadow_steps)))

    # Remove null steps
    print("There are {}/{} dirty steps: {}"
          " cloud, {} missing, {} shadow".format(len(to_remove),
                                                 len(img), len(dirty_steps),
                                                 len(missing_images),
                                                 len(shadow_steps)))

    img = np.delete(img, to_remove, 0)
    probs = np.delete(probs, to_remove, 0)
    image_dates = np.delete(image_dates, to_remove)
    shadows = np.delete(shadows, to_remove, 0)

    to_remove = remove_missed_clouds(img)
    img = np.delete(img, to_remove, 0)
    probs = np.delete(probs, to_remove, 0)
    image_dates = np.delete(image_dates, to_remove)
    shadows = np.delete(shadows, to_remove, 0)
    print("Removing {} steps based on ratio".format(len(to_remove)))
    return img, probs, image_dates, shadows
    
def DSen2(d10, d20):
    """Super resolves 20 meter bans using the DSen2 convolutional
       neural network, as specified in Lanaras et al. 2018
       https://github.com/lanha/DSen2

        Parameters:
         d10 (arr): (4, X, Y) shape array with 10 meter resolution
         d20 (arr): (6, X, Y) shape array with 20 meter resolution

        Returns:
         prediction (arr): (6, X, Y) shape array with 10 meter superresolved
                          output of DSen2 on d20 array
    """
    test = [d10, d20]
    input_shape = ((4, None, None), (6, None, None))
    prediction = _predict(test, input_shape, deep=False)
    #prediction *= 5
    return prediction

def _predict(test, input_shape, model = MODEL, deep=False, run_60=False):
    
    prediction = model.predict(test, verbose=1)
    return prediction

def calculate_proximal_steps_index(date, satisfactory):
    """Returns proximal steps that are cloud and shadow free

         Parameters:
          date (int): current time step
          satisfactory (list): time steps with no clouds or shadows

         Returns:
          arg_before (str): index of the prior clean image
          arg_after (int): index of the next clean image
    """
    arg_before, arg_after = None, None
    if date > 0:
        idx_before = satisfactory - date
        arg_before = idx_before[np.where(idx_before < 0, idx_before, -np.inf).argmax()]
    if date < np.max(satisfactory):
        idx_after = satisfactory - date
        arg_after = idx_after[np.where(idx_after > 0, idx_after, np.inf).argmin()]
    if not arg_after and not arg_before:
        arg_after = date
        arg_before = date
    if not arg_after:
        arg_after = arg_before
    if not arg_before:
        arg_before = arg_after
    return arg_before, arg_after

def speyediff(N, d, format = 'csc'):
    """Calculates the d-th order sparse difference matrix based on 
       an initial N x N identity matrix

         Parameters:
          N (int): input length
          d (int): smoothing order

         Returns:
          spmat (arr): sparse difference matrix
    """
    shape = (N-d, N)
    diagonals = np.zeros(2*d + 1)
    diagonals[d] = 1.
    for i in range(d):
        diff = diagonals[:-1] - diagonals[1:]
        diagonals = diff
    offsets = np.arange(d+1)
    spmat = sparse.diags(diagonals, offsets, shape, format = format)
    return spmat

def smooth(y, lmbd, d = 2):
    """Calculates the whittaker smoother on input array

         Parameters:
          y (arr): 1-dimensional input array
          lmbd (int): degree of smoothing, higher is more

         Returns:
          z (arr): smoothed version of y
    """
    m = len(y)
    E = sparse.eye(m, format = 'csc')
    D = speyediff(m, d, format = 'csc')
    coefmat = E + lmbd * D.conj().T.dot(D)
    z = splu(coefmat).solve(y)
    return z


def superresolve(arr):
    print("Shape before super: {}".format(arr.shape))

    d10 = arr[:, :, :, 0:4]
    d20 = arr[:, :, :, 4:10]

    d10 = np.swapaxes(d10, 1, -1)
    d10 = np.swapaxes(d10, 2, 3)
    d20 = np.swapaxes(d20, 1, -1)
    d20 = np.swapaxes(d20, 2, 3)
    superresolved = DSen2(d10, d20)
    superresolved = np.swapaxes(superresolved, 1, -1)
    superresolved = np.swapaxes(superresolved, 1, 2)
    print(superresolved.shape)
    print(arr.shape)

    # returns band IDXs 3, 4, 5, 7, 8, 9
    arr[:, :, :, 4:10] = superresolved
    print("Shape after super: {}".format(arr.shape))
    return arr

def calculate_and_save_best_images(img_bands, image_dates):
    """ Interpolate input data of (Time, X, Y, Band) to a constant
        (72, X, Y, Band) shape with one time step every five days
        
        Parameters:
         img_bands (arr):
         image_dates (list):
         
        Returns:
         keep_steps (arr):
         max_distance (int)
    """
    biweekly_dates = [day for day in range(0, 360, 5)] # ideal imagery dates are every 15 days
    
    # Clouds have been removed at this step, so all steps are satisfactory
    satisfactory_ids = [x for x in range(0, img_bands.shape[0])]
    satisfactory_dates = [value for idx, value in enumerate(image_dates) if idx in satisfactory_ids]
    
    
    selected_images = {}
    for i in biweekly_dates:
        distances = [abs(date - i) for date in satisfactory_dates]
        closest = np.min(distances)
        closest_id = np.argmin(distances)
        # If there is imagery within 5 days, select it
        if closest < 8:
            date = satisfactory_dates[closest_id]
            image_idx = int(np.argwhere(np.array(image_dates) == date)[0])
            selected_images[i] = {'image_date': [date], 'image_ratio': [1], 'image_idx': [image_idx]}
        # If there is not imagery within 7 days, look for the closest above and below imagery
        else:
            distances = np.array([(date - i) for date in satisfactory_dates])
            # Number of days above and below the selected date of the nearest clean imagery
            above = distances[np.where(distances < 0, distances, -np.inf).argmax()]
            below = distances[np.where(distances > 0, distances, np.inf).argmin()]
            if abs(above) > 240: # If date is the last date, occassionally argmax would set above to - number
                above = below
            if abs(below) > 240:
                below = above
            if above != below:
                below_ratio = above / (above - below)
                above_ratio = 1 - below_ratio
            else:
                above_ratio = below_ratio = 0.5
                
            # Extract the image date and imagery index for the above and below values
            above_date = i + above
            above_image_idx = int(np.argwhere(np.array(image_dates) == above_date)[0])
            
            below_date = i + below
            below_image_idx = int(np.argwhere(np.array(image_dates) == below_date)[0])
            
            selected_images[i] = {'image_date': [above_date, below_date], 'image_ratio': [above_ratio, below_ratio],
                                 'image_idx': [above_image_idx, below_image_idx]}
                               
    max_distance = 0
    
    for i in selected_images.keys():
        #print(i, selected_images[i])
        if len(selected_images[i]['image_date']) == 2:
            dist = selected_images[i]['image_date'][1] - selected_images[i]['image_date'][0]
            if dist > max_distance:
                max_distance = dist
    
    print("Maximum time distance: {}".format(max_distance))
        
    # Compute the weighted average of the selected imagery for each time step
    keep_steps = []
    use_median = False
    for i in selected_images.keys():
        step1_additional = None
        step2_additional = None
        info = selected_images[i]
        if len(info['image_idx']) == 1:
            step = img_bands[info['image_idx'][0]]
        if len(info['image_idx']) == 2:
            step1 = img_bands[info['image_idx'][0]] # * info['image_ratio'][0]
            step2 = img_bands[info['image_idx'][1]]
            step = step1 * 0.5 + step2 * 0.5
        keep_steps.append(step)
        
    keep_steps = np.stack(keep_steps)
    return keep_steps, max_distance

def remove_cloud_and_shadows(tiles, probs, shadows, image_dates, wsize = 5):
    """ Interpolates clouds and shadows for each time step with 
        linear combination of proximal clean time steps for each
        region of specified window size
        
        Parameters:
         tiles (arr):
         probs (arr): 
         shadows (arr):
         image_dates (list):
         wsize (int): 
    
        Returns:
         tiles (arr): 
    """
    c_probs = np.copy(probs)
    c_probs = c_probs - np.min(c_probs, axis = 0)
    c_probs[np.where(c_probs > 0.33)] = 1.
    c_probs[np.where(c_probs < 0.33)] = 0.
    c_probs = np.reshape(c_probs, [c_probs.shape[0], int(48/8), 8, int(48/8), 8])
    c_probs = np.sum(c_probs, (2, 4))
    c_probs = resize(c_probs, (c_probs.shape[0], 48, 48), 0)
    c_probs[np.where(c_probs < 12)] = 0.
    c_probs[np.where(c_probs >= 12)] = 1.
    c_probs += shadows
    c_probs[np.where(c_probs >= 1.)] = 1.
    n_interp = 0
    for cval in range(0, 48 - 5, 1):
        for rval in range(0, 48 - 5, 1):
            subs = c_probs[:, cval:cval + wsize, rval:rval+wsize]
            satisfactory = [x for x in range(c_probs.shape[0]) if np.sum(subs[x, :, :]) < 10]
            satisfactory = np.array(satisfactory)
            for date in range(0, tiles.shape[0]):
                if np.sum(subs[date, :, :]) > 10:
                    n_interp += 1
                    before, after = calculate_proximal_steps_index(date, satisfactory)
                    before = date + before
                    after = date + after
                    if after >= tiles.shape[0]:
                        after = before
                    if before < 0:
                        before = after
                    bef = tiles[before, cval:cval+wsize, rval:rval+wsize, : ]
                    aft = tiles[after, cval:cval+wsize, rval:rval+wsize, : ]
                    before = image_dates[before]
                    after = image_dates[after]
                    before_diff = abs(image_dates[date] - before)
                    after_diff = abs(image_dates[date] - after)
                    bef_wt = 1 - before_diff / (before_diff + after_diff)
                    aft_wt = 1 - bef_wt
                    candidate = bef_wt*bef + aft_wt*aft
                    candidate = candidate*c_arr + tiles[date, cval:cval+wsize, rval:rval+wsize, : ]*o_arr
                    tiles[date, cval:cval+wsize, rval:rval+wsize, : ] = candidate  
    print("Interpolated {} px".format(n_interp))
    return tiles


def process_array(plot_id):
    starting_days = np.cumsum([0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30])
    YEAR = 2019
    s2 = np.load('../data/raw/test-raw/{}.npy'.format(plot_id))
    shadows = np.load('../data/raw/test-shadows/{}.npy'.format(plot_id))
    clouds = np.load('../data/raw/test-clouds/{}.npy'.format(plot_id))
    dates = np.load('../data/raw/test-dates/{}.npy'.format(plot_id), allow_pickle = True)
    s1 = np.load("../data/raw/test-s1/{}.npy".format(plot_id))
    dem = np.load("../data/raw/test-dem/{}.npy".format(plot_id))
    
    print("BEF", s2.shape)
    print("BEF", clouds.shape)
    print("BEF", dates.shape)
    image_dates = []
    for date in dates:
        if date.year == YEAR - 1:
            image_dates.append(-365 + starting_days[(date.month-1)] + date.day)
        if date.year == YEAR:
            image_dates.append(starting_days[(date.month-1)] + date.day)
        if date.year == YEAR + 1:
            image_dates.append(365 + starting_days[(date.month-1)]+date.day)
    dates = np.array(image_dates)
    
    print("BEF", s2.shape)
    print("BEF", clouds.shape)
    # subset time steps
    s2, clouds, dates, shadows = remove_bad_steps(s2, clouds, shadows, dates)
    print("AFT", s2.shape)
    print("AFT", clouds.shape)
    
    # Concatenate DEM
    dem = np.tile(dem.reshape((1, 48, 48, 1)), (s2.shape[0], 1, 1, 1))
    s2 = np.concatenate([s2, dem], axis = -1)
    s2[:, :, :, -1] /= 90
    
    # remove clouds and shadows
    print(s2.shape)
    print(clouds.shape)
    print(shadows.shape)
    s2 = remove_cloud_and_shadows(s2, clouds, shadows, dates)
    
    # super resolve
    s2 = superresolve(s2)
    
    
    # indices
    s2, amin = evi(s2, True)
    s2 = bi(s2, True)
    s2 = msavi2(s2, True)
    s2 = si(s2, True)
    
    s2 = s2[:, 8:40, 8:40, :]


    
    # whittaker smooth
    # Smooth linear interpolation
    for row in range(0, 32):
        for column in range(0, 32):
            for band in [x for x in range(0, 15) if x != 10]:
                sm = smooth(s2[:, row, column, band], 800, d = 2)
                s2[:, row, column, band] = sm
    
    s2, _ = calculate_and_save_best_images(s2, dates)
    biweekly_dates = np.array([day for day in range(0, 360, 5)])
    to_remove = np.argwhere(biweekly_dates % 15 != 0)
    s2 = np.delete(s2, to_remove, 0)
    fused = np.concatenate([s1, s2], axis = -1)
    # save fused data
    np.save("../data/raw/test-processed/{}.npy".format(plot_id), fused)
    #print(fused.shape)
    #windows = make_5d_array(fused, tile_lookups)
    #return windows
    # return (24, 32, 32, 16) array

def make_5d_array(arr, tile_lookups):
    arr_5d = np.empty((9, 24, 16, 16, 17))
    print(arr.shape)
    for i in range(len(tile_lookups.keys())):
        key = list(tile_lookups.keys())[i]
        start_x = tile_lookups[key][0]
        start_y = tile_lookups[key][1]
        arr_5d[i] = arr[:24, start_x:start_x+16, start_y:start_y+16, :]
    return arr_5d

def reconstruct_array(arr):
    
    out = np.copy(arr[0])
    
    center = arr[0]
    left = arr[1]
    right = arr[2]
    up = arr[3]
    down = arr[4]
    ul = arr[5]
    ur = arr[6]
    dl = arr[7]
    dr = arr[8]
    
    out[:, :8, :, :] = (center[:, :8, :, :] + left[:, 8:, :, :]) / 2
    print(np.sum(out) - np.sum(center))
    out[:, 8:, :, :] = (center[:, 8:, :, :] + right[:, :8, :, :]) / 2
    print(np.sum(out) - np.sum(center))
    out[:, :, 8:, :] = (center[:, :, 8:, :] + up[:, :, :8, :]) / 2
    print(np.sum(out) - np.sum(center))
    out[:, :, :8, :] = (center[:, :, :8, :] + down[:, :, 8:, :]) / 2
    print(np.sum(out) - np.sum(center))

In [None]:
to_process = [x for x in os.listdir("../data/raw/test-raw/")]
to_process = [x for x in to_process if x  in os.listdir("../data/raw/test-s1/")]
to_process = [str(x[:-4]) for x in  to_process if x not in os.listdir("../data/raw/test-processed/")]
for i in to_process:
    process_array(i)
#map(process_array, to_process)
#windows = process_array("135804022")    
        
#data_5d = make_5d_array(data, tile_lookups)
#reconstruct_array(data_5d)

In [None]:
def reconstruct_images(plot_id):
    '''Takes a plot ID and subsets the input pd.DataFrame to that plot ID
       returns a (14, 14) array-like list with binary labels
       
        Parameters:
          batch_ids (list):
          batch_size (int):
          
         Returns:
          x_batch (arr):
          y_batch (arr):
    '''
    subs = df[df['PLOT_ID'] == plot_id]
    rows = []
    lats = reversed(sorted(subs['LAT'].unique()))
    for i, val in enumerate(lats):
        subs_lat = subs[subs['LAT'] == val]
        subs_lat = subs_lat.sort_values('LON', axis = 0)
        rows.append(list(subs_lat['TREE']))
    return rows

source = 'test'
sentinel_1 = True
s2_path = "../data/raw/test-processed/"
s1_path = "../data/{}-s1/".format(source)
csv_path = "../data/{}-csv/".format(source)
output_path = "../data/{}-processed/".format(source)

# For either train or test data, loop through each plot and determine whether there is
# labelled Y data for it -- returning one dataframe for the entire data set

dfs = []
for i in os.listdir(csv_path):
    if ".csv" in i:
        print(i)
        df = pd.read_csv(csv_path + i).drop('IMAGERY_TITLE', axis = 1)
        df['country'] = i.split(".")[0]
        dfs.append(df)

for i in range(len(dfs)):
    if "PL_PLOTID" in dfs[i].columns:
            dfs[i] = dfs[i].drop("PL_PLOTID", axis = 1)
    if 'STACKINGPROFILEDG' in dfs[i].columns:
        dfs[i] = dfs[i].drop('STACKINGPROFILEDG', axis = 1)
    if 'IMAGERYYEARDG' in dfs[i].columns:
        dfs[i] = dfs[i].drop('IMAGERYYEARDG', axis = 1)

df = pd.concat(dfs, ignore_index = True)
df = df.dropna(axis = 0)

existing = [x for x in os.listdir(s2_path) if ".DS" not in x]
existing = [x for x in existing if x in os.listdir("../data/test-s1/") if ".DS" not in x]
existing = [int(x[:-4]) for x in existing if x in os.listdir("../data/test-s2/") if ".DS" not in x]
df = df[df['PLOT_ID'].isin(existing)]
plot_ids = sorted(df['PLOT_ID'].unique())
print(len(plot_ids))

In [None]:
# match these up with the 2-preprocessing-notebook

test_x = []
test_y = []
for i in range(len(plot_ids)):
    if str(plot_ids[i]) + ".npy" in os.listdir(s2_path):
        if plot_ids[i] != 136077593:
            x = np.load("../data/raw/test-processed/" + str(plot_ids[i]) + ".npy")
            s1 = x[:, :, :, :2]
            x = x[:, :, :, 2:]
            x = np.concatenate([x, s1], axis = -1)
            test_x.append(x)
            y = reconstruct_images(plot_ids[i])
            test_y.append(y)
test_x = np.stack(test_x)
test_y = np.stack(test_y)
print(test_x.shape)

In [None]:
np.save("../tile_data/processed/test_x_tile.npy", test_x)
np.save("../tile_data/processed/test_y_tile.npy", test_y)

In [None]:
sns.heatmap(data[0, 
                 tile_lookups['center'][0]:tile_lookups['center'][0]+16,
                 tile_lookups['center'][1]:tile_lookups['center'][1]+16,
                 2])