In [1]:
import sys 
import datetime

sys.path.insert(0, '../')

from constants import * 
import datasets
import loss_fns
import models
import util

Using TensorFlow backend.


In [2]:
def keep_track(cur_count, cur_sum, cur_sq_sum, inputs):
    for band_idx in range(inputs.shape[2]):
        cur_band = inputs[:, :, band_idx, :, :]
        cur_count[band_idx] += len(cur_band.flatten())
        cur_sum[band_idx] += np.sum(cur_band)
        cur_sq_sum[band_idx] += np.sum(cur_band**2)
    return cur_count, cur_sum, cur_sq_sum

In [3]:
def load_data(model_name, args=None, dataloaders=None, X=None, y=None):
    """ Trains the model on the inputs
    
    Args:
        model - trainable model
        model_name - (str) name of the model
        args - (argparse object) args parsed in from main; used only for DL models
        dataloaders - (dict of dataloaders) used only for DL models
        X - (npy arr) data for non-dl models
        y - (npy arr) labels for non-dl models
    """
    if model_name in DL_MODELS:
        N = None
        for split in ['train']:
            dl = dataloaders[split]
            for inputs, targets, cloudmasks in dl:
                if N is None:
                    N = np.zeros((inputs.numpy().shape[2],))
                    S = np.zeros((inputs.numpy().shape[2],))
                    Q = np.zeros((inputs.numpy().shape[2],))
                    
                N, S, Q = keep_track(N, S, Q, inputs.numpy())
                    
    return N, S, Q

In [4]:
# train model
means_list = []
stds_list = []

count = 0
while count < 1:
    print('count: ', count)
    
    train_parser = util.get_train_parser()
    args = train_parser.parse_args(['--epochs', str(1),
                                '--model_name', 'fcn_crnn',
                                '--dataset', 'full', 
                                '--num_classes', str(4),
                                '--country', 'ghana',
                                '--hdf5_filepath', '/home/data/ghana/data.hdf5',
                                '--batch_size', str(1),
                                '--hidden_dims', str(4),
                                '--crnn_num_layers', str(1),
                                '--use_s1', str(True),
                                '--use_s2', str(True),
                                '--sample_w_clouds', str(False),
                                '--include_clouds', str(True),
                                '--include_doy', str(True),
                                '--bidirectional', str(False), 
                                '--shuffle', str(False),
                                '--normalize', str(False),
                                '--apply_transforms', str(False),
                                '--least_cloudy', str(False)])
    
    # load in data generator
    dataloaders = datasets.get_dataloaders(args.grid_dir, args.country, args.dataset, args)

    N, S, Q = load_data(args.model_name, args, dataloaders=dataloaders)
    means = S / N 
    stds = np.sqrt( (Q / N) - (S / N)**2 )
    
    print(means)
    print(stds)
    
    means_list.append(means)
    stds_list.append(stds)
    count += 1

count:  0
[-1.03430326e+01 -1.68884483e+01  1.14935244e+00  1.39528606e-02
  2.61804050e+03  2.51302351e+03  2.60890642e+03  2.71480983e+03
  3.19964221e+03  3.53313875e+03  3.32707585e+03  3.75451141e+03
  2.82330860e+03  2.03768521e+03  8.85984018e-02  1.73613201e-01]
[3.62755244e+00 4.94754780e+00 1.05513105e+01 5.92549638e-01
 2.21432477e+03 2.12968352e+03 2.22625548e+03 2.13607543e+03
 2.11700885e+03 2.18226745e+03 2.04947709e+03 2.16941124e+03
 1.24830882e+03 9.46449411e+02 8.31678880e-01 5.61052841e-01]
count:  1
[-1.03313250e+01 -1.68705949e+01  1.15225130e+00  2.14210776e-02
  2.60858940e+03  2.50461811e+03  2.60151852e+03  2.70885934e+03
  3.19410304e+03  3.52727807e+03  3.32283869e+03  3.74879827e+03
  2.82020792e+03  2.03524199e+03  9.28364883e-02  1.79311038e-01]
[3.58523023e+00 4.85314646e+00 1.07147364e+01 5.92104442e-01
 2.22143744e+03 2.13622006e+03 2.23239042e+03 2.14413354e+03
 2.12367537e+03 2.18663752e+03 2.05660485e+03 2.17314794e+03
 1.24916887e+03 9.48023792e+02

In [5]:
bnd_means = np.vstack(means_list)
print("Means: ", np.mean(bnd_means, axis=0))
print("+/- : ", np.std(bnd_means, axis=0))

bnd_stds = np.vstack(stds_list)

print("Stdevs: ", np.mean(bnd_stds, axis=0))
print("+/- : ", np.std(bnd_stds, axis=0))

Means:  [-1.03188577e+01 -1.68562630e+01  1.15186084e+00  1.97034215e-02
  2.61619311e+03  2.51126700e+03  2.60739752e+03  2.71425497e+03
  3.19940404e+03  3.53242898e+03  3.32763300e+03  3.75349526e+03
  2.81829437e+03  2.03369264e+03  9.07834845e-02  1.81667560e-01]
+/- :  [2.55724212e-02 3.13214393e-02 1.72702865e-03 3.33387171e-03
 1.09366891e+01 1.04072988e+01 1.08356659e+01 1.07485572e+01
 1.16860950e+01 1.26783961e+01 1.18292223e+01 1.31423342e+01
 1.00767777e+01 7.46907206e+00 4.77159695e-03 4.59240440e-03]
Stdevs:  [3.60755247e+00 4.90097923e+00 1.17266831e+01 5.91574226e-01
 2.22642428e+03 2.14090566e+03 2.23718373e+03 2.14824661e+03
 2.12642293e+03 2.18887671e+03 2.05886182e+03 2.17427895e+03
 1.24299343e+03 9.43253283e+02 8.30565130e-01 5.62406905e-01]
+/- :  [2.04688068e-02 4.47734636e-02 1.26424229e+00 9.74702743e-04
 7.69169343e+00 7.25054763e+00 7.47731237e+00 7.93110147e+00
 7.80939519e+00 7.82822654e+00 8.28118530e+00 8.14410972e+00
 7.71959190e+00 5.12940426e+00 9.05