In [None]:
import numpy as np
import hickle as hkl
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

%run ../src/downloading/utils.py

In [None]:
%run ../src/pca/pca_filter.py

In [None]:
def pif_pca(input_year, reference):
    pif_mask = numpy.ones((142* 142*5*5), dtype=numpy.bool)
    for date in range(0, 24, 2):
        for band in range(0, 5):
            pif_band_mask = pca_fit_and_filter_pixel_list(input_year[date, ..., band].flatten(),
                                            reference[date, ..., band].flatten(),
                                            660)
            pif_mask = numpy.logical_and(pif_mask, pif_band_mask)
            print(np.sum(pif_mask))
    return np.argwhere(pif_mask == True)

In [None]:
pifs = np.argwhere(pif_mask == True)

In [None]:
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
np.stack([a, b]).T

In [None]:
landscape = 'pilot-kenya'
path = f"../tile_data/{landscape}/"

In [None]:
from tqdm import tnrange, tqdm_notebook
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr
from sklearn.cross_decomposition import CCA



def load_years(x, y):
    x17 = hkl.load(path + f"2017/processed/{str(x)}/{str(y)}.hkl")
    x18 = hkl.load(path + f"2018/processed/{str(x)}/{str(y)}.hkl")
    x19 = hkl.load(path + f"2019/processed/{str(x)}/{str(y)}.hkl")
    #x17 = hkl.load(path + f"2017/raw/s2/0_0.hkl")
    #x18 = hkl.load(path + f"2018/raw/s2/0_0.hkl")
    #x19 = hkl.load(path + f"2019/raw/s2/0_0.hkl")
    return x17, x18, x19

def load_reference_years(input_year, target_year):
    inp = np.empty((24, 5, 5, 142, 142, 17))
    targ = np.empty((24, 5, 5, 142, 142, 17))
    for x in tnrange(0, 5):
        for y in range(0, 5):
            inp_x = hkl.load(path + f"{str(input_year)}/processed/{str(x)}/{str(y)}.hkl")
            targ_x = hkl.load(path + f"{str(target_year)}/processed/{str(x)}/{str(y)}.hkl")
            inp[:, x, y, ...] = inp_x[...,]
            targ[:, x, y, ...] = targ_x[...]
    inp = inp.reshape(24, 5*5*142*142, 17)
    targ = targ.reshape(24, 5*5*142*142, 17)
    return inp, targ

#inp, targ = load_reference_years(2017, 2019)
x17, x18, x19 = load_years(3, 2)
#x17 = np.delete(x17, 10, -1)
#x19 = np.delete(x19, 10, -1)

sns.lineplot([x for x in range(0,x17.shape[0])], np.mean(x17[..., 3], axis = (1, 2)))
sns.scatterplot([x for x in range(0,x18.shape[0])],  np.mean(x18[..., 3], axis = (1, 2)))
sns.scatterplot([x for x in range(0,x19.shape[0])],  np.mean(x19[..., 3], axis = (1, 2)))

In [None]:
all_data = np.concatenate([x17, x18, x19], axis = 0)
g = sns.scatterplot([x for x in range(0,all_data.shape[0])], np.mean(all_data[..., 15], axis = (1, 2)))
#g.set(ylim=(15000, 21000))
plt.show()

In [None]:
all_data = np.concatenate([x17, x18, x19], axis = 0)
sns.scatterplot([x for x in range(0,all_data.shape[0])], np.mean(all_data[..., 3], axis = (1, 2)))

In [None]:
from sklearn.decomposition import PCA
diffs_all = np.zeros((142, 142))
for date in tnrange(0, 1):
    x = x17[date, ..., :10].reshape(142*142, 10)
    y = x19[date, ..., :10].reshape(142*142, 10)
    pca = PCA(n_components=2)
    pca.fit(x, y)
    #xs = cca.transform(x)
    

In [None]:
diffs_all = np.zeros((142, 142))
for date in tnrange(0, 24):
    x = x17[date, ..., :10].reshape(142*142, 10)
    y = x19[date, ..., :10].reshape(142*142, 10)
    cca = CCA(n_components=2)
    cca.fit(x, y)
    xs = cca.transform(x)
    ys = cca.transform(y)
    diffs = np.sum(abs(xs - ys)**2, axis = 1)

    diffs = diffs.reshape(142, 142)
    diffs_all += diffs
sns.heatmap(diffs_all.reshape(142, 142))
diffs_all = diffs_all.reshape(142*142)
pifs = np.argwhere(diffs_all < np.percentile(diffs_all, 1))

In [None]:
output = x17.copy()
    
for date in range(0, 24):
    for band in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14]:
        input_ = x17[..., band].reshape(24, 142*142)
        target = x19[..., band].reshape(24, 142*142)

        target = target[:, pifs]
        input_ = input_[:, pifs]
        target_date = target[date].squeeze()[:, np.newaxis]
        input_date = input_[date].squeeze()[:, np.newaxis]
        reg = LinearRegression()
        reg.fit(input_date, target_date)
        input_updated = reg.predict(x17[date, ..., band].reshape(142*142, 1))

        output[date, ..., band] = input_updated.reshape(1, 142, 142)

In [None]:
cca.coef_.shape

In [None]:
x = x17[:, ..., :4].reshape(24, 142*142, 4)
x = np.swapaxes(x, 0, 1).reshape(142*142, 24*4)
y = x19[:, ..., :4].reshape(24, 142*142, 4)
y = np.swapaxes(y, 0, 1).reshape(142*142, 24*4)
cca = CCA(n_components=3)
cca.fit(x, y)
xs = cca.transform(x)
ys = cca.transform(y)
#xs = np.sum(xs, axis = 1)
#ys = np.sum(ys, axis = 1)
diffs = abs(xs - ys)
diffs = np.sum(diffs, axis = 1)
#diffs = (diffs - np.mean(diffs)) / np.std(diffs)

diffs = diffs.reshape(142, 142)
cutoff = np.percentile(diffs, 0.1)
new = np.zeros_like(diffs)
new[np.where(diffs < cutoff)] = 1.
#ys = ys.reshape(142, 142)
#diffs = abs(xs - ys)
sns.heatmap(diffs)

In [None]:
cca.x_weights_

In [None]:
pearsons = np.empty(142*142)
input_year = x17[..., 3].reshape(24, 142*142)
reference = x19[..., 3].reshape(24, 142*142)
    
for i in tnrange(pearsons.shape[0]):
    pearsons[i] = pearsonr(input_year[:, i], reference[:, i])[0]
pearsons = np.reshape(pearsons, (142, 142))
sns.heatmap(pearsons)

In [None]:
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr

from sklearn.cross_decomposition import CCA

def identify_pifs_pairs(input_year, reference):
    diffs_all = np.zeros(142*142*5*5)
    for date in tnrange(0, 24):
        x = input_year[date, ..., :10].reshape(142*142*5*5, 10)
        y = reference[date, ..., :10].reshape(142*142*5*5, 10)
        cca = CCA(n_components=2)
        cca.fit(x, y)
        xs = cca.transform(x)
        ys = cca.transform(y)
        diffs = np.sum(abs(xs - ys), axis = 1)
        diffs = diffs.reshape(142*142*5*5)
        diffs_all += diffs
    diffs = np.argwhere(diffs_all < np.percentile(diffs, 0.5))
    return diffs

def identify_pifs_new(input_year, reference):
    x = input_year[:, :, :10]
    x = np.swapaxes(x, 0, 1).reshape(142*142*5*5, 24*10)
    y = reference[:, ..., :10]
    y = np.swapaxes(y, 0, 1).reshape(142*142*5*5, 24*10)
    cca = CCA(n_components=2)
    xs, ys = cca.fit_transform(x, y)
    diffs = abs(xs - ys)
    diffs = np.sum(diffs**2, axis = 1)
    diffs = (diffs - np.mean(diffs)) / np.std(diffs)
    diffs = np.argwhere(diffs < np.percentile(diffs, 1))
    return diffs#np.argwhere(diffs < np.percentile(diffs, 1))

def identify_pifs(input_year, reference):
    pearsons = np.empty(142*142*5*5)
    input_year = input_year[..., 3].reshape(24, 142*142*5*5)
    reference = reference[..., 3].reshape(24, 142*142*5*5)
    
    for i in tnrange(pearsons.shape[0]):
        pearsons[i] = pearsonr(input_year[:, i], reference[:, i])[0]
    
    pifs = np.argwhere(pearsons > np.percentile(pearsons, 99))
    #stacked = np.concatenate([input_year, reference], axis = 0)
    #stacked = stacked[..., 3].reshape(48, 142*142*5*5)
    #stacked = np.percentile(stacked, 75, 0) - np.percentile(stacked, 25, 0)
    #pifs = np.argwhere(stacked < np.percentile(stacked, 2))
    #diffs = abs( 1 - (abs(input_year[..., 3] / reference[..., 3])) )#
    #diffs = diffs.reshape(24, 142*142)
    #diffs = np.std(diffs, axis = (0))
    #pifs = np.argwhere((diffs < 0.05))
    return pifs, pearsons


def linear_adjust_pif(pifs, small_input, small_target, large_input, large_target):
    output = small_input.copy()
    
    for date in range(0, 24):
        for band in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14]:
            input_ = large_input[..., band].reshape(24, 142*142*5*5)
            target = large_target[..., band].reshape(24, 142*142*5*5)
            
            target = target[:, pifs]
            input_ = input_[:, pifs]
            target_date = target[date].squeeze()[:, np.newaxis]
            input_date = input_[date].squeeze()[:, np.newaxis]
            reg = LinearRegression().fit(input_date, target_date)
            input_updated = reg.predict(small_input[date, ..., band].reshape(142*142, 1))

            output[date, ..., band] = input_updated.reshape(1, 142, 142)
    return output

#pifs = identify_pifs_new(inp, targ)
pifs = identify_pifs_new(x17, x19)
#sns.scatterplot([x for x in range(len(diffs))], sorted(diffs))


In [None]:
output = linear_adjust_pif(pifs, x17, x19, inp, targ)

In [None]:
idx = 1212
sns.scatterplot([x for x in range(24)], x17[..., 5].reshape(24, 142*142)[:, idx])
sns.lineplot([x for x in range(24)], x19[..., 5].reshape(24, 142*142)[:, idx])
sns.lineplot([x for x in range(24)], output[..., 5].reshape(24, 142*142)[:, idx])

In [None]:
sns.scatterplot([x for x in range(0,24)], np.mean(x17[..., 3], axis = (1, 2)))
sns.lineplot([x for x in range(0,24)],  np.mean(output[..., 3], axis = (1, 2)))
sns.lineplot([x for x in range(0,24)],  np.mean(x19[..., 3], axis = (1, 2)))

### for date in range(0, 24):
    x17[date] = hist_match(x17[date], x19[date])
    x18[date] = hist_match(x18[date], x19[date])

In [None]:
sns.scatterplot([x for x in range(0,24)], np.mean(x17[..., 9], axis = (1, 2)))
sns.scatterplot([x for x in range(0,24)],  np.mean(x18[..., 9], axis = (1, 2)))
sns.scatterplot([x for x in range(0,24)],  np.mean(x19[..., 9], axis = (1, 2)))