In [None]:
import sys 
import datetime

sys.path.insert(0, '../')

from constants import * 
import datasets
import loss_fns
import models
import util

In [None]:
def keep_track(cur_count, cur_sum, cur_sq_sum, inputs):
    for band_idx in range(inputs.shape[2]):
        cur_band = inputs[:, :, band_idx, :, :]
        cur_count[band_idx] += len(cur_band.flatten())
        cur_sum[band_idx] += np.sum(cur_band)
        cur_sq_sum[band_idx] += np.sum(cur_band**2)
    return cur_count, cur_sum, cur_sq_sum

In [None]:
def load_data(model_name, args=None, dataloaders=None, X=None, y=None):
    """ Trains the model on the inputs
    
    Args:
        model - trainable model
        model_name - (str) name of the model
        args - (argparse object) args parsed in from main; used only for DL models
        dataloaders - (dict of dataloaders) used only for DL models
        X - (npy arr) data for non-dl models
        y - (npy arr) labels for non-dl models
    """
    if model_name in DL_MODELS:
        N = None
        for split in ['train']:
            dl = dataloaders[split]
            for inputs, targets, cloudmasks in dl:
                if N is None:
                    N = np.zeros((inputs.numpy().shape[2],))
                    S = np.zeros((inputs.numpy().shape[2],))
                    Q = np.zeros((inputs.numpy().shape[2],))
                    
                N, S, Q = keep_track(N, S, Q, inputs.numpy())
                    
    return N, S, Q

In [None]:
# train model
means_list = []
stds_list = []

count = 0
while count < 1:
    print('count: ', count)
    
    train_parser = util.get_train_parser()
    args = train_parser.parse_args(['--epochs', str(1),
                                '--model_name', 'fcn_crnn',
                                '--dataset', 'full', 
                                '--num_classes', str(4),
                                '--country', 'ghana',
                                '--batch_size', str(1),
                                '--hidden_dims', str(4),
                                '--crnn_num_layers', str(1),
                                '--use_s1', str(True),
                                '--use_s2', str(True),
                                '--sample_w_clouds', str(False),
                                '--include_clouds', str(True),
                                '--include_doy', str(True),
                                '--bidirectional', str(False), 
                                '--shuffle', str(False),
                                '--normalize', str(False),
                                '--apply_transforms', str(False),
                                '--least_cloudy', str(False)])
    
    # load in data generator
    dataloaders = datasets.get_dataloaders(args.country, args.dataset, args)

    N, S, Q = load_data(args.model_name, args, dataloaders=dataloaders)
    means = S / N 
    stds = np.sqrt( (Q / N) - (S / N)**2 )
    
    print(means)
    print(stds)
    
    means_list.append(means)
    stds_list.append(stds)
    count += 1

In [None]:
bnd_means = np.vstack(means_list)
print("Means: ", np.mean(bnd_means, axis=0))
print("+/- : ", np.std(bnd_means, axis=0))

bnd_stds = np.vstack(stds_list)

print("Stdevs: ", np.mean(bnd_stds, axis=0))
print("+/- : ", np.std(bnd_stds, axis=0))