In [7]:
import xarray as xr
import os
import numpy as np
import cartopy.crs as ccrs
import matplotlib.pyplot as plt

## Cargar datos

In [8]:
def traverseDir(root):
    for (dirpath, dirnames, filenames) in os.walk(root):
        for file in filenames:
            if file.endswith(('.nc')):
                yield os.path.join(dirpath, file)

In [9]:
root = '/home/jovyan/shared/data/students/pablo23hf/ERA5'

In [10]:
files = np.sort(list(traverseDir(root)))

In [11]:
ds_hourly = xr.open_mfdataset(files[0:72], 
                       concat_dim='time',
                       combine = 'nested',
                       chunks={"time":100})

In [13]:
# Recorto dominio para que sea cuadrado y múltiplo de 2
#n_cells=48
n_cells=8
lon_bnds= ds_hourly.longitude[10:10+n_cells]
lat_bnds= ds_hourly.latitude[0:0+n_cells]
ds_hourly = ds_hourly.sel(longitude = lon_bnds, latitude = lat_bnds)

## Generar los archivos .npz y .npy

In [14]:
import os
import numpy as np
import xarray as xr
import pandas as pd
from dask.diagnostics import ProgressBar

pbar = ProgressBar()
pbar.register()

#PARAMS
# for training data:
startdate = '19700101'
enddate = '19741231'
# for test data:
# startdate = '19750101'
# enddate = '19751231'

outpath='/home/jovyan/work/prueba_datos_ERA5'
ds_hourly['tp'] = ds_hourly['tp']*1000
ds_hourly=ds_hourly['tp']
# convert to 32bit
ds_hourly = ds_hourly.astype('float32')


# convert to numpy array
ds_hourly = ds_hourly.values

# now we want to reshape to (days,tperday,lat,lon)
t_per_day = int(24/1)

ntime,ny,nx = ds_hourly.shape
ndays = ntime / t_per_day
assert(ndays.is_integer())
ndays = int(ndays)
reshaped = ds_hourly.reshape((ndays,t_per_day,ny,nx))

final = reshaped

#np.savez_compressed(f'{outpath}/prueba.npz',data=final)
#np.save(f'{outpath}/prueba', final)
np.savez_compressed(f'{outpath}/prueba{startdate}-{enddate}.npz',data=final)
np.save(f'{outpath}/prueba{startdate}-{enddate}', final)

[########################################] | 100% Completed |  5.7s


## Valid training samples

In [17]:
import os
import pickle
import numpy as np
import numba
from dask.diagnostics import ProgressBar

pbar = ProgressBar()
pbar.register()

os.system('mkdir -p data')

#PARAMS
# for training data:
startdate = '19700101'
enddate = '19741231'
# for test data:
# startdate = '19750101'
# enddate = '19751231'

ndomain = n_cells  # gridpoints
stride = n_cells  # |ndomain # in which steps to scan the whole domain

tp_thresh_daily = 5  # mm. in the radardate the unit is mm/h, but then on 5 minutes steps.
# the conversion is done automatically in this script
n_thresh = 20
# END PARAMS

#if ndomain % 2 != 0:
#    raise ValueError(f'ndomain must be an even number')

datapath = '/home/jovyan/work/prueba_datos_ERA5'

#ifile = f'/home/jovyan/work/prueba_datos_ERA5/prueba.npy'
ifile = f'{datapath}/prueba{startdate}-{enddate}.npy'

data = np.load(ifile, mmap_mode='r')

if len(data.shape) != 4:
    raise ValueError(f'data has wrong number of dimensions {len(data.shape)} instead of 4')

# compute daily sum, which is the sum over the hour axis
n_days,nhour, ny, nx = data.shape

# compute all valid indices
# for this, we try out all ndomain x ndomain squares shifted by strides, and check whether they have any missing data,
# and if not, whether they adhere to the criteria set by tp_thresh_daily and n_thresh
# since this contains many for loops, we speed it up with numba


@numba.jit
def filter(data):
    final_valid_idcs = []
    # loop over timeslices
    for tidx in numba.prange(n_days):
        #print(tidx, '/', n_days)
        # daily sum
        sub = np.sum(data[tidx],axis=0)
        # loop over all possible boxes
        for ii in range(0, ny - ndomain, stride):
            for jj in range(0, nx - ndomain, stride):
                subsub = sub[ii:ii + ndomain, jj:jj + ndomain]
                # check for nan values
                if not np.any(np.isnan(subsub)):
                    # if at least n_thresh points are above the threshold,
                    # we use this box
                    if np.sum(subsub > tp_thresh_daily) >= n_thresh:
                        final_valid_idcs.append((tidx, ii, jj))


    return final_valid_idcs


final_valid_idcs = filter(data)


pickle.dump(final_valid_idcs, open(f'/home/jovyan/work/prueba_datos_ERA5/valid_indices_ERA5-{startdate}-{enddate}.pkl', 'wb'))

print(f'found {len(final_valid_idcs)} valid samples')

found 0 valid samples


In [18]:
data

memmap([[[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
           8.92207026e-04, 8.92207026e-04, 2.49594450e-02],
          [5.34765422e-03, 8.02241266e-03, 1.15875155e-02, ...,
           1.78255141e-03, 1.78255141e-03, 2.67475843e-03],
          [1.15875155e-02, 1.15875155e-02, 2.31768936e-02, ...,
           8.02241266e-03, 2.31768936e-02, 1.24797225e-02],
          ...,
          [1.69370323e-02, 2.22846866e-02, 2.85245478e-02, ...,
           5.43743372e-02, 8.02241266e-03, 7.13206828e-03],
          [1.51544809e-02, 1.96099281e-02, 8.02241266e-03, ...,
           4.90266830e-02, 2.13943422e-02, 1.24797225e-02],
          [8.91461968e-03, 7.13206828e-03, 4.45730984e-03, ...,
           3.83295119e-02, 1.60448253e-02, 9.80496407e-03]],

         [[1.69370323e-02, 1.87195837e-02, 1.78273767e-02, ...,
           8.02241266e-03, 5.34765422e-03, 8.02241266e-03],
          [2.67475843e-03, 4.45730984e-03, 8.02241266e-03, ...,
           8.92207026e-04, 0.00000000e+00, 8.922070

## Training

In [9]:
import pickle
import os
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib

matplotlib.use('agg')
from pylab import plt
from tqdm import trange
from skimage.util import view_as_windows
from matplotlib.colors import LogNorm
from tensorflow.keras.utils import GeneratorEnqueuer
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input

#PARAMS
# for training data:
startdate = '19700101'
enddate = '19741231'

ndomain = 16  # gridpoints
stride = 16

tp_thresh_daily = 5  # mm. in the radardate the unit is mm/h, but then on 5 minutes steps.
# the conversion is done automatically in this script
n_thresh = 20

# normalization of daily sums
# we ues the 99.9 percentile of 2010
#norm_scale = 127.4
norm_scale = 1

# neural network parameters
n_disc = 5
GRADIENT_PENALTY_WEIGHT = 10  # As per the paper
latent_dim = 100
batch_size = 32 # this is used as global variable in randomweightedaverage
# the training is done with increasing batch size. each tuple is
# a combination nof number of epochs and batch_size
#n_epoch_and_batch_size_list = ((5, 32), (10, 64), (10, 128), (20, 256))
n_epoch_and_batch_size_list = ((50, 32),)

plot_format = 'png'

name='results'

plotdirs = f'/home/jovyan/work/plots_{name}/'

outdirs = f'/home/jovyan/work/trained_models/{name}/'

# note for colab: sometimes mkdir does not work that way. in this case
# you have to create the directories manually
os.system(f'mkdir -p {plotdirs}')
os.system(f'mkdir -p {outdirs}')

# load data and precomputed indices

converted_data_paths = '/home/jovyan/work/prueba_datos_ERA5/'

indices_data_paths = '/home/jovyan/work/prueba_datos_ERA5/'

data_ifile = f'{converted_data_paths}/prueba{startdate}-{enddate}.npy'

indices_file = f'{indices_data_paths}/valid_indices_ERA5-{startdate}-{enddate}.pkl'
print('loading data')
# load the data as memmap
data = np.load(data_ifile, mmap_mode='r')


indices_all = pickle.load(open(indices_file, 'rb'))
# convert to array
indices_all = np.array(indices_all)
# this has shape (nsamples,3)
# each row is (tidx,yidx,xidx)
print('finished loading data')

# the data has dimensions (sample,hourofday,x,y)
n_days, nhours, ny, nx = data.shape
n_channel=1
# sanity checks
assert (len(data.shape) == 4)
assert (len(indices_all.shape) == 2)
assert (indices_all.shape[1] == 3)
assert (nhours == 24 // 1)
assert (np.max(indices_all[:, 0]) < n_days)
assert (np.max(indices_all[:, 1]) < ny)
assert (np.max(indices_all[:, 2]) < nx)
assert (data.dtype == 'float32')

n_samples = len(indices_all)


def generate_real_samples(n_batch):
    """get random sampples and do the last preprocessing on them"""
    while True:
        # get random sample of indices from the precomputed indices
        # for this we generate random indices for the index list (confusing termoonology, since we use
        # indices to index the list of indices...
        ixs = np.random.randint(n_samples, size=n_batch)
        idcs_batch = indices_all[ixs]

        # now we select the data corresponding to these indices

        data_wview = view_as_windows(data, (1, 1, ndomain, ndomain))[..., 0, 0, :,:]
        batch = data_wview[idcs_batch[:, 0], :, idcs_batch[:, 1], idcs_batch[:, 2]]
        # add empty channel dimension (necessary for keras, which expects a channel dimension)
        batch = np.expand_dims(batch, -1)
        # compute daily sum (which is the condition)
        batch_cond = np.sum(batch, axis=1) # daily sum

        # the data now is in mm/hour, but we want it as fractions of the daily sum for each day
        for i in range(n_batch):
            batch[i] = batch[i] / batch_cond[i]

        # normalize daily sum
        batch_cond = batch_cond / norm_scale
        assert (batch.shape == (n_batch, nhours, ndomain, ndomain, 1))
        assert (batch_cond.shape == (n_batch, ndomain, ndomain, 1))
        assert (~np.any(np.isnan(batch)))
        assert (~np.any(np.isnan(batch_cond)))
        assert (np.max(batch) <= 1)
        assert (np.min(batch) >= 0)

        yield [batch, batch_cond]


def generate_latent_points(n_batch):
    # generate points in the latent space and a random condition
    latent = np.random.normal(size=(n_batch, latent_dim))
    # randomly select conditions
    ixs = np.random.randint(0, n_samples, size=n_batch)
    idcs_batch = indices_all[ixs]

    data_wview = view_as_windows(data, (1, 1, ndomain, ndomain))[..., 0, 0, :,:]
    batch = data_wview[idcs_batch[:, 0], :, idcs_batch[:, 1], idcs_batch[:, 2]]
    # add empty channel dimension (necessary for keras, which expects a channel dimension)
    batch = np.expand_dims(batch, -1)
    batch_cond = np.sum(batch, axis=1) # daily sum
    # normalize daily sum
    batch_cond = batch_cond / norm_scale
    assert (batch_cond.shape == (n_batch, ndomain, ndomain, 1))
    assert (~np.any(np.isnan(batch_cond)))
    return [latent, batch_cond]


def generate_latent_points_as_generator(n_batch):
    while True:
        yield generate_latent_points(n_batch)


def generate_fake_samples(n_batch):
    # generate points in latent space
    latent, cond = generate_latent_points(n_batch)
    # predict outputs
    generated = generator.predict([latent, cond])
    return [generated, cond]


def generate(cond):
    latent = np.random.normal(size=(1, latent_dim))
    cond = np.expand_dims(cond, 0)
    return generator.predict([latent, cond])


def wasserstein_loss(y_true, y_pred):
    return tf.reduce_mean(y_true * y_pred)


class RandomWeightedAverage(tf.keras.layers.Layer):

    def call(self, inputs, **kwargs):
        global batch_size
        alpha = tf.random.uniform((batch_size,1, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

    def compute_output_shape(self, input_shape):
        return input_shape[0]


class GradientPenalty(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(GradientPenalty, self).__init__(**kwargs)

    def build(self, input_shapes):
        # Create a trainable weight variable for this layer.
        super(GradientPenalty, self).build(input_shapes)  # Be sure to call this somewhere!

    def call(self, inputs):
        target, wrt = inputs
        grad = K.gradients(target, wrt)[0]
        return K.sqrt(K.sum(K.batch_flatten(K.square(grad)), axis=1, keepdims=True))-1

    def compute_output_shape(self, input_shapes):
        return (input_shapes[1][0], 1)


# pixel-wise feature vector normalization layer
# from https://machinelearningmastery.com/how-to-train-a-progressive-growing-gan-in-keras-for-synthesizing-faces/
class PixelNormalization(tf.keras.layers.Layer):
    # initialize the layer
    def __init__(self, **kwargs):
        super(PixelNormalization, self).__init__(**kwargs)

    # perform the operation
    def call(self, inputs):
        # calculate square pixel values
        values = inputs ** 2.0
        # calculate the mean pixel values
        mean_values = K.mean(values, axis=-1, keepdims=True)
        # ensure the mean is not zero
        mean_values += 1.0e-8
        # calculate the sqrt of the mean squared value (L2 norm)
        l2 = K.sqrt(mean_values)
        # normalize values by the l2 norm
        normalized = inputs / l2
        return normalized

    # define the output shape of the layer
    def compute_output_shape(self, input_shape):
        return input_shape

def create_discriminator():
    # we add the condition as additional channel. For this we
    # expand its dimensions alon the nhours axis via linear scaling
    in_cond = tf.keras.layers.Input(shape=(ndomain, ndomain, 1))
    # add nhours dimension (size 1 for now)
    cond_expanded = tf.keras.layers.Reshape((1, ndomain, ndomain, 1))(in_cond)
    cond_expanded = tf.keras.layers.Lambda(lambda x: tf.keras.backend.repeat_elements(x, rep=nhours, axis=1))(
        cond_expanded)
    in_sample = tf.keras.layers.Input(shape=(nhours, ndomain, ndomain, 1))

    in_combined = tf.keras.layers.Concatenate(axis=-1)([in_sample, cond_expanded])
    kernel_size = (3, 3, 3)
    main_net = tf.keras.Sequential([

        tf.keras.layers.Conv3D(64, kernel_size=kernel_size, strides=2, input_shape=(nhours, ndomain, ndomain, 2),
                               padding="valid"),  # 11x7x7x32
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.Dropout(0.25),

        tf.keras.layers.Conv3D(128, kernel_size=kernel_size, strides=2, padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.Dropout(0.25),

        tf.keras.layers.Conv3D(256, kernel_size=kernel_size, strides=2, padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.Dropout(0.25),

        tf.keras.layers.Conv3D(256, kernel_size=kernel_size, strides=2, padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.Dropout(0.25),

        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(1, activation='linear'),
    ])
    out = main_net(in_combined)
    model = tf.keras.Model(inputs=[in_sample, in_cond], outputs=out)

    return model


def create_generator():

    # for the moment, the flat approach is used
    init = tf.keras.initializers.RandomNormal(stddev=0.02)
    # define model

    n_nodes = 256 * 2 * 2 * 3
    in_latent = tf.keras.layers.Input(shape=(latent_dim,))
    # the condition is a 2d array (ndomain x ndomain), we simply flatten it
    in_cond = tf.keras.layers.Input(shape=(ndomain, ndomain, n_channel))
    in_cond_flat = tf.keras.layers.Flatten()(in_cond)
    in_combined = tf.keras.layers.Concatenate()([in_latent, in_cond_flat])

    main_net = tf.keras.Sequential([
        tf.keras.layers.Dense(n_nodes, kernel_initializer=init),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.Reshape((3, 2, 2, 256)),

        tf.keras.layers.UpSampling3D(size=(2, 2, 2)),
        tf.keras.layers.Conv3D(256, (3, 3, 3), padding='same', kernel_initializer=init),
        PixelNormalization(),
        tf.keras.layers.LeakyReLU(alpha=0.2),

        tf.keras.layers.UpSampling3D(size=(2, 2, 2)),
        tf.keras.layers.Conv3D(128, (3, 3, 3), padding='same', kernel_initializer=init),
        PixelNormalization(),
        tf.keras.layers.LeakyReLU(alpha=0.2),

        tf.keras.layers.UpSampling3D(size=(2, 2, 2)),
        tf.keras.layers.Conv3D(64, (3, 3, 3), padding='same', kernel_initializer=init),
        PixelNormalization(),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        # output 24x16x16x1
        tf.keras.layers.Conv3D(1, (3, 3, 3), activation='linear', padding='same', kernel_initializer=init),
        # softmax per gridpoint, thus over nhours, which is axis 1 (Softmax also counts the batch axis)
        tf.keras.layers.Softmax(axis=1),
        # check for Nans (only for debugging)
        tf.keras.layers.Lambda(
            lambda x: tf.debugging.check_numerics(x, 'found nan in output of per_gridpoint_softmax')),

    ])

    out = main_net(in_combined)
    model = tf.keras.Model(inputs=[in_latent, in_cond], outputs=out)

    return model


print('building networks')
generator = create_generator()
critic = create_discriminator()
generator.trainable = False
# Image input (real sample)
real_img = tf.keras.layers.Input(shape=(nhours,ndomain,ndomain,n_channel))
# Noise input
z_disc = tf.keras.layers.Input(shape=(latent_dim,))
# Generate image based of noise (fake sample) and add label to the input
label = tf.keras.layers.Input(shape=(ndomain, ndomain, n_channel))
fake_img = generator([z_disc, label])
# Discriminator determines validity of the real and fake images
fake = critic([fake_img, label])
valid = critic([real_img, label])

# Construct weighted average between real and fake images
interpolated_img = RandomWeightedAverage()([real_img, fake_img])

# Determine validity of weighted sample
validity_interpolated = critic([interpolated_img, label])
# here we use the approach from https://github.com/jleinonen/geogan/blob/master/geogan/gan.py,
# where gradient panely is a keras layer, and then 'mse' used as loss for this output
disc_gp = GradientPenalty()([validity_interpolated, interpolated_img])

# default from https://arxiv.org/pdf/1704.00028.pdf
optimizer = tf.optimizers.Adam(lr=0.0001, beta_1=0, beta_2=0.9)

critic_model = tf.keras.Model(inputs=[real_img, label, z_disc], outputs=[valid, fake, disc_gp])
critic_model.compile(loss=[wasserstein_loss,
                                wasserstein_loss,
                                'mse'],
                          optimizer=optimizer,
                          loss_weights=[1, 1, 10])

# For the generator we freeze the critic's layers
critic.trainable = False
generator.trainable = True

# Sampled noise for input to generator
z_gen = Input(shape=(latent_dim,))
# add label to the input
label = tf.keras.layers.Input(shape=(ndomain, ndomain, n_channel))
# Generate images based of noise
img = generator([z_gen, label])
# Discriminator determines validity
valid = critic([img, label])
# Defines generator model
generator_model = tf.keras.Model([z_gen, label], valid)
generator_model.compile(loss=wasserstein_loss, optimizer=optimizer)
print('finished building networks')

# plot some real samples
# plot a couple of samples
plt.figure(figsize=(25, 25))
n_plot = 30
[X_real, cond_real] = next(generate_real_samples(n_plot))
for i in range(n_plot):
    plt.subplot(n_plot, 25, i * 25 + 1)
    plt.imshow(cond_real[i, :, :].squeeze(), cmap=plt.cm.gist_earth_r, norm=LogNorm(vmin=0.01, vmax=1))
    plt.axis('off')
    for j in range(1, 24):
        plt.subplot(n_plot, 25, i * 25 + j + 1)
        plt.imshow(X_real[i, j, :, :].squeeze(), vmin=0, vmax=1, cmap=plt.cm.hot_r)
        plt.axis('off')
plt.colorbar()
plt.savefig(f'{plotdirs}/real_samples.{plot_format}')

hist = {'d_loss': [], 'g_loss': []}
print(f'start training on {n_samples} samples')


def train(n_epochs, _batch_size, start_epoch=0):
    """
        train with fixed batch_size for given epochs
        make some example plots and save model after each epoch
    """
    global batch_size
    batch_size = _batch_size
    # create a dataqueue with the keras facilities. this allows
    # to prepare the data in parallel to the training
    sample_dataqueue = GeneratorEnqueuer(generate_real_samples(batch_size),
                                         use_multiprocessing=True)
    sample_dataqueue.start(workers=2, max_queue_size=10)
    sample_gen = sample_dataqueue.get()

    # targets for loss function
    gan_sample_dataqueue = GeneratorEnqueuer(generate_latent_points_as_generator(batch_size),
                                         use_multiprocessing=True)
    gan_sample_dataqueue.start(workers=2, max_queue_size=10)
    gan_sample_gen = gan_sample_dataqueue.get()

    # targets for loss function
    valid = -np.ones((batch_size, 1))
    fake = np.ones((batch_size, 1))
    dummy = np.zeros((batch_size, 1))  # Dummy gt for gradient penalty

    bat_per_epo = int(n_samples / batch_size)

    # we need to call the discriminator once in order
    # to initialize the input shapes
    [X_real, cond_real] = next(sample_gen)
    latent = np.random.normal(size=(batch_size, latent_dim))
    critic_model.predict([X_real, cond_real, latent])
    for i in trange(n_epochs):
        epoch = 1 + i + start_epoch
        # enumerate batches over the training set
        for j in trange(bat_per_epo):

            for _ in range(n_disc):
                # fetch a batch from the queue
                [X_real, cond_real] = next(sample_gen)
                latent = np.random.normal(size=(batch_size, latent_dim))
                d_loss = critic_model.train_on_batch([X_real, cond_real,latent], [valid, fake, dummy])
                # we get for losses back here. average, valid, fake, and gradient_penalty
                # we want the average of valid and fake
                d_loss = np.mean([d_loss[1], d_loss[2]])


            # train generator
            # prepare points in latent space as input for the generator
            [latent, cond] = next(gan_sample_gen)
            # update the generator via the discriminator's error
            g_loss = generator_model.train_on_batch([latent, cond], valid)
            # summarize loss on this batch
            print(f'{epoch}, {j + 1}/{bat_per_epo}, d_loss {d_loss}' + \
                  f' g:{g_loss} ')  # , d_fake:{d_loss_fake} d_real:{d_loss_real}')

            if np.isnan(g_loss) or np.isnan(d_loss):
                raise ValueError('encountered nan in g_loss and/or d_loss')

            hist['d_loss'].append(d_loss)
            hist['g_loss'].append(g_loss)


        # plot generated examples
        plt.figure(figsize=(25, 25))
        n_plot = 30
        X_fake, cond_fake = generate_fake_samples(n_plot)
        for iplot in range(n_plot):
            plt.subplot(n_plot, 25, iplot * 25 + 1)
            plt.imshow(cond_fake[iplot, :, :].squeeze(), cmap=plt.cm.gist_earth_r, norm=LogNorm(vmin=0.01, vmax=1))
            plt.axis('off')
            for jplot in range(1, 24):
                plt.subplot(n_plot, 25, iplot * 25 + jplot + 1)
                plt.imshow(X_fake[iplot, jplot, :, :].squeeze(), vmin=0, vmax=1, cmap=plt.cm.hot_r)
                plt.axis('off')
        plt.colorbar()
        plt.suptitle(f'epoch {epoch:04d}')
        plt.savefig(f'{plotdirs}/fake_samples_{epoch:04d}_{i:06d}.{plot_format}')#cambio i por j

        # plot loss
        plt.figure()
        plt.plot(hist['d_loss'], label='d_loss')
        plt.plot(hist['g_loss'], label='g_loss')
        plt.ylabel('batch')
        plt.legend()
        plt.savefig(f'{plotdirs}/training_loss.{plot_format}')
        pd.DataFrame(hist).to_csv('hist.csv')
        plt.close('all')

        generator.save(f'{outdirs}/gen_{epoch:04d}.h5')
        critic.save(f'{outdirs}/disc_{epoch:04d}.h5')


# the training is done with increasing batch size,
# as defined in n_epoch_and_batch_size_list at the beginning of the script
start_epoch = 0
for n_epochs, batch_size in  n_epoch_and_batch_size_list:
    train(n_epochs, batch_size, start_epoch)
    start_epoch = start_epoch + n_epochs #this is only needed for correct plot labelling

loading data
finished loading data
building networks
finished building networks
start training on 3 samples


  0%|          | 0/50 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|          | 0/50 [00:13<?, ?it/s]


UnboundLocalError: local variable 'j' referenced before assignment

## Evaluates the GAN and makes analysis plots

In [None]:
#! /pfs/nobackup/home/s/sebsc/miniconda3/envs/pr-disagg-env/bin/python
#SBATCH -A SNIC2019-3-611
#SBATCH --time=06:00:00
#SBATCH -N 1
#SBATCH --exclusive
"""
this script uses the trained generator to create precipitation scenarios.
a number of daily sum conditions are sampled from the test-data,
and for each sub-daily scenarios are generated with the generator.
The results are shown in various plots
"""

import pickle
import os
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib
import matplotlib.colors as mcolors
matplotlib.use('agg')
from pylab import plt
import seaborn as sns
import scipy.stats
from tqdm import trange
from skimage.util import view_as_windows
from matplotlib.colors import LogNorm
from tensorflow.keras import backend as K

# for reproducability, we set a fixed seed to the random number generator
np.random.seed(354)

# we need to specify train start and enddate to get correct filenames
train_startdate = '19700101'
train_enddate = '19741231'

eval_startdate = '19750101'
eval_enddate = '19751231'


# parameters (need to be the same as in training)
ndomain = 16  # gridpoints
stride = 16
latent_dim = 100

tp_thresh_daily = 40  # mm. in the radardate the unit is mm/h, but then on 5 minutes steps.
# the conversion is done automatically in this script
n_thresh = 80

# here we need to choose which epoch we use from the saved models (we saved them at the end of every
# epoch). visual inspection of the images generated from the training set showed
# that after epoch 20, things starts to detoriate. Therefore we use epoch 20.
epoch = 20
# normalization of daily sums
# we ues the 99.9 percentile of 2010
#norm_scale = 127.4
norm_scale = 1

plot_format = 'png'

name = 'results'

plotdir = f'/home/jovyan/work/plots_generated_{name}_rev1/'

outdir = f'/home/jovyan/work/trained_models/{name}/'

# note for colab: sometimes mkdir does not work that way. in this case
# you have to create the directories manually
os.system(f'mkdir -p {plotdir}')
os.system(f'mkdir -p {outdir}')

# load data and precomputed indices for the test data

converted_data_path = '/home/jovyan/work/prueba_datos_ERA5/'

indices_data_path = '/home/jovyan/work/prueba_datos_ERA5'


data_ifile = f'{converted_data_path}/prueba{startdate}-{enddate}.npy'


indices_file = f'{indices_data_path}/valid_indices_ERA5-{startdate}-{enddate}.pkl'
print('loading data')
# load the data as memmap
data = np.load(data_ifile, mmap_mode='r')

indices_all = pickle.load(open(indices_file, 'rb'))
# convert to array
indices_all = np.array(indices_all)
# this has shape (nsamples,3)
# each row is (tidx,yidx,xidx)
print('finished loading data')

# the data has dimensions (sample,hourofday,x,y)
n_days, nhours, ny, nx = data.shape
n_channel = 1
# sanity checks
assert (len(data.shape) == 4)
assert (len(indices_all.shape) == 2)
assert (indices_all.shape[1] == 3)
assert (nhours == 24 // 1)
assert (np.max(indices_all[:, 0]) < n_days)
assert (np.max(indices_all[:, 1]) < ny)
assert (np.max(indices_all[:, 2]) < nx)
assert (data.dtype == 'float32')

n_samples = len(indices_all)

print(f'evaluate in {n_samples} samples')

print('load the trained generator')
generator_file = f'{outdir}/gen_{epoch:04d}.h5'

# we need the custom layer PixelNormalization to load the generator
class PixelNormalization(tf.keras.layers.Layer):
    # initialize the layer
    def __init__(self, **kwargs):
        super(PixelNormalization, self).__init__(**kwargs)

    # perform the operation
    def call(self, inputs):
        # calculate square pixel values
        values = inputs ** 2.0
        # calculate the mean pixel values
        mean_values = K.mean(values, axis=-1, keepdims=True)
        # ensure the mean is not zero
        mean_values += 1.0e-8
        # calculate the sqrt of the mean squared value (L2 norm)
        l2 = K.sqrt(mean_values)
        # normalize values by the l2 norm
        normalized = inputs / l2
        return normalized

    # define the output shape of the layer
    def compute_output_shape(self, input_shape):
        return input_shape


gen = tf.keras.models.load_model(generator_file, compile=False,
                                 custom_objects={'PixelNormalization': PixelNormalization})


# in order to use the model, we need to compile it (even though we dont need the los function
# and optimizer here, since we only do prediction)
def wasserstein_loss(y_true, y_pred):
    # we use -1 for fake, and +1 for real labels
    return tf.reduce_mean(y_true * y_pred)

gen.compile(loss=wasserstein_loss, optimizer=tf.keras.optimizers.RMSprop(lr=0.00005))


def generate_real_samples_and_conditions(n_batch):
    """get random sampples and do the last preprocessing on them"""
    # get random sample of indices from the precomputed indices
    # for this we generate random indices for the index list (confusing termoonology, since we use
    # indices to index the list of indices...
    ixs = np.random.randint(n_samples, size=n_batch)
    idcs_batch = indices_all[ixs]

    # now we select the data corresponding to these indices
    data_wview = view_as_windows(data, (1, 1, ndomain, ndomain))[..., 0, 0, :, :]
    batch = data_wview[idcs_batch[:, 0], :, idcs_batch[:, 1], idcs_batch[:, 2]]
    # add empty channel dimension (necessary for keras, which expects a channel dimension)
    batch = np.expand_dims(batch, -1)
    # compute daily sum (which is the condition)
    batch_cond = np.sum(batch, axis=1)  # daily sum

    # the data now is in mm/hour, but we want it as fractions of the daily sum for each day
    for i in range(n_batch):
        batch[i] = batch[i] / batch_cond[i]

    # normalize daily sum
    batch_cond = batch_cond / norm_scale
    assert (batch.shape == (n_batch, nhours, ndomain, ndomain, 1))
    assert (batch_cond.shape == (n_batch, ndomain, ndomain, 1))
    assert (~np.any(np.isnan(batch)))
    assert (~np.any(np.isnan(batch_cond)))
    assert (np.max(batch) <= 1)
    assert (np.min(batch) >= 0)

    return [batch, batch_cond]


plt.rcParams['savefig.bbox'] = 'tight'
cmap = plt.cm.gist_earth_r
plotnorm = LogNorm(vmin=0.01, vmax=50)

# for each (real) condition, generate a couple of fake
# distributions, and plot them all together

n_to_generate = 20
n_per_batch = 10
n_batches = n_to_generate // n_per_batch
n_fake_per_real = 10
plotcount = 0
for ibatch in trange(n_batches):

    reals, conds = generate_real_samples_and_conditions(n_per_batch)

    for real, cond in zip(reals, conds):
        plotcount += 1
        # for each cond, make several predictions with different latent noise
        latent = np.random.normal(size=(n_fake_per_real, latent_dim))
        # for efficiency reason, we dont make a single forecast with the network, but
        # we batch all n_fake_per_real together
        cond_batch = np.repeat(cond[np.newaxis], repeats=n_fake_per_real, axis=0)
        generated = gen.predict([latent, cond_batch])


        # make a matrix of mapplots.
        # first column: condition (daily mean), the same for every row
        # first row: real fractions per hour
        # rest of the rows: generated fractions per hour, 1 row per realization
        fig = plt.figure(figsize=(25, 12))
        n_plot = n_fake_per_real + 1
        ax = plt.subplot(n_plot, 25, 1)
        # compute unnormalized daily sum. squeeze away empty channel dimension (for plotting)
        dsum = cond.squeeze() * norm_scale
        plt.imshow(dsum, cmap=cmap, norm=plotnorm)
        plt.axis('off')
        ax.annotate('real', xy=(0, 0.5), xytext=(-5, 0), xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='right', va='center', rotation='vertical')
        ax.annotate(f'daily sum', xy=(0.5, 1), xytext=(0, 5), xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='center', va='baseline')
        for jplot in range(1, 24 + 1):
            ax = plt.subplot(n_plot, 25, jplot + 1)
            plt.imshow(real[jplot - 1, :, :].squeeze(), vmin=0, vmax=1, cmap=plt.cm.Greys)
            plt.axis('off')
            ax.annotate(f'{jplot:02d}'':00', xy=(0.5, 1), xytext=(0, 5),
                        xycoords='axes fraction', textcoords='offset points',
                        size='large', ha='center', va='baseline')
        # plot fake samples
        for iplot in range(n_fake_per_real):
            plt.subplot(n_plot, 25, (iplot + 1) * 25 + 1)
            plt.imshow(dsum, cmap=cmap, norm=plotnorm)
            plt.axis('off')
            for jplot in range(1, 24 + 1):
                plt.subplot(n_plot, 25, (iplot + 1) * 25 + jplot + 1)
                im = plt.imshow(generated[iplot, jplot - 1, :, :].squeeze(), vmin=0, vmax=1, cmap=plt.cm.Greys)
                plt.axis('off')
        fig.subplots_adjust(right=0.93)
        cbar_ax = fig.add_axes([0.93, 0.15, 0.007, 0.7])
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('fraction of daily precipitation', fontsize=16)
        cbar.ax.tick_params(labelsize=16)

        plt.savefig(f'{plotdir}/generated_fractions_{epoch:04d}_{plotcount:04d}_allhours.{plot_format}')

        # now the same, but showing absolute precipitation fields
        # compute absolute precipitation from fraction of daily sum.
        # this can be done with numpy broadcasting.
        # we also have to multiply with norm_scale (because cond is normalized)
        generated_scaled = generated * cond * norm_scale

        real_scaled = real * cond * norm_scale
        fig = plt.figure(figsize=(25, 12))
        # plot real one
        ax = plt.subplot(n_plot, 25, 1)
        im = plt.imshow(dsum, cmap=cmap, norm=plotnorm)
        plt.axis('off')
        ax.annotate('real', xy=(0, 0.5), xytext=(-5, 0), xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='right', va='center', rotation='vertical')
        ax.annotate(f'daily sum', xy=(0.5, 1), xytext=(0, 5), xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='center', va='baseline')

        for jplot in range(1, 24 + 1):
            ax = plt.subplot(n_plot, 25, jplot + 1)
            plt.imshow(real_scaled[jplot - 1, :, :].squeeze(), cmap=cmap, norm=plotnorm)
            plt.axis('off')
            ax.annotate(f'{jplot:02d}'':00', xy=(0.5, 1), xytext=(0, 5),
                        xycoords='axes fraction', textcoords='offset points',
                        size='large', ha='center', va='baseline')
        # plot fake samples
        for iplot in range(n_fake_per_real):
            plt.subplot(n_plot, 25, (iplot + 1) * 25 + 1)
            plt.imshow(dsum, cmap=cmap, norm=plotnorm)
            plt.axis('off')
            for jplot in range(1, 24 + 1):
                plt.subplot(n_plot, 25, (iplot + 1) * 25 + jplot + 1)
                plt.imshow(generated_scaled[iplot, jplot - 1, :, :].squeeze(), cmap=cmap, norm=plotnorm)
                plt.axis('off')
        fig.subplots_adjust(right=0.93)
        cbar_ax = fig.add_axes([0.93, 0.15, 0.007, 0.7])
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('precipitation [mm]', fontsize=16)
        cbar.ax.tick_params(labelsize=16)
        plt.savefig(f'{plotdir}/generated_precip_{epoch:04d}_{plotcount:04d}_allhours.{plot_format}')

        np.save(f'data/real_precip_for_mapplots_{plotcount}.npy', real_scaled)

        # same as before, but only every 3rd hour.
        # rest of the rows: generated fractions per 3rd hour, 1 row per realization
        fig = plt.figure(figsize=(12, 12))
        n_plot = n_fake_per_real + 1
        ax = plt.subplot(n_plot, 9, 1)
        # compute unnormalized daily sum. squeeze away empty channel dimension (for plotting)
        dsum = cond.squeeze() * norm_scale
        plt.imshow(dsum, cmap=cmap, norm=plotnorm)
        plt.axis('off')
        ax.annotate('real', xy=(0, 0.5), xytext=(-5, 0), xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='right', va='center', rotation='vertical')
        ax.annotate(f'daily sum', xy=(0.5, 1), xytext=(0, 5), xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='center', va='baseline')
        for jplot in range(1, 8 + 1):
            ax = plt.subplot(n_plot, 9, jplot + 1)
            plt.imshow(real[jplot*3 - 1, :, :].squeeze(), vmin=0, vmax=1, cmap=plt.cm.Greys)
            plt.axis('off')
            hour = jplot*3
            ax.annotate(f'{hour:02d}'':00', xy=(0.5, 1), xytext=(0, 5),
                        xycoords='axes fraction', textcoords='offset points',
                        size='large', ha='center', va='baseline')
        # plot fake samples
        for iplot in range(n_fake_per_real):
            plt.subplot(n_plot, 8+1, (iplot + 1) * 9 + 1)
            plt.imshow(dsum, cmap=cmap, norm=plotnorm)
            plt.axis('off')
            for jplot in range(1, 8 + 1):
                plt.subplot(n_plot, 9, (iplot + 1) * 9 + jplot + 1)
                im = plt.imshow(generated[iplot, jplot*3 - 1, :, :].squeeze(), vmin=0, vmax=1, cmap=plt.cm.Greys)
                plt.axis('off')
        fig.subplots_adjust(right=0.93)
        cbar_ax = fig.add_axes([0.93, 0.15, 0.007, 0.7])
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('fraction of daily precipitation', fontsize=16)
        cbar.ax.tick_params(labelsize=16)

        plt.savefig(f'{plotdir}/generated_fractions_{epoch:04d}_{plotcount:04d}.{plot_format}')

        # now the same, but showing absolute precipitation fields
        # compute absolute precipitation from fraction of daily sum.
        # this can be done with numpy broadcasting.
        # we also have to multiply with norm_scale (because cond is normalized)
        generated_scaled = generated * cond * norm_scale

        real_scaled = real * cond * norm_scale
        fig = plt.figure(figsize=(12, 12))
        n_plot = n_fake_per_real + 1
        ax = plt.subplot(n_plot, 9, 1)
        # compute unnormalized daily sum. squeeze away empty channel dimension (for plotting)
        dsum = cond.squeeze() * norm_scale
        plt.imshow(dsum, cmap=cmap, norm=plotnorm)
        plt.axis('off')
        ax.annotate('real', xy=(0, 0.5), xytext=(-5, 0), xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='right', va='center', rotation='vertical')
        ax.annotate(f'daily sum', xy=(0.5, 1), xytext=(0, 5), xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='center', va='baseline')
        for jplot in range(1, 8 + 1):
            ax = plt.subplot(n_plot, 9, jplot + 1)
            plt.imshow(real_scaled[jplot*3 - 1, :, :].squeeze(), cmap=cmap, norm=plotnorm)
            plt.axis('off')
            hour = jplot*3
            ax.annotate(f'{hour:02d}'':00', xy=(0.5, 1), xytext=(0, 5),
                        xycoords='axes fraction', textcoords='offset points',
                        size='large', ha='center', va='baseline')
        # plot fake samples
        for iplot in range(n_fake_per_real):
            plt.subplot(n_plot, 8+1, (iplot + 1) * 9 + 1)
            plt.imshow(dsum, cmap=cmap, norm=plotnorm)
            plt.axis('off')
            for jplot in range(1, 8 + 1):
                plt.subplot(n_plot, 9, (iplot + 1) * 9 + jplot + 1)
                im = plt.imshow(generated_scaled[iplot, jplot*3 - 1, :, :].squeeze(), cmap=cmap, norm=plotnorm)
                plt.axis('off')
        fig.subplots_adjust(right=0.93)
        cbar_ax = fig.add_axes([0.93, 0.15, 0.007, 0.7])
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('precipitation [mm]', fontsize=16)
        cbar.ax.tick_params(labelsize=16)
        plt.savefig(f'{plotdir}/generated_precip_{epoch:04d}_{plotcount:04d}.{plot_format}')

        plt.close('all')


# compute statistics over
# many generated smaples
# we compute the areamean,
n_sample = 10000
amean_fraction_gen = []
amean_fraction_real = []
amean_gen = []
amean_real = []
dists_real = []
dists_gen = []


# for each real conditoin, we crate 1 fake sample
for i in trange(n_sample):
    real, cond = generate_real_samples_and_conditions(1)
    latent = np.random.normal(size=(1, latent_dim))
    generated = gen.predict([latent, cond])

    generated = generated.squeeze()
    real = real.squeeze()
    cond = cond.squeeze()
    # compute area means
    amean_fraction_gen.append(np.mean(generated, axis=(1, 2)).squeeze())
    amean_fraction_real.append(np.mean(real, axis=(1, 2)).squeeze())
    amean_gen.append(np.mean(generated * cond * norm_scale, axis=(1, 2)).squeeze())
    amean_real.append(np.mean(real * cond * norm_scale, axis=(1, 2)).squeeze())
    dists_real.append(real * cond * norm_scale)
    dists_gen.append(generated * cond * norm_scale)




amean_fraction_gen = np.array(amean_fraction_gen)
amean_fraction_real = np.array(amean_fraction_real)
amean_gen = np.array(amean_gen)
amean_real = np.array(amean_real)
dists_gen = np.array(dists_gen)
dists_real = np.array(dists_real)
np.save('data/generated_samples.npy',dists_gen)
np.save('data/real_samples.npy', dists_real)

def ecdf(data):
    x = np.sort(data)
    n = x.size
    y = np.arange(1, n+1) / n
    return(x, y)


sns.set_palette('colorblind')
# ecdf of area means. the hours are flattened
plt.figure()
ax1 = plt.subplot(211)
plt.plot(*ecdf(amean_gen.flatten()), label='gen')
plt.plot(*ecdf(amean_real.flatten()), label='real')
plt.legend(loc='upper left')
sns.despine()
plt.xlabel('mm/h')
plt.ylabel('ecdf areamean')
plt.semilogx()
# ecdf of (flattened) spatial data
ax2 = plt.subplot(212)
plt.plot(*ecdf(dists_gen.flatten()), label='gen')
plt.plot(*ecdf(dists_real.flatten()), label='real')
plt.legend(loc='upper left')
sns.despine()
plt.ylabel('ecdf')
plt.xlabel('mm/h')
plt.semilogx()
plt.tight_layout()
plt.savefig(f'{plotdir}/ecdf_allx_{epoch:04d}.png', dpi=400)
# cut at 0.1mm/h
ax1.set_xlim(xmin=0.5)
ax1.set_ylim(ymin=0.8, ymax=1.01)
ax2.set_xlim(xmin=0.1)
ax2.set_ylim(ymin=0.6, ymax=1.01)
plt.savefig(f'{plotdir}/ecdf_{epoch:04d}.png', dpi=400)

plt.close('all')
# free some memory
del dists_gen
del dists_real

# convert to pandas data frame, with timeofday ('hour') as additional column
res_df = []
for i in range(24):
    _df1 = pd.DataFrame({'fraction': amean_fraction_gen[:, i],
                         'precip': amean_gen[:, i],
                         'typ': 'generated',
                         'hour': i + 1}, index=np.arange(len(amean_gen)))
    _df2 = pd.DataFrame({'fraction': amean_fraction_real[:, i].squeeze(),
                         'precip': amean_real[:, i],
                         'typ': 'real',
                         'hour': i + 1}, index=np.arange(len(amean_gen)))
    res_df.append(_df1)
    res_df.append(_df2)


df = pd.concat(res_df)
df.to_csv(f'{plotdir}/gen_and_real_ameans_{epoch:04d}.csv')

# make boxplot
for showfliers in (True, False):

    plt.figure()
    plt.subplot(211)
    sns.boxplot('hour', 'precip', data=df, hue='typ', showfliers=showfliers)
    plt.xlabel('')
    sns.despine()
    plt.subplot(212)
    sns.boxplot('hour', 'fraction', data=df, hue='typ', showfliers=showfliers)
    sns.despine()
    plt.suptitle(f'n={n_sample}')
    plt.savefig(f'{plotdir}/daily_cycle_showfliers{showfliers}_{epoch:04d}.svg')


## for a single real one, generate a large
# number of fake distributions, and then
# plot the areamean in a lineplot
# we generate 100 fake distributions with different noise accross the samples
# and additionally 10 fake ones that use the same noise for all plots
# the latter we plot in the same color (1 seperate color for each generated one)
# so that we can compare them accross the plots

n_to_generate = 20
n_fake_per_real = 100
n_fake_per_real_samenoise = 10
plotcount = 0
hours = np.arange(1, 24 + 1)
# use same noise for all samples
latent_shared = np.random.normal(size=(n_fake_per_real_samenoise, latent_dim))
for isample in trange(n_to_generate):
    real, cond = generate_real_samples_and_conditions(1)
    latent= np.random.normal(size=(n_fake_per_real, latent_dim))
    # for efficiency reason, we dont make a single forecast with the network, but
    # we batch all n_fake_per_real together
    cond_batch = np.repeat(cond, repeats=n_fake_per_real, axis=0)
    cond_batch_samenoise = np.repeat(cond, repeats=n_fake_per_real_samenoise, axis=0)
    generated = gen.predict([latent, cond_batch], verbose=1)
    generated_samenoise = gen.predict([latent_shared, cond_batch_samenoise], verbose=1)
    real = real.squeeze()
    generated = generated.squeeze()
    generated_samenoise = generated_samenoise.squeeze()
    # compute are mean
    amean_real = np.mean(real * cond.squeeze() * norm_scale, (1, 2))
    amean_gen = np.mean(generated * cond.squeeze() * norm_scale, (2, 3))  # generated has a time dimension
    amean_gen_samenoise = np.mean(generated_samenoise * cond.squeeze() * norm_scale, (2, 3))  # generated has a time dimension

    plt.figure(figsize=(7, 3))
    plt.plot(hours, amean_gen.T, label='_nolegend_', alpha=0.3,color='#1b9e77')
    plt.plot(hours, amean_gen_samenoise.T, label='_nolegend_', alpha=1)
    plt.plot(hours, amean_real, label='real', color='black')
    plt.xlabel('hour')
    plt.ylabel('precipitation [mm/hour]')
    plt.legend()
    sns.despine()
    plt.savefig(f'{plotdir}/distribution_lineplot_samenosie_{epoch:04d}_{isample:04d}.svg')
    plt.close('all')

# take two conditions, and
# then plot the areamean of the resulting distributions, and check whether they are different
# we use the same noise for both, to avoid finding effects that only might come from the noise
n_fake_per_real = 1000
latent = np.random.normal(size=(n_fake_per_real, latent_dim))
for isample in trange(20):
    real1, cond1 = generate_real_samples_and_conditions(1)

    cond_batch1 = np.repeat(cond1, repeats=n_fake_per_real, axis=0)
    generated1 = gen.predict([latent, cond_batch1], verbose=1)
    real2, cond2 = generate_real_samples_and_conditions(1)
    cond_batch2 = np.repeat(cond2, repeats=n_fake_per_real, axis=0)
    generated2 = gen.predict([latent, cond_batch2], verbose=1)

    amean_fraction_real1 = np.mean(real1.squeeze(), (1, 2)).squeeze()
    amean_fraction_gen1 = np.mean(generated1, (2, 3)).squeeze()  # generated has a time dimension
    amean_fraction_real2 = np.mean(real2.squeeze(), (1, 2)).squeeze()
    amean_fraction_gen2 = np.mean(generated2.squeeze(), (2, 3)).squeeze()  # generated has a time dimension

    res_df = []
    for i in range(24):
        _df1 = pd.DataFrame({'fraction': amean_fraction_gen1[:, i],
                             'cond': 1,
                             'hour': i + 1}, index=np.arange(len(amean_fraction_gen1)))
        _df2 = pd.DataFrame({'fraction': amean_fraction_gen2[:, i],
                             'cond': 2,
                             'hour': i + 1}, index=np.arange(len(amean_fraction_gen1)))
        res_df.append(_df1)
        res_df.append(_df2)

    df = pd.concat(res_df)
    df.to_csv(f'{plotdir}/check_conditional_dist_samenoise_{epoch:04d}_{isample:04d}.csv')
    pvals_per_hour = []
    for hour in range(1,24+1):
        sub = df.query('hour==@hour')
        _, p = scipy.stats.ks_2samp(sub.query('cond==1')['fraction'], sub.query('cond==2')['fraction'])
        pvals_per_hour.append(p)
    np.savetxt(f'{plotdir}/check_conditional_dist_samenoise_KSpval_{epoch:04d}_{isample:04d}.txt', pvals_per_hour)
    for showfliers in (True, False):
        fig = plt.figure(constrained_layout=True, figsize=(6, 4.8))
        gs = fig.add_gridspec(2, 2)
        ax1 = fig.add_subplot(gs[0, 0])
        im = ax1.imshow(cond1.squeeze(), cmap=cmap, norm=plotnorm)
        plt.title('cond 1')
        plt.axis('off')
        plt.colorbar(im)
        ax2 = fig.add_subplot(gs[0, 1])
        im = ax2.imshow(cond2.squeeze(), cmap=cmap, norm=plotnorm)
        plt.title('cond 2')
        plt.axis('off')
        plt.colorbar(im)
        ax3 = fig.add_subplot(gs[1, :])
        sns.boxplot('hour', 'fraction', hue='cond', data=df, ax=ax3, showfliers=showfliers)
        sns.despine()
        plt.savefig(f'{plotdir}/check_conditional_dist_samenoise_showfliers{showfliers}_{epoch:04d}_{isample:04d}.svg')

    plt.close('all')

loading data
finished loading data
evaluate in 3 samples
load the trained generator


100%|██████████| 2/2 [06:36<00:00, 198.26s/it]
100%|██████████| 10000/10000 [09:17<00:00, 17.93it/s]
Exception in thread Thread-48:
Traceback (most recent call last):
  File "/opt/conda/envs/pr-disagg-env/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/opt/conda/envs/pr-disagg-env/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/pr-disagg-env/lib/python3.7/multiprocessing/pool.py", line 412, in _handle_workers
    pool._maintain_pool()
  File "/opt/conda/envs/pr-disagg-env/lib/python3.7/multiprocessing/pool.py", line 248, in _maintain_pool
    self._repopulate_pool()
  File "/opt/conda/envs/pr-disagg-env/lib/python3.7/multiprocessing/pool.py", line 241, in _repopulate_pool
    w.start()
  File "/opt/conda/envs/pr-disagg-env/lib/python3.7/multiprocessing/process.py", line 112, in start
    self._popen = self._Popen(self)
  File "/opt/conda/envs/pr-disagg-env/lib/python3.7/multiprocessin