# Tree segmentation with multitemporal Sentinel 1/2 imagery

## John Brandt
## December 2023

## This notebook finetunes the TTC decoder for a new task

## Package Loading

In [1]:
from tqdm import tqdm_notebook, tnrange
import tensorflow as tf

sess = tf.Session()
from keras import backend as K
K.set_session(sess)

from time import sleep

import keras
from tensorflow.python.keras.layers import *
from tensorflow.python.keras.layers import ELU
from keras.losses import binary_crossentropy
from tensorflow.python.ops import array_ops
from tensorflow.python.keras.layers import Conv2D, Lambda, Dense, Multiply, Add
from tensorflow.initializers import glorot_normal, lecun_normal
from scipy.ndimage import median_filter
from skimage.transform import resize

import pandas as pd
import numpy as np
from random import shuffle
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import os
import random
import itertools
from tensorflow.contrib.framework import arg_scope
from keras.regularizers import l1
from tensorflow.layers import batch_normalization
from tensorflow.python.util import deprecation as deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

os.environ['KMP_DUPLICATE_LIB_OK']='True'

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])





  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
2023-12-22 12:39:56.310226: I tensorflow/core/platform/cpu_feature_guard.cc:145] This TensorFlow binary is optimized with Intel(R) MKL-DNN to use the following CPU instructions in performance critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in non-MKL-DNN operations, rebuild TensorFlow with the appropriate compiler flags.
2023-12-22 12:39:56.310907: I tensorflow/core/common_runtime/process_util.cc:115] Creating new thread pool with default inter op setting: 8. Tune using inter_op_parallelism_threads for best performance.
Using TensorFlow backend.





## Utility scripts

In [2]:
%run ../src/layers/zoneout.py
%run ../src/layers/losses.py
%run ../src/layers/adabound.py
%run ../src/layers/convgru.py
%run ../src/layers/dropblock.py
%run ../src/layers/extra_layers.py
%run ../src/layers/stochastic_weight_averaging.py
%run ../src/preprocessing/indices.py
%run ../src/preprocessing/slope.py
#%run ../src/utils/metrics.py
#%run ../src/utils/lovasz.py




# Hyperparameter definitions

In [3]:
ZONE_OUT_PROB = 0.90
ACTIVATION_FUNCTION = 'swish'

INITIAL_LR = 1e-3
DROPBLOCK_MAXSIZE = 5

N_CONV_BLOCKS = 1
FINAL_ALPHA = 0.33
LABEL_SMOOTHING = 0.03

L2_REG = 0.
BATCH_SIZE = 32
MAX_DROPBLOCK = 0.6

FRESH_START = True
best_val = 0.2

START_EPOCH = 1
END_EPOCH = 100

n_bands = 17
initial_flt = 32
mid_flt = 32 * 2
high_flt = 32 * 2 * 2

temporal_model = True
input_size = 124
output_size = 14

# Custom layer definitions

### Conv GRU Block

In [4]:
def gru_block(inp, length, size, flt, scope, train, normalize = True):
    '''Bidirectional convolutional GRU block with 
       zoneout and CSSE blocks in each time step

         Parameters:
          inp (tf.Variable): (B, T, H, W, C) layer
          length (tf.Variable): (B, T) layer denoting number of
                                steps per sample
          size (int): kernel size of convolution
          flt (int): number of convolution filters
          scope (str): tensorflow variable scope
          train (tf.Bool): flag to differentiate between train/test ops
          normalize (bool): whether to compute layer normalization

         Returns:
          gru (tf.Variable): (B, H, W, flt*2) bi-gru output
          steps (tf.Variable): (B, T, H, W, flt*2) output of each step
    '''
    with tf.variable_scope(scope):
        print(f"GRU input shape {inp.shape}, zoneout: {ZONE_OUT_PROB}")
        
        # normalize is internal group normalization within the reset gate
        # sse is internal SSE block within the state cell

        cell_fw = ConvGRUCell(shape = size, filters = flt,
                           kernel = [3, 3], padding = 'VALID', 
                           normalize = normalize, sse = True)
        cell_bw = ConvGRUCell(shape = size, filters = flt,
                           kernel = [3, 3], padding = 'VALID',
                           normalize = normalize, sse = True)
        
        cell_fw = ZoneoutWrapper(
           cell_fw, zoneout_drop_prob = 0.75, is_training = train)
        cell_bw = ZoneoutWrapper(
            cell_bw, zoneout_drop_prob = 0.75, is_training = train)
        steps, out = convGRU(inp, cell_fw, cell_bw, length)
        gru = tf.concat(out, axis = -1)
        steps = tf.concat(steps, axis = -1)
        print(f"GRU block output shape {gru.shape}")
    return gru, steps

# Model definition

## Placeholders

In [5]:
reg = tf.contrib.layers.l2_regularizer(0.)
temporal_model = True
n_bands = 17
output_size = input_size - 14

if temporal_model:
    inp = tf.placeholder(tf.float32, shape=(None, 5, input_size, input_size, n_bands))
    length = tf.placeholder_with_default(np.full((1,), 4), shape = (None,))
else:
    inp = tf.placeholder(tf.float32, shape=(None, input_size, input_size, n_bands))
    
labels = tf.placeholder(tf.float32, shape=(None, output_size, output_size))#, 1))
mask = tf.placeholder(tf.float32, shape = (None, output_size, output_size))
keep_rate = tf.placeholder_with_default(1.0, ()) # For DropBlock
is_training = tf.placeholder_with_default(False, (), 'is_training') # For DropBlock
alpha = tf.placeholder(tf.float32, shape = ()) # For loss scheduling
ft_lr = tf.placeholder_with_default(0.001, shape = ()) # For loss scheduling
loss_weight = tf.placeholder_with_default(1.0, shape = ())
beta_ = tf.placeholder_with_default(0.0, shape = ()) # For loss scheduling, not currently implemented

INFO:tensorflow:Scale of 0 disables regularizer.




## Layers

In [6]:
# master modmel is 32, 64, 96, 230k paramms
initial_flt = 64
mid_flt = initial_flt * 2
high_flt = initial_flt * 2 * 2

gru_input = inp[:, :-1, ...]
gru, steps = gru_block(inp = gru_input, length = length,
                            size = [124, 124, ], # + 2 here for refleclt pad
                            flt = initial_flt // 2,
                            scope = 'down_16',
                            train = is_training)
with tf.variable_scope("gru_drop"):
    drop_block = DropBlock2D(keep_prob=keep_rate, block_size=4)
    gru = drop_block(gru, is_training)
    
# Median conv
median_input = inp[:, -1, ...]
median_conv = conv_swish_gn(inp = median_input, is_training = is_training, stride = (1, 1),
            kernel_size = 3, scope = 'conv_median', filters = initial_flt, 
            keep_rate = keep_rate, activation = True, use_bias = False, norm = True,
            csse = True, dropblock = True, weight_decay = None,
                            window_size = 104)
print(f"Median conv: {median_conv.shape}")

concat1 = tf.concat([gru, median_conv], axis = -1)


concat = conv_swish_gn(inp = concat1, is_training = is_training, stride = (1, 1),
            kernel_size = 3, scope = 'conv_concat', filters = initial_flt,
            keep_rate = keep_rate, activation = True, use_bias = False, norm = True,
            csse = True, dropblock = True, weight_decay = None, padding = "SAME",
                       window_size = 104)
print(f"Concat: {concat.shape}")

# MaxPool-conv-swish-GroupNorm-csse
pool1 = MaxPool2D()(concat)
conv1 = conv_swish_gn(inp = pool1, is_training = is_training, stride = (1, 1),
            kernel_size = 3, scope = 'conv1', filters = mid_flt,
            keep_rate = keep_rate, activation = True, use_bias = False, norm = True, padding = "VALID",
            csse = True, dropblock = True, weight_decay = None)
print(f"Conv1: {conv1.shape}")

# MaxPool-conv-swish-csse-DropBlock
pool2 = MaxPool2D()(conv1)
conv2 = conv_swish_gn(inp = pool2, is_training = is_training, stride = (1, 1),
            kernel_size = 3, scope = 'conv2', filters = high_flt, 
            keep_rate = keep_rate, activation = True, use_bias = False, norm = True,
            csse = True, dropblock = True, weight_decay = None, block_size = 4, padding = "VALID",
                     window_size = 24)
print("Encoded", conv2.shape)

# Decoder 4 - 8, upsample-conv-swish-csse-concat-conv-swish
up2 = tf.keras.layers.UpSampling2D((2, 2), interpolation = 'nearest')(conv2)
#up2 = ReflectionPadding2D((1, 1,))(up2)
up2 = conv_swish_gn(inp = up2, is_training = is_training, stride = (1, 1),
                    kernel_size = 3, scope = 'up2', filters = mid_flt, 
                    keep_rate = keep_rate, activation = True,use_bias = False, norm = True,
                    csse = True, dropblock = True, weight_decay = None)
conv1_crop = Cropping2D(2)(conv1)
print(conv1_crop.shape)
up2 = tf.concat([up2, conv1_crop], -1)
#up2 = ReflectionPadding2D((1, 1,))(up2)
up2 = conv_swish_gn(inp = up2, is_training = is_training, stride = (1, 1),
                    kernel_size = 3, scope = 'up2_out', filters = mid_flt, 
                    keep_rate =  keep_rate, activation = True,use_bias = False, norm = True,
                    csse = True, dropblock = True, weight_decay = None)

# Decoder 8 - 14 upsample-conv-swish-csse-concat-conv-swish
up3 = tf.keras.layers.UpSampling2D((2, 2), interpolation = 'nearest')(up2)
#up3 = ReflectionPadding2D((1, 1,))(up3)
up3 = conv_swish_gn(inp = up3, is_training = is_training, stride = (1, 1),
                    kernel_size = 3, scope = 'up3', filters = initial_flt, 
                    keep_rate = keep_rate, activation = True, use_bias = False, norm = True,
                    csse = True, dropblock = True, weight_decay = None, 
                    window_size = 104)
gru_crop = Cropping2D(6)(concat)
print(up3.shape)
print(gru_crop.shape)
up3 = tf.concat([up3, gru_crop], -1)

up3out = conv_swish_gn(inp = up3, is_training = is_training, stride = (1, 1),
                    kernel_size = 3, scope = 'out', filters = initial_flt, 
                    keep_rate  = keep_rate, activation = True,use_bias = False, norm = True,
                    csse = True, dropblock = False, weight_decay = None, padding = "VALID",
                       window_size = 104)

init = tf.constant_initializer([-np.log(0.7/0.3)]) # For focal loss
print(f"The output is {up2.shape}, with a receptive field of {1}")
fm = tf.layers.Conv2D(filters = 1,
            kernel_size = (1, 1),
            padding = 'valid',
            activation = 'sigmoid',
            bias_initializer = init, name = 'conv2d')(up3out)#,
print(fm)


GRU input shape (?, 4, 124, 124, 17), zoneout: 0.9

(3, 3, 49, 64)
(3, 3, 49, 64)
GRU block output shape (?, 124, 124, 64)

conv_median 3 Conv 2D Group Norm RELU CSSE NoBias DropBlock
The non normalized feats are Tensor("conv_median_conv/conv_median/x/mul:0", shape=(?, 124, 124, 64), dtype=float32)
The non normalized feats are Tensor("swish_f32:0", shape=(?, 124, 124, 64), dtype=float32)

Median conv: (?, 124, 124, 64)
conv_concat 3 Conv 2D Group Norm RELU CSSE NoBias DropBlock
The non normalized feats are Tensor("conv_concat_conv/conv_concat/x/mul:0", shape=(?, 124, 124, 64), dtype=float32)
The non normalized feats are Tensor("swish_f32_1:0", shape=(?, 124, 124, 64), dtype=float32)
Concat: (?, 124, 124, 64)
conv1 3 Conv 2D Group Norm RELU CSSE NoBias DropBlock
The non normalized feats are Tensor("conv1_conv/conv1/ws_conv2d_2/Conv2D:0", shape=(?, 60, 60, 128), dtype=float32)
The non normalized feats are Tensor("swish_f32_2:0", shape=(?, 60, 60, 128), dtype=float32)
Conv1: (?, 60, 60, 

In [7]:
finetune_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"conv2d_5") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"conv2d") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "csse_out") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "out") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "up3") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "up3_drop") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "csse_up3") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "up2_out") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "up2_out_drop") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "csse_up2_out") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "up2") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "up2_drop") + \
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "csse_up2")# + \

finetune_vars = finetune_vars + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"conv2")
finetune_vars = finetune_vars + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"csse_conv2")
finetune_vars = finetune_vars + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"conv1")
finetune_vars = finetune_vars + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"csse_conv1")
finetune_vars = finetune_vars + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"conv1")
finetune_vars = finetune_vars + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"csse_conv1")
#finetune_vars = finetune_vars + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"conv_concat")
#finetune_vars = finetune_vars + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"csse_conv_concat")

In [8]:
def grad_norm(gradients):
    norm = tf.compat.v1.norm(
        tf.stack([
            tf.compat.v1.norm(grad) for grad in gradients if grad is not None
        ])
    )
    return norm

FRESH_START = True
#print(f"Starting model with: \n {ZONE_OUT_PROB} zone out \n {L2_REG} l2 \n"
 #     f"{INITIAL_LR} initial LR \n {total_parameters} parameters")  

OUT = input_size - 14
if FRESH_START:
    # We use the Adabound optimizer
    optimizer = AdaBoundOptimizer(2e-4, 2e-2)
    #train_loss1 = logcosh(tf.reshape(labels, (-1, 14, 14, 1)), output) 
    
    train_loss2 = bce_surface_loss(tf.reshape(labels, (-1, OUT, OUT, 1)), fm,
                                  weight = loss_weight, 
                             alpha = alpha, beta = beta_, mask = mask)

    train_loss = train_loss2# + train_loss2
    
    # If there is any L2 regularization, add it. Current model does not use
    l2_loss = tf.losses.get_regularization_loss()
    if len(tf.losses.get_regularization_losses()) > 0:
        train_loss = train_loss + l2_loss
        
    test_loss = bce_surface_loss(tf.reshape(labels, (-1, OUT, OUT, 1)),
                            fm, weight = loss_weight, 
                            alpha = alpha, beta = beta_, mask = mask)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(train_loss, var_list = finetune_vars)   
        #ft_op = ft_optimizer.minimize(train_loss)
    
    # The following code blocks are for sharpness aware minimization
    # Adapted from https://github.com/sayakpaul/Sharpness-Aware-Minimization-TensorFlow
    # For tensorflow 1.15
    trainable_params = tf.trainable_variables()
    gradients = optimizer.compute_gradients(loss=train_loss, var_list=None)
    gradient_norm = grad_norm(gradients)
    scale = 0.05 / (gradient_norm + 1e-12)
    e_ws = []
    for (grad, param) in gradients:
        e_w = grad * scale
        param.assign_add(e_w)
        e_ws.append(e_w)

    sam_gradients = optimizer.compute_gradients(loss=train_loss, var_list=None)
    for (param, e_w) in zip(trainable_params, e_ws):
        param.assign_sub(e_w)
    train_step = optimizer.apply_gradients(sam_gradients)
    
    # Create a saver to save the model each epoch
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    saver = tf.train.Saver(max_to_keep = 150)#, var_list = all_vars)





2023-12-22 12:40:02.930471: I tensorflow/core/common_runtime/process_util.cc:115] Creating new thread pool with default inter op setting: 8. Tune using inter_op_parallelism_threads for best performance.





In [9]:
def make_saver_varlist(path):

    current_items = []
    vars_dict = {}
    for var_current in tf.global_variables():
        current_items.append(var_current) 
    names = [x.op.name for x in current_items]
    names = np.argsort(names)
    current_items = [current_items[x] for x in names]
    
    ckpt_items = []
    for var_ckpt in tf.train.list_variables(path):
        if 'BackupVariables' not in var_ckpt[0]:
            if 'StochasticWeightAveraging' not in var_ckpt[0]:
                if 'global_step' not in var_ckpt[0]:
                    if 'is_training' not in var_ckpt[0]:
                        if 'n_models' not in var_ckpt[0]:
                            ckpt_items.append(var_ckpt[0])
    
    ckptdict = {}
    for y, x in zip(ckpt_items, current_items):
        ckptdict[y] = x
    return ckptdict

In [10]:
ckptdict = make_saver_varlist('../models/172-ttc-dec2023-3/')

In [11]:
saver = tf.train.Saver(ckptdict)
model_path  = "../models/172-ttc-dec2023-3/"
FRESH_START = False
if not os.path.exists(model_path):
    os.makedirs(model_path)
if os.path.isfile(f"{model_path}metrics.npy"):
    metrics = np.load(f"{model_path}metrics.npy")
    print(f"Loading {model_path}metrics.npy")
else:
    print("Starting anew")
    metrics = np.zeros((6, 300))

if not FRESH_START:
    path = model_path
    saver.restore(sess, tf.train.latest_checkpoint(path))

Starting anew
INFO:tensorflow:Restoring parameters from ../models/172-ttc-dec2023-3/-0


In [12]:
#saver.save(sess, '../models/loss-avg-tf2/model')

In [13]:
def initialize_uninitialized(sess):
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))

# Data loading

*  Load in CSV data from Collect Earth
*  Reconstruct the X, Y grid for the Y data per sample
*  Calculate remote sensing indices
*  Stack X, Y, length data
*  Apply median filter to DEM

In [14]:
import hickle as hkl

def normalize_subtile(subtile):
    for band in range(0, subtile.shape[-1]):
        mins = min_all[band]
        maxs = max_all[band]
        subtile[..., band] = np.clip(subtile[..., band], mins, maxs)
        midrange = (maxs + mins) / 2
        rng = maxs - mins
        standardized = (subtile[..., band] - midrange) / (rng / 2)
        subtile[..., band] = standardized
    return subtile
 
def make_and_smooth_indices(arr):
    """Calculates remote sensing indices
    (evi, bi, msavi2, grndvi) and smooths them
    with the Whittaker smoother
    """
    def _make_indices(arr):
        indices = np.zeros(
            (arr.shape[0], arr.shape[1], arr.shape[2], 4), dtype = np.float32
        )
        indices[:, ..., 0] = evi(arr)
        indices[:, ...,  1] = bi(arr)
        indices[:, ...,  2] = msavi2(arr)
        indices[:, ...,  3] = grndvi(arr)
        return indices

    sm_indices = Smoother(lmbd = 50, 
                          size = 12, 
                          nbands = 4, 
                          dimx = arr.shape[1],
                          dimy = arr.shape[2], 
                          outsize = 12)

    indices = _make_indices(arr)
    indices = sm_indices.interpolate_array(indices)
    return indices

def load_individual_sample(fpath, ypath, f):
    ishkl = os.path.exists(fpath + f + '.hkl')
    if ishkl:
        x = hkl.load(fpath + f + '.hkl') / 65535
    else:
        x = np.load(fpath + f + ".npy") / 65535

    if x.shape[-1] == 13:
        i = make_and_smooth_indices(x)
        out = np.zeros((x.shape[0], x.shape[1], x.shape[2], 17), dtype = np.float32)
        out[..., :13] = x 
        out[..., 13:] = i
    else:
        out = x
        out[..., -1] *= 2
        out[..., -1] -= 0.7193834232943873
        
        #out[-1] -= 0.7193834232943873
        out[..., -2] -= 0.09731556326714398
        out[..., -3] -= 0.4973397113668104,
        out[..., -4] -= 0.1409399364817101
        #out[]
    #median = np.median(out, axis = 0)
    #out = np.reshape(out, (4, 3, out.shape[1], out.shape[2], out.shape[3]))
    #out = np.median(out, axis = 1, overwrite_input = True)
    #out = np.concatenate([out, median[np.newaxis]], axis = 0)
    return normalize_subtile(out[:, 2:-2, 2:-2, :]), rs.open(ypath + f + ".tif").read(1) / 255

def augment_single_sample(x, y):
    '''Performs random flips and rotations of the X and Y
       data for a total of 4 x augmentation
    
         Parameters:
          batch_ids (list):
          batch_size (int):
          
         Returns:
          x_batch (arr):
          y_batch (arr):
    '''
    samples_to_median = np.random.randint(0, 12, size=(12,)) #[32, 6]
    samples_to_select = np.zeros((4))
    samples_to_select[0] = np.random.randint(0, 3, size=(1))
    samples_to_select[1] = np.random.randint(3, 6, size=(1))
    samples_to_select[2] = np.random.randint(6, 9, size=(1))
    samples_to_select[3] = np.random.randint(9, 12, size=(1))
    samples_to_select = samples_to_select.astype(np.int)
    n_samples = np.random.randint(2, 10) 
    
    x_batch = np.zeros((5, 124, 124, 17))
    x_batch[0] = x[samples_to_select[0]]
    x_batch[1] = x[samples_to_select[1]]
    x_batch[2] = x[samples_to_select[2]]
    x_batch[3] = x[samples_to_select[3]]
    x_batch[4] = np.median(x[samples_to_median[:n_samples]], axis = 0)
    x = x_batch

    return x, y

## Standardization

In [15]:
min_all = [0.006576638437476157, 0.0162050812542916, 0.010040436408026246, 
           0.013351644159609368, 0.01965362020294499, 0.014229037918669413, 
           0.015289539940489814, 0.011993591210803388, 0.008239871824216068, 
           0.006546120393682765, 0.0, 0.0, 0.0, -0.1409399364817101, 
           -0.4973397113668104, -0.09731556326714398, -0.7193834232943873]

max_all = [0.2691233691920348, 0.3740291447318227, 0.5171435111009385, 0.6027466239414053,
           0.5650263218127718, 0.5747005416952773, 0.5933928435187305, 0.6034943160143434, 
           0.7472037842374304, 0.4, 0.509269855802243, 0.948334642387533, 
           0.6729257769285485, 0.8177635298774327, 0.35768999002433816, 0.7545951919107605, 
           0.7602693339366691]

## Load and process test data

# Evaluation metrics

In [16]:
def compute_f1_score_at_tolerance(true, pred, tolerance = 1):
    """Because of coregistration errors, we evaluate the model
    where false positives/negatives must be >1px away from a true positive
    """
    fp = 0
    tp = 0
    fn = 0
    
    tp = np.zeros_like(true)
    fp = np.zeros_like(true)
    fn = np.zeros_like(true)
    
    for x in range(true.shape[0]):
        for y in range(true.shape[1]):
            min_x = np.max([0, x-1])
            min_y = np.max([0, y-1])
            max_y = np.min([true.shape[0], y+2])
            max_x = np.min([true.shape[0], x+2])
            if true[x, y] == 1:
                if np.sum(pred[min_x:max_x, min_y:max_y]) > 0:
                    tp[x, y] = 1
                else:
                    fn[x, y] = 1
            if pred[x, y] == 1:
                if np.sum(true[min_x:max_x, min_y:max_y]) > 0:
                    if true[x, y] == 1:
                        tp[x, y] = 1
                else:
                    fp[x, y] = 1                
                
    return np.sum(tp), np.sum(fp), np.sum(fn)

def calculate_metrics(al = 0.4, canopy_thresh = 100):
    '''Calculates the following metrics
       
         - Loss
         - F1
         - Precision
         - Recall
         - Dice
         - Mean surface distance
         - Average error
    
         Parameters:
          al (float):
          canopy_thresh (int)
          
         Returns:
          val_loss (float):
          best_dice (float):
          error (float):
    '''
    start_idx = 0
    stop_idx = len(test_x)
    best_f1, best_thresh, relaxed_f1 = 0, 0, 0
    preds, trues, vls = [], [], []

    test_ids = [x for x in range(len(test_x))]
    for test_sample in test_ids[start_idx:stop_idx]:
        if np.sum(test_y[test_sample]) < ((canopy_thresh/100) * 197):
            x_input = test_x[test_sample].reshape(1, 13, 28, 28, n_bands)
            x_median_input = calc_median_input(x_input)
            y, vl = sess.run([fm, test_loss], feed_dict={inp: x_input,
                                                          length: np.full((1,), 12),
                                                          is_training: False,
                                                          labels: test_y[test_sample].reshape(1, OUT, OUT),
                                                          loss_weight: 0.1,
                                                          alpha: 0.33,
                                                          })
            preds.append(y.reshape((OUT, OUT)))
            vls.append(vl)
            trues.append(test_y[test_sample].reshape((OUT, OUT)))
            
    # These threshes are just for ROC
    for thresh in range(7, 9):
        tps_relaxed = np.empty((len(preds), ))
        fps_relaxed = np.empty((len(preds), ))
        fns_relaxed = np.empty((len(preds), ))
        abs_error = np.empty((len(preds), ))
        
        for sample in range(len(preds)):
            pred = np.copy(preds[sample])
            true = trues[sample]
        
            pred[np.where(pred >= thresh*0.05)] = 1
            pred[np.where(pred < thresh*0.05)] = 0
            
            true_s = np.sum(true[1:-1])
            pred_s = np.sum(pred[1:-1])
            abs_error[sample] = abs(true_s - pred_s)
            tp_relaxed, fp_relaxed, fn_relaxed = compute_f1_score_at_tolerance(true, pred)
            tps_relaxed[sample] = tp_relaxed
            fps_relaxed[sample] = fp_relaxed
            fns_relaxed[sample] = fn_relaxed                   
            
        oa_error = np.mean(abs_error)
        precision_r = np.sum(tps_relaxed) / (np.sum(tps_relaxed) + np.sum(fps_relaxed))
        recall_r = np.sum(tps_relaxed) / (np.sum(tps_relaxed) + np.sum(fns_relaxed))
        f1_r = 2*((precision_r* recall_r) / (precision_r + recall_r))
        
        if f1_r > best_f1:
            best_f1 = f1_r
            p = precision_r
            r = recall_r
            error = oa_error
            best_thresh = thresh*0.05

    print(f"Val loss: {np.around(np.mean(vls), 3)}"
          f" Thresh: {np.around(best_thresh, 2)}"
          f" F1: {np.around(best_f1, 3)} R: {np.around(p, 3)} P: {np.around(r, 3)}"
          f" Error: {np.around(error, 3)}")
    return np.mean(vls), best_f1, error

# Data augmentation

The below code block implements cut mix where random samples are spliced together where the output labels have similar tree cover distributions (within the same kmeans cluster). Not super necessary but does give a small performance improvement.

In [17]:
import rasterio as rs
LEN = 4
train_xs= [x[:-4] for x in os.listdir('/Volumes/John/data/train-17k-may2023/figs/') if '.png' in x]
print(len(train_xs))
train_bad = [x[:-4] for x in os.listdir('/Volumes/John/data/train-17k-may2023/bad/')]
train_xs = [x for x in train_xs if x not in train_bad]
print(len(train_xs))
x_path = '/Volumes/John/train-ard-128/'
y_path = '/Volumes/John/data/train-17k-may2023/train-y/'

def load_and_augment_xy(x, y, f):
    x, y = load_individual_sample(x, y, f)
    x, y = augment_single_sample(x, y)
    return x, y

percs = np.zeros((len(train_xs)))
for i in range(len(train_xs)):
    percs[i] = np.mean(rs.open(y_path + train_xs[i] + '.tif').read(1)) / 2.55

2294
2043


In [18]:
zeros = np.argwhere(percs == 0).flatten()
fives = np.argwhere(np.logical_and(percs > 0, percs <= 5)).flatten()
tens = np.argwhere(np.logical_and(percs > 5, percs <= 10)).flatten()
twenties = np.argwhere(np.logical_and(percs > 10, percs <= 20)).flatten()
thirties = np.argwhere(np.logical_and(percs > 20, percs <= 35)).flatten()
forties = np.argwhere(np.logical_and(percs > 35, percs <= 60)).flatten()
fifties = np.argwhere(np.logical_and(percs > 60, percs <= 90)).flatten()
seventies = np.argwhere(np.logical_and(percs > 90, percs <= 100)).flatten()
print(len(zeros), len(fives), len(tens), len(twenties), len(thirties), len(forties), len(fifties), len(seventies))

343 272 221 275 189 181 172 390


In [19]:
batches = []

In [20]:
#### from tqdm.notebook import trange
%run ../src/preprocessing/indices.py
%run ../src/preprocessing/whittaker_smoother.py

best_val = 0.72
fine_tune = False
ft_epochs = 0
BATCH_SIZE = 4

# loss2 125-160 is 0.4 alpha, >0.6 surface, 0.33 loss weight
# loss45 is 0.4 alpha, >0.45 surface, 0.4 loss weight
# loss45 250 - 300 is 0.4 alpha, >0.45 surface, 0.4 loss weight with minimum surface loss

SWA = False
for i in range(1, 10):
    al = 0.3
    ft_steps = (i - 1) * 240
    ft_learning_rate = 0.02
    #if nepochs < 5:
     #   ft_learning_rate *= (0.2 * nepochs)
    be = 0.0
    test_al = al
    #random.shuffle(train_xs)
    op = train_op# if fine_tune else train_op
    
    print(f"starting epoch {i}, " 
          f"alpha: {al}, beta: {be}, "
          f"drop: {np.max(((1. - (i * 0.005)), 0.6))} "
          f"Learning rate: {ft_learning_rate}"
         )
    
    loss = train_loss
    losses = []
    
    for k in range(0, 240*8, 8):
        ft_steps += 1
        if ft_steps < 600:
            ft_learning_rate = (ft_steps / 600) * 2e-2
            print(ft_learning_rate)
        else:
            ft_learning_rate = 2e-2
        #try:
        x_batch = np.zeros((8, 5, 124, 124, 17), dtype = np.float32)
        y_batch = np.zeros((8, 110, 110), dtype = np.float32)
        mask_batch = np.zeros_like(y_batch)
        
        tochoose = [zeros, fives, tens, twenties, thirties, forties, fifties, seventies]
        old_batch = np.copy(np.array(batches))
        batches = []
        for i in range(0, 8):
            #print(train_xs[k + i])
            rng = np.random.randint(len(tochoose[i]))
            sample = train_xs[tochoose[i][rng]]
            batches.append(sample)
            x_batcha, y_batcha = load_and_augment_xy(x_path, y_path, sample)
            y_batcha = np.pad(y_batcha, (48, 48), 'constant', constant_values=(0, 0))
            mask_batcha = np.zeros_like(y_batcha)
            mask_batcha[48:-48, 48:-48] = 1.
            x_batch[i] = x_batcha
            y_batch[i] = y_batcha
            mask_batch[i] = mask_batcha
        y_batch[y_batch > 0.9] = 1.
        #x_batch = x_batch[:, :, 48:-48, 48:-48]
        #y_batch = y_batch[:, 48:-48, 48:-48]
        #mask_batch = np.ones_like(y_batch)
        #print(x_batch.shape, y_batch.shape, mask_batch.shape)
        out = sess.run([op, fm], #
                          feed_dict={inp: x_batch,
                                     length: np.full((8,), 4),
                                     labels: y_batch,
                                     mask: mask_batch,
                                     is_training: True,
                                     loss_weight: 1.,
                                     keep_rate: 0.75,
                                     alpha: al,
                                     beta_: be,
                                     ft_lr: ft_learning_rate,
                                     })
    
        predmean = np.mean(out[1][:, 48:-48, 48:-48].squeeze(), axis = (1, 2))
        labelmean = np.mean(y_batch[:, 48:-48, 48:-48].squeeze(), axis = (1, 2))
        predmean = np.around(predmean, 2)
        labelmean = np.around(labelmean, 2)
        for a, b, c in zip(np.array(batches), predmean, labelmean):
            print(a, b, c)
        ##losses.append(tr)
        #except KeyboardInterrupt:
        #        print('Interrupted')
        #        break
        #except:
        #    continue
    
    print(f"Epoch {i}: Loss {np.around(np.mean(losses[:-1]), 3)}")
    #saver.save(sess, f"../models/epoch30-{str(i)}/model")
    output_node_names = ['conv2d/Sigmoid']
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)


    # Save the frozen graph
    with open(f'../models/tmp/predict_graph-{str(i)}.pb', 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())

starting epoch 1, alpha: 0.3, beta: 0.0, drop: 0.995 Learning rate: 0.02
3.3333333333333335e-05


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


141370751 0.11 0.0
200058 0.07 0.02
135787414 0.02 0.09
135732397 0.09 0.17
140860628 0.24 0.32
136318424 0.54 0.57
3117060010 0.57 0.69
20048 0.95 1.0
6.666666666666667e-05
2200114 0.01 0.0
139379537 0.02 0.01
139379145 0.1 0.07
139341808 0.08 0.12
135702863 0.38 0.31
138872341 0.46 0.41
139027214 0.52 0.71
135703926 0.94 1.0
0.0001
139291647 0.03 0.0
139077430 0.02 0.03
135804029 0.05 0.05
135224895 0.1 0.16
139207917 0.13 0.24
135703898 0.51 0.38
138872498 0.89 0.88
135703851 0.88 0.96
0.00013333333333333334
141370751 0.11 0.0
139379801 0.13 0.03
138173606 0.04 0.08
135702947 0.1 0.15
135224843 0.22 0.31
138872600 0.43 0.43
5232050030 0.53 0.69
135787639 0.91 0.9
0.00016666666666666666
139397648 0.01 0.0
137517365 0.05 0.04
139379145 0.1 0.07
135804089 0.06 0.18
139025822 0.19 0.23
135703846 0.62 0.49
137532551 0.62 0.87
1234277 0.88 0.96
0.0002
139208166 0.05 0.0
139379445 0.01 0.04
135724967 0.09 0.08
11527110010 0.11 0.19
135505956 0.05 0.2
135505842 0.47 0.51
136089094 0.75 0.83

24005 0.23 0.0
135787448 0.06 0.04
31005 0.13 0.07
135224686 0.16 0.18
139077754 0.45 0.28
9636050010 0.6 0.39
135703328 0.54 0.65
9017080020 0.96 0.99
0.0016666666666666666
200072 0.02 0.0
135704174 0.26 0.03
139379145 0.11 0.07
137862315 0.15 0.17
135703890 0.4 0.34
139365986 0.56 0.41
200159 0.71 0.65
8691060010 0.95 1.0
0.0017000000000000001
135787400 0.01 0.0
135732396 0.1 0.03
7867110010 0.21 0.09
135840880 0.11 0.1
139365859 0.22 0.23
10288060040 0.51 0.5
137532551 0.67 0.87
9835060010 0.96 0.99
0.0017333333333333335
2500232 0.03 0.0
2600248 0.03 0.02
139323565 0.06 0.06
139160350 0.21 0.12
20089 0.4 0.22
5101050010 0.5 0.4
240028 0.4 0.62
135680051 0.96 1.0
0.0017666666666666666
135672785 0.03 0.0
137535092 0.02 0.02
135654602 0.04 0.08
139420305 0.25 0.13
140786039 0.51 0.28
135704062 0.38 0.49
80052 0.78 0.88
138872365 0.96 1.0
0.0018
2100178 0.02 0.0
135703106 0.1 0.04
138173606 0.04 0.08
13826110040 0.23 0.13
135191161 0.35 0.33
135787440 0.68 0.51
139686336 0.78 0.81
22310

139025733 0.02 0.0
135787272 0.03 0.04
139323565 0.06 0.06
800155 0.05 0.11
135803878 0.47 0.23
135751791 0.31 0.37
80016 0.8 0.81
2178080010 0.32 0.97
0.003266666666666667
139291657 0.07 0.0
27007 0.04 0.04
8725110010 0.03 0.06
40073 0.19 0.16
135344159 0.16 0.2
137517062 0.58 0.58
8008 0.82 0.77
138872365 0.97 1.0
0.0033000000000000004
139077432 0.01 0.0
135841016 0.03 0.03
139025529 0.23 0.1
135680980 0.31 0.19
139048986 0.19 0.32
139669334 0.3 0.45
300021 0.92 0.88
1234274 0.89 0.93
0.003333333333333333
139419876 0.01 0.0
135787438 0.04 0.05
139379571 0.16 0.07
135702611 0.07 0.1
139207917 0.14 0.24
136029544 0.26 0.37
9394050040 0.71 0.91
5530050040 0.97 1.0
0.0033666666666666667
900182 0.03 0.0
135732487 0.04 0.02
135672824 0.09 0.09
135681029 0.24 0.15
137517146 0.21 0.2
137527070 0.51 0.46
3117060010 0.72 0.69
139186714 0.96 1.0
0.0034000000000000002
2500241 0.02 0.0
11976050030 0.03 0.02
139413342 0.12 0.09
2100165 0.18 0.15
135703839 0.2 0.22
135703898 0.57 0.38
80090 0.81 0.

2000150 0.04 0.0
139292058 0.07 0.05
139178381 0.03 0.05
135732431 0.11 0.13
138948227 0.24 0.22
139027689 0.32 0.48
12689060010 0.78 0.79
139027999 0.98 1.0
0.004866666666666667
135542446 0.01 0.0
135702943 0.35 0.04
137532577 0.18 0.09
139177929 0.41 0.2
2175080020 0.73 0.25
20079 0.33 0.48
136134609 0.54 0.8
139178376 0.98 1.0
0.0049
2500223 0.03 0.0
2000130 0.15 0.04
139025598 0.03 0.07
139160350 0.2 0.12
139077536 0.15 0.29
136134714 0.24 0.37
80051 0.58 0.89
9519050010 0.97 1.0
0.004933333333333334
139379771 0.02 0.0
139188995 0.04 0.03
2190110020 0.02 0.08
4763080040 0.36 0.16
139277261 0.09 0.25
9636050010 0.65 0.39
12485080010 0.76 0.66
8999110010 0.68 0.94
0.004966666666666667
200072 0.01 0.0
139027219 0.03 0.04
137587691 0.06 0.08
135840880 0.07 0.1
139048875 0.02 0.23
139146774 0.12 0.4
135703903 0.91 0.86
135703912 0.94 1.0
0.005
280023 0.01 0.0
135787360 0.03 0.02
139025529 0.26 0.1
135703980 0.1 0.13
135703931 0.23 0.28
139178384 0.45 0.52
136134609 0.54 0.8
137532557 0.

2500225 0.02 0.0
135787403 0.02 0.04
136456566 0.2 0.08
138900715 0.03 0.15
400235 0.56 0.31
139277255 0.24 0.36
135703864 0.66 0.69
138872365 0.97 1.0
0.006466666666666667
139146789 0.01 0.0
135787345 0.1 0.04
135724926 0.07 0.05
137532493 0.1 0.11
200243 0.15 0.23
135704000 0.48 0.42
139686336 0.76 0.81
1234175 0.87 1.0
0.006500000000000001
139178427 0.01 0.0
135732396 0.07 0.03
139291487 0.05 0.05
139027210 0.03 0.1
135224677 0.16 0.32
12004060030 0.67 0.59
135703811 0.86 0.87
30041 0.96 1.0
0.006533333333333334
140750829 0.13 0.0
270041 0.04 0.01
800174 0.25 0.09
136075854 0.13 0.19
10985060030 0.54 0.31
14550060010 0.41 0.57
135703988 0.65 0.88
200121 0.89 1.0
0.006566666666666666
139208290 0.03 0.0
139146778 0.01 0.02
135703058 0.12 0.07
135680829 0.3 0.18
400231 0.35 0.21
135703870 0.56 0.53
500224 0.89 0.68
137517124 0.94 1.0
0.006600000000000001
2000150 0.04 0.0
136446310 0.11 0.04
138872508 0.06 0.07
135702611 0.05 0.1
4267110010 0.21 0.34
136077612 0.46 0.47
80020 0.44 0.65


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


INFO:tensorflow:Froze 65 variables.
INFO:tensorflow:Converted 65 variables to const ops.
starting epoch 2, alpha: 0.3, beta: 0.0, drop: 0.99 Learning rate: 0.02
0.008033333333333333
2500221 0.04 0.0
200088 0.01 0.01
139160447 0.05 0.06
140860608 0.06 0.11
2847060010 0.23 0.34
5006080020 0.58 0.41
14252110010 0.66 0.92
8205080040 0.96 1.0
0.008066666666666666
2600214 0.01 0.0
135732418 0.02 0.04
139077462 0.36 0.06
139319999 0.19 0.11
12953080030 0.2 0.3
140474083 0.42 0.43
400229 0.87 0.82
1234193 0.69 0.94
0.008100000000000001
2600227 0.03 0.0
136434333 0.01 0.04
135724919 0.09 0.1
135681019 0.06 0.18
139025822 0.17 0.23
8393080040 0.47 0.4
135704456 0.86 0.75
136410973 0.76 1.0
0.008133333333333334
1234219 0.02 0.0
139160528 0.02 0.04
135840931 0.11 0.08
135809884 0.23 0.16
135680808 0.22 0.22
139179727 0.35 0.45
5668050030 0.82 0.62
9565110040 0.94 0.97
0.008166666666666666
12973080040 0.01 0.0
600192 0.07 0.04
135191289 0.16 0.07
9190110041 0.06 0.11
138948044 0.2 0.33
135191261 0.

139189061 0.02 0.0
2000163 0.04 0.01
135809729 0.07 0.1
139397782 0.06 0.11
8625110010 0.31 0.25
400247 0.23 0.36
12069060030 0.9 0.73
1234161 0.78 1.0
0.009633333333333334
250049 0.08 0.0
139027186 0.04 0.03
137587691 0.06 0.08
135702870 0.13 0.17
135703805 0.35 0.31
135787415 0.21 0.38
80052 0.76 0.88
139027236 0.97 1.0
0.009666666666666667
200061 0.01 0.0
27003 0.03 0.02
135787376 0.02 0.08
135702870 0.13 0.17
135344159 0.21 0.2
13574080010 0.41 0.4
136411025 0.71 0.71
9519050010 0.96 1.0
0.0097
280023 0.01 0.0
139365837 0.06 0.04
135698200 0.1 0.09
135787319 0.06 0.11
135787078 0.51 0.34
137547458 0.16 0.4
1234194 0.61 0.76
139146791 0.88 0.95
0.009733333333333333
2500225 0.02 0.0
2000136 0.02 0.02
135787379 0.06 0.09
135703529 0.14 0.17
135841059 0.27 0.21
300144 0.24 0.56
139077723 0.75 0.68
600178 0.93 1.0
0.009766666666666667
2000159 0.02 0.0
139189633 0.1 0.04
135751783 0.03 0.07
138173604 0.06 0.13
139365788 0.15 0.24
136134445 0.43 0.45
136410944 0.65 0.65
4953050010 0.97 0.

137535073 0.01 0.0
135787435 0.04 0.03
9353050030 0.03 0.06
137517095 0.14 0.11
139379486 0.54 0.3
139178384 0.44 0.52
135787281 0.78 0.73
7574080041 0.97 1.0
0.011233333333333333
2000122 0.09 0.0
139027173 0.03 0.03
135680099 0.16 0.09
13543110010 0.13 0.18
139342340 0.25 0.24
5101050010 0.47 0.4
1234159 0.66 0.86
310039 0.72 1.0
0.011266666666666668
2000109 0.01 0.0
139379751 0.06 0.05
139048878 0.17 0.06
135787430 0.2 0.13
138872338 0.26 0.26
137517062 0.65 0.58
137535114 0.81 0.84
135703907 0.96 0.99
0.0113
139070804 0.28 0.0
2000136 0.02 0.02
139189017 0.18 0.09
135224777 0.11 0.15
135787358 0.18 0.31
9807110010 0.38 0.4
137517104 0.7 0.79
137517117 0.93 1.0
0.011333333333333332
600104 0.08 0.0
137532541 0.04 0.03
135787352 0.05 0.08
135703289 0.14 0.14
139027215 0.31 0.31
135703794 0.46 0.44
135703990 0.62 0.78
137517139 0.94 1.0
0.011366666666666667
139379564 0.01 0.0
139025636 0.02 0.04
139161946 0.12 0.09
7524080040 0.25 0.16
139052583 0.33 0.29
135703283 0.46 0.49
9394050040 

KeyboardInterrupt: 

In [None]:
output_node_names = ['conv2d/Sigmoid']
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
    sess,
    sess.graph_def,
    output_node_names)


# Save the frozen graph
with open(f'../models/tmp/predict_graph-{str(i)}.pb', 'wb') as f:
    f.write(frozen_graph_def.SerializeToString())

In [None]:
idx = 0
l = sns.heatmap(out[0][idx].squeeze(), vmin = 0.0, vmax = 1)
l.set_title(old_batch[idx])

In [None]:

def load_individual_sample(fpath, ypath, f):
    ishkl = os.path.exists(fpath + f + '.hkl')
    if ishkl:
        x = hkl.load(fpath + f + '.hkl') / 65535
    else:
        x = np.load(fpath + f + ".npy") / 65535
    img = rs.open(fpath + f + ".tif").read()
    print(fpath + f + ".hkl")
    print(np.sum(x))
    img = np.moveaxis(img, 0, 2)
    #img = np.swapaxes(img, 0, 1)
    #x = x[..., :13]
    if x.shape[-1] == 13:
        i = make_and_smooth_indices(x)
        out = np.zeros((x.shape[0], x.shape[1], x.shape[2], 17), dtype = np.float32)
        out[..., :13] = x 
        out[..., 13:] = i
    else:
        out = x
        out[..., -1] *= 2
        out[..., -1] -= 0.7193834232943873
        
        #out[-1] -= 0.7193834232943873
        out[..., -2] -= 0.09731556326714398
        out[..., -3] -= 0.4973397113668104,
        out[..., -4] -= 0.1409399364817101
        #out[]
    median = np.median(out, axis = 0)
    print(np.sum(out))
    out = np.reshape(out, (4, 3, out.shape[1], out.shape[2], out.shape[3]))
    out = np.median(out, axis = 1, overwrite_input = True)
    print(np.sum(out))
    out = np.concatenate([out, median[np.newaxis]], axis = 0)
    print(np.sum(out))
    return normalize_subtile(out[:, 2:-2, 2:-2, :]), rs.open(ypath + f + ".tif").read(1) / 255

def load_and_augment_xy(x, y, f):
    x, y = load_individual_sample(x, y, f)
    #x, y = augment_single_sample(x, y)
    return x, y

x, y = load_and_augment_xy(x_path, y_path, str(135840882))
np.sum(x, axis = (0, 1, 2))

In [None]:
np.save('batch_x.npy', x_batch[0])

In [None]:
#sess.run(swa_to_weights)
saver = tf.train.Saver(max_to_keep = 150)
#os.mkdir(f"../models/loss2/")
save_path = saver.save(sess, f"../models/loss2/model")


In [None]:

output_node_names = ['conv2d_5/Sigmoid']
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
    sess,
    sess.graph_def,
    output_node_names)


# Save the frozen graph
with open('../models/loss3/predict_graph.pb', 'wb') as f:
    f.write(frozen_graph_def.SerializeToString())