In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import hickle as hkl
import os

In [None]:
%run ../../src/downloading/utils.py
%run ../../src/models/utils.py

In [None]:
bands = np.concatenate([np.full(shape = (1, 16, 16, 1), 
                          fill_value = x) for x in range(0, 7)], axis = 0)

image_dates = np.array([0, 22, 105, 232, 295, 310, 330])

bands, max_time = calculate_and_save_best_images(bands, image_dates)
band_means = np.mean(bands, axis = (1, 2, 3))

In [None]:
band_means[2]

In [None]:
clouds_a = hkl.load("../../tile_data/senegal-tucker-a/raw/clouds/clouds_0_1.hkl")

In [None]:
np.argmax(np.mean(clouds_a, axis = (1, 2)))

In [None]:
sns.heatmap(clouds_a[np.argmax(np.mean(clouds_a, axis = (1, 2))), :300, :300])

# Cloud and cloud shadow interpolation

In [None]:
def remove_cloud_and_shadows(tiles, probs, shadows, image_dates, wsize = 9):
    """ Interpolates clouds and shadows for each time step with 
        linear combination of proximal clean time steps for each
        region of specified window size
        
        Parameters:
         tiles (arr):
         probs (arr): 
         shadows (arr):
         image_dates (list):
         wsize (int): 
    
        Returns:
         tiles (arr): 
    """
    
    def _fspecial_gauss(size, sigma):
        x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
        g = np.exp(-((x**2 + y**2)/(2.0*sigma**2)))
        return g

    c_arr = np.reshape(_fspecial_gauss(wsize, ((wsize/2) - 1 ) / 2), (1, wsize, wsize, 1))
    o_arr = 1 - c_arr

    c_probs = np.copy(probs) - np.min(probs, axis = 0)
    c_probs[np.where(c_probs >= 0.33)] = 1.
    c_probs[np.where(c_probs < 0.33)] = 0.
    
    c_probs += shadows
    c_probs[np.where(c_probs >= 1.)] = 1.
    n_interp = 0
    
    
    for x in range(0, tiles.shape[1] - (wsize - 1), 1):
        for y in range(0, tiles.shape[2] - (wsize - 1), 1):
            subs = c_probs[:, x:x + wsize, y:y+wsize]
            satisfactory = np.argwhere(np.sum(subs, axis = (1, 2)) < (wsize*wsize)/10)
            for date in range(0, tiles.shape[0]):
                if np.sum(subs[date]) >= (wsize*wsize)/10:
                    n_interp += 1
                    before, after = calculate_proximal_steps(date, satisfactory)
                    before = date + before
                    after = date + after
                    after = before if after >= tiles.shape[0] else after
                    before = after if before < 0 else before

                    before_array = tiles[before, x:x+wsize, y:y+wsize, : ]
                    after_array = tiles[after, x:x+wsize, y:y+wsize, : ]
                    original_array = tiles[np.newaxis, date, x:x+wsize, y:y + wsize, :]
                    
                    n_days_before = abs(image_dates[date] - image_dates[before])
                    n_days_after = abs(image_dates[date] - image_dates[after])
                    before_weight = 1 - n_days_before / (n_days_before + n_days_after)
                    after_weight = 1 - before_weight
                    
                    candidate = before_weight*before_array + after_weight * after_array
                    candidate = candidate * c_arr + original_array[np.newaxis] * o_arr
                    tiles[date, x:x+wsize, y:y+wsize, : ] = candidate 
                    
    print("Interpolated {} px".format(n_interp))
    return tiles

In [None]:
clouds = np.zeros((3, 632, 632))
shadows = np.zeros((3, 632, 632))
image_dates = np.array([3, 15, 30])
clouds[1] = clouds_a[32]
clouds[np.where(clouds >= 0.33)] = 1.
clouds[np.where(clouds < 0.33)] = 0.


#clouds[1, 18:28, 0:12] = 1.
#clouds[1, 10:20, 8:15] = 1.

#clouds[1, 18:20, 26:28] = 1.

#clouds[1, 2:5, 26:29] = 1.
# Things to watch for
# before, after = 0, 2
# mean(bef) = 0
# mean(aft) = 2.0

In [None]:
bands = ([np.full(shape = (1, 632, 632, 1), 
                          fill_value = float(x)) for x in range(0, 3)])
bands = np.concatenate(bands, axis = 0)
tiles = remove_cloud_and_shadows(bands, clouds, shadows, image_dates)

In [None]:
sns.heatmap(tiles[1, :200, :200, 0])

In [None]:
sns.heatmap(clouds[1, :200, :200])

# Tiling validation

In [None]:
arr = np.random.rand(1, 128, 128)
tile_arr = np.stack(tile_images(arr))

In [None]:
tile_arr.shape

In [None]:
SIZE = 9
SIZE_N = SIZE*SIZE
SIZE_UR = (SIZE - 1) * (SIZE - 1)
SIZE_R = (SIZE - 1) * SIZE
SIZE_U = SIZE_R
TOTAL = SIZE_N + SIZE_UR + SIZE_R + SIZE_U

def validate_tiling(arr):
    preds = np.stack(tile_images(arr))
    preds = preds[:, 0, 1:-1, 1:-1]
    
    preds_stacked = []
    for i in range(0, SIZE_N, SIZE):
        preds_stacked.append(np.concatenate(preds[i:i + SIZE], axis = 1))
    stacked = np.concatenate(preds_stacked, axis = 0)
        
    preds_overlap = []
    for scene in range(SIZE_N, SIZE_N+SIZE_UR, SIZE - 1):
        to_concat = np.concatenate(preds[scene:scene+ (SIZE - 1)], axis = 1)
        preds_overlap.append(to_concat)    
    overlapped = np.concatenate(preds_overlap, axis = 0)
    overlapped = np.pad(overlapped, (7, 7), 'constant', constant_values = 0)
        
    preds_up = []
    for scene in range(SIZE_N+SIZE_UR, SIZE_N+SIZE_UR+SIZE_R, SIZE):
        to_concat = np.concatenate(preds[scene:scene+SIZE], axis = 1)
        preds_up.append(to_concat)   
    up = np.concatenate(preds_up, axis = 0)
    up = np.pad(up, ((7,7), (0,0)), 'constant', constant_values = 0)
        
    preds_right = []
    for scene in range(SIZE_N+SIZE_UR+SIZE_R, TOTAL, SIZE - 1):
        to_concat = np.concatenate(preds[scene:scene+SIZE-1], axis = 1)
        preds_right.append(to_concat)   
    right = np.concatenate(preds_right, axis = 0)
    right = np.pad(right, ((0, 0), (7, 7)), 'constant', constant_values = 0)
    
    stacked = stacked + overlapped + right + up
    return stacked

In [None]:
validated = validate_tiling(arr)

np.sum(validated[7:-7, 7:-7]/4 - (arr[0, 8:-8, 8:-8]))