In [1]:
import tensorflow as tf
import os
import hickle as hkl
import numpy as np

In [2]:

%run ../src/preprocessing/indices.py

In [3]:
predict_model_path = "../models/may-avg-small-onethird/"

In [4]:
predict_graph_def = tf.compat.v1.GraphDef()
if os.path.exists(predict_model_path):
    print(f"Loading model from {predict_model_path}")
    predict_file = tf.io.gfile.GFile(predict_model_path + "predict_graph.pb", 'rb')
    predict_graph_def.ParseFromString(predict_file.read())
    predict_graph = tf.import_graph_def(predict_graph_def, name='predict')
    predict_sess = tf.compat.v1.Session(graph=predict_graph)
    predict_logits = predict_sess.graph.get_tensor_by_name(f"predict/conv2d_13/Sigmoid:0") 
    #feature_extraction = predict_sess.graph.get_tensor_by_name(f"predict/csse_out_mul/mul:0")  
    #feature_extraction_initial = predict_sess.graph.get_tensor_by_name(
    #    "predict/conv_median_drop/drop_block2d_1/cond/Merge:0")
    predict_inp = predict_sess.graph.get_tensor_by_name("predict/Placeholder:0")
    predict_length = predict_sess.graph.get_tensor_by_name("predict/PlaceholderWithDefault:0")
else:
    raise Exception(f"The model path {predict_model_path} does not exist")

Loading model from ../models/may-avg-small-onethird/


In [5]:
min_all = [0.006576638437476157, 0.0162050812542916, 0.010040436408026246, 0.013351644159609368, 
           0.01965362020294499, 0.014229037918669413, 0.015289539940489814, 0.011993591210803388, 
           0.008239871824216068, 0.006546120393682765, 0.0, 0.0, 0.0, -0.1409399364817101,
           -0.4973397113668104, -0.09731556326714398, -0.7193834232943873]

max_all = [0.2691233691920348, 0.3740291447318227, 0.5171435111009385, 0.6027466239414053, 
           0.5650263218127718, 0.5747005416952773, 0.5933928435187305, 0.6034943160143434,
           0.7472037842374304, 0.7000076295109483, 0.509269855802243, 0.948334642387533, 
           0.6729257769285485, 0.8177635298774327, 0.35768999002433816, 0.7545951919107605, 0.7602693339366691]

In [6]:

def convert_to_db(x: np.ndarray, min_db: int) -> np.ndarray:
    """ Converts unitless backscatter coefficient
        to db with a min_db lower threshold
        
        Parameters:
         x (np.ndarray): unitless backscatter (T, X, Y, B) array
         min_db (int): integer from -50 to 0
    
        Returns:
         x (np.ndarray): db backscatter (T, X, Y, B) array
    """
    
    x = 10 * np.log10(x + 1/65535)
    x[x < -min_db] = -min_db
    x = (x + min_db) / min_db
    return np.clip(x, 0, 1)

In [8]:
import pandas as pd
x = hkl.load("../data/test/test_x.hkl")
test_y = hkl.load("../data/test/test_y.hkl")
df = pd.read_csv("../data/test/test_plot_ids.csv")
x = np.delete(x, 11, -1)
x.shape

(1424, 12, 28, 28, 13)

In [11]:
x = np.float32(x) / 65535

x[..., -1] = convert_to_db(x[..., -1], 22)
x[..., -2] = convert_to_db(x[..., -2], 22)

indices = np.empty((x.shape[0], 12, 28, 28, 4))
indices[..., 0] = evi(x)
indices[..., 1] = bi(x)
indices[..., 2] = msavi2(x)
indices[..., 3] = grndvi(x)

x = np.concatenate([x, indices], axis = -1)

In [12]:
idx = 0
pd.set_option('display.float_format', lambda x: '%.5f' % x)

def preprocess_sample(test, idx):

    med = np.median(test[idx], axis = 0)
    med = med[np.newaxis, :, :, :]
    sample = np.concatenate([test[idx], med], axis = 0)
    
    for band in range(0, sample.shape[-1]):
        mins = min_all[band]
        maxs = max_all[band]
        sample[..., band] = np.clip(sample[..., band], mins, maxs)
        midrange = (maxs + mins) / 2
        rng = maxs - mins
        standardized = (sample[..., band] - midrange) / (rng / 2)
        sample[..., band] = standardized
    return sample


In [13]:
idx = 0

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

print(idx)
for i in range(0, 50):
    print(idx, df.iloc[idx])
    sample = preprocess_sample(x, idx)
    batch_x = sample[np.newaxis]
    lengths = np.full((batch_x.shape[0]), 12)
    preds = predict_sess.run(predict_logits,
                          feed_dict={predict_inp:batch_x, 
                                     predict_length:lengths})

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (9, 4))

    sns.heatmap(preds.squeeze(), vmin = 0.0, vmax = 1, ax=ax1, cbar = False)
    sns.heatmap(test_y[idx], vmin = 0.0, vmax = 1, ax=ax2, cbar = False)
    plt.show()
    idx += 1
    

In [13]:
al = 0.33

def make_evaluation_csv(data):
    test_ids = [x for x in range(len(x))]
    print(len(test_ids))
    sums = []
    sum_preds = []
    trues = []
    preds = []
    preds_arr = np.zeros_like(test_y)
    for test_sample in test_ids:
        #x_input = x[test_sample]#.reshape(1, 13, 28, 28, 17)
        x_input = preprocess_sample(x, test_sample)
        batch_x = x_input[np.newaxis]
        lengths = np.full((batch_x.shape[0]), 12)
        y = predict_sess.run(predict_logits,
                              feed_dict={predict_inp:batch_x, 
                                         predict_length:lengths})
        preds.append(y.reshape((14, 14)))
        trues.append(test_y[test_sample].reshape((14, 14)))
        preds_arr[test_sample] = y.reshape((14, 14))
    thresh = 0.4
    tps_relaxed = np.empty((len(preds), ))
    fps_relaxed = np.empty((len(preds), ))
    fns_relaxed = np.empty((len(preds), ))
    abs_error = np.empty((len(preds), ))
    
    tree_cover = []
    for sample in range(len(preds)):
        pred = np.copy(preds[sample])
        true = trues[sample]
        if thresh == 8:
            if np.sum(true + pred) > 0:
                dice_losses.append(0.5)
               # dice_losses.append(dice_loss_tolerance(np.array(true), np.array(pred)))
            else:
                dice_losses.append(1.)
        pred[np.where(pred >= thresh)] = 1
        pred[np.where(pred < thresh)] = 0

        true_s = np.sum(true)
        pred_s = np.sum(pred)
        
        tp_relaxed, fp_relaxed, fn_relaxed = compute_f1_score_at_tolerance(true, pred)
        abs_error[sample] = int((true_s - pred_s) // 1.96)
        print(abs_error[sample])
        tps_relaxed[sample] = tp_relaxed
        fps_relaxed[sample] = fp_relaxed
        fns_relaxed[sample] = fn_relaxed       
        tree_cover.append(int( (np.sum(true) * 100) // 196))

    oa_error = np.mean(abs_error)
    precision_r = np.sum(tps_relaxed) / (np.sum(tps_relaxed) + np.sum(fps_relaxed))
    recall_r = np.sum(tps_relaxed) / (np.sum(tps_relaxed) + np.sum(fns_relaxed))
    f1_r = 2*((precision_r* recall_r) / (precision_r + recall_r))
    data['error'] = abs_error
    data['tp'] = tps_relaxed
    data['fp'] = fps_relaxed
    data['fn'] = fns_relaxed
    data['tree_cover'] = tree_cover
    
    return data, preds_arr

In [14]:
def compute_f1_score_at_tolerance(true, pred, tolerance = 1):
    """Because of coregistration errors, we evaluate the model
    where false positives/negatives must be >1px away from a true positive
    """
    fp = 0
    tp = 0
    fn = 0
    
    tp = np.zeros_like(true)
    fp = np.zeros_like(true)
    fn = np.zeros_like(true)
    
    for x in range(true.shape[0]):
        for y in range(true.shape[1]):
            min_x = np.max([0, x-1])
            min_y = np.max([0, y-1])
            max_y = np.min([true.shape[0], y+2])
            max_x = np.min([true.shape[0], x+2])
            if true[x, y] == 1:
                if np.sum(pred[min_x:max_x, min_y:max_y]) > 0:
                    tp[x, y] = 1
                else:
                    fn[x, y] = 1
            if pred[x, y] == 1:
                if np.sum(true[min_x:max_x, min_y:max_y]) > 0:
                    if true[x, y] == 1:
                        tp[x, y] = 1
                else:
                    fp[x, y] = 1                
                
    return np.sum(tp), np.sum(fp), np.sum(fn)

def calculate_metrics(al = 0.4, canopy_thresh = 100):
    '''Calculates the following metrics
       
         - Loss
         - F1
         - Precision
         - Recall
         - Dice
         - Mean surface distance
         - Average error
    
         Parameters:
          al (float):
          canopy_thresh (int)
          
         Returns:
          val_loss (float):
          best_dice (float):
          error (float):
    '''
    best_f1, best_thresh, relaxed_f1 = 0, 0, 0
    preds, trues, vls = [], [], []

    for test_sample in range(x.shape[0]):
        sample = preprocess_sample(x, test_sample)
        batch_x = sample[np.newaxis]
        lengths = np.full((batch_x.shape[0]), 12)
        y = predict_sess.run(predict_logits,
                              feed_dict={predict_inp:batch_x, 
                                         predict_length:lengths})
        preds.append(y.reshape((14, 14)))
        trues.append(test_y[test_sample].reshape((14, 14)))

    # These threshes are just for ROC
    for thresh in range(7, 9):
        tps_relaxed = np.empty((len(preds), ))
        fps_relaxed = np.empty((len(preds), ))
        fns_relaxed = np.empty((len(preds), ))
        abs_error = np.empty((len(preds), ))
        
        for sample in range(len(preds)):
            pred = np.copy(preds[sample])
            true = trues[sample]
        
            pred[np.where(pred >= thresh*0.05)] = 1
            pred[np.where(pred < thresh*0.05)] = 0
            
            true_s = np.sum(true[1:-1])
            pred_s = np.sum(pred[1:-1])
            abs_error[sample] = abs(true_s - pred_s)
            tp_relaxed, fp_relaxed, fn_relaxed = compute_f1_score_at_tolerance(true, pred)
            tps_relaxed[sample] = tp_relaxed
            fps_relaxed[sample] = fp_relaxed
            fns_relaxed[sample] = fn_relaxed                   
            
        oa_error = np.mean(abs_error)
        precision_r = np.sum(tps_relaxed) / (np.sum(tps_relaxed) + np.sum(fps_relaxed))
        recall_r = np.sum(tps_relaxed) / (np.sum(tps_relaxed) + np.sum(fns_relaxed))
        f1_r = 2*((precision_r* recall_r) / (precision_r + recall_r))
        
        if f1_r > best_f1:
            best_f1 = f1_r
            p = precision_r
            r = recall_r
            error = oa_error
            best_thresh = thresh*0.05

    print(f" Thresh: {np.around(best_thresh, 2)}"
          f" F1: {np.around(best_f1, 3)} R: {np.around(p, 3)} P: {np.around(r, 3)}"
          f" Error: {np.around(error, 3)}")
    return np.mean(vls), best_f1, error, (fps_relaxed + fns_relaxed)

In [None]:
df = pd.DataFrame(columns = ['scale', 'error', 'Reference tree cover (%)', 'Threshold'])

def check_treecover_accuracy_at_grain(y, pred, scale, mincc, maxcc, threshold):
    y_scale = y[:, 2:2+scale, 2:2+scale]
    pred_scale = pred[:, 2:2+scale, 2:2+scale]
    pred_scale[pred_scale < 0.25] = 0.
    y_scale_mean = np.mean(y_scale, axis = (1, 2))
    if threshold == True:
        pred_scale_mean = np.mean(pred_scale > 0.4, axis = (1, 2))
    else:
        pred_scale_mean = np.mean(pred_scale, axis = (1, 2))
    return abs(y_scale_mean - pred_scale_mean)[np.logical_and(y_scale_mean >= mincc /100, y_scale_mean < maxcc / 100)]

for scale in [3, 5, 7, 10]:
    for mincc in [0, 10, 40]:
        for threshold in [False]:
            if mincc == 0:
                maxcc = 10
            if mincc == 10:
                maxcc = 40
            if mincc == 40:
                maxcc = 100
            error1ha = check_treecover_accuracy_at_grain(test_y, preds_arr, scale, mincc, maxcc, threshold)
            cc_label = f'{str(mincc)}-{str(maxcc)}'
            m = pd.DataFrame({"scale": [scale * 10] * len(error1ha),
                              "error":  error1ha * 100, 
                              'Reference tree cover (%)': [cc_label] * len(error1ha),
                              'Threshold': [threshold] * len(error1ha)})
            print(scale, mincc, threshold, len(error1ha))
            df = df.append(m, ignore_index = True)
print(np.mean(df['error']))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(font_scale = 2.5)
sns.set_style("ticks")
plt.figure(figsize=(20, 10))
facet = sns.FacetGrid(df, col='Reference tree cover (%)', sharey = False, height = 12, aspect = 0.9)
facet.map(sns.violinplot, "scale", "error", 'Threshold', cut = 0, bw = 0.15, inner = 'quartile')
facet.set(xlabel = "Scale (meters)", ylabel = "Absolute error (%)")
#l.set(ylim = (0, 1.))