In [2]:
import os
import sys
import time
import math
import logging
import warnings
import numpy as np
from glob import glob

# import matplotlib.pyplot as plt
# %matplotlib inline

# warnings.filterwarnings("ignore")

# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

# logging.getLogger("tensorflow").setLevel(logging.ERROR) 

In [3]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from GroupNormalization import GroupNormalization

tf.autograph.set_verbosity(0)
tf.get_logger().setLevel('ERROR')

2023-10-29 14:56:03.749250: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-29 14:56:06.070756: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
print('tf.version = {}'.format(tf.__version__))
print('tf on GPU = {}'.format(tf.test.is_gpu_available()))

tf.version = 2.12.1
tf on GPU = True


2023-10-29 15:00:34.606082: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /device:GPU:0 with 31141 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:8a:00.0, compute capability: 7.0


In [53]:
sys.path.insert(0, '/glade/u/home/ksha/GAN_proj/')
sys.path.insert(0, '/glade/u/home/ksha/GAN_proj/libs/')

from namelist import *
import data_utils as du
import model_utils as mu
import graph_utils as gu

**Hyperparameters**

In [4]:
total_timesteps = 1000
norm_groups = 8  # Number of groups used in GroupNormalization layer
lr = 1e-4

img_size = 128
img_channels = 1
clip_min = -1.0
clip_max = 1.0

first_conv_channels = 64
widths = [64, 96, 128, 256]
has_attention = [False, False, True, True]
num_res_blocks = 2  # Number of residual blocks

L_valid = 1000
model_name = '/glade/work/ksha/GAN/models/DB_APCP128/'

In [5]:
class GaussianDiffusion:
    """Gaussian diffusion utility.

    Args:
        beta_start: Start value of the scheduled variance
        beta_end: End value of the scheduled variance
        timesteps: Number of time steps in the forward process
    """

    def __init__(
        self,
        beta_start=1e-4,
        beta_end=0.02,
        timesteps=1000,
        clip_min=-1.0,
        clip_max=1.0,
    ):
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.timesteps = timesteps
        self.clip_min = clip_min
        self.clip_max = clip_max

        # Define the linear variance schedule
        self.betas = betas = np.linspace(
            beta_start,
            beta_end,
            timesteps,
            dtype=np.float64,  # Using float64 for better precision
        )
        self.num_timesteps = int(timesteps)

        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

        self.betas = tf.constant(betas, dtype=tf.float32)
        self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
        self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)

        # Calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = tf.constant(
            np.sqrt(alphas_cumprod), dtype=tf.float32
        )

        self.sqrt_one_minus_alphas_cumprod = tf.constant(
            np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32
        )

        self.log_one_minus_alphas_cumprod = tf.constant(
            np.log(1.0 - alphas_cumprod), dtype=tf.float32
        )

        self.sqrt_recip_alphas_cumprod = tf.constant(
            np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32
        )
        self.sqrt_recipm1_alphas_cumprod = tf.constant(
            np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32
        )

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)

        # Log calculation clipped because the posterior variance is 0 at the beginning
        # of the diffusion chain
        self.posterior_log_variance_clipped = tf.constant(
            np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32
        )

        self.posterior_mean_coef1 = tf.constant(
            betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
            dtype=tf.float32,
        )

        self.posterior_mean_coef2 = tf.constant(
            (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
            dtype=tf.float32,
        )

    def _extract(self, a, t, x_shape):
        """Extract some coefficients at specified timesteps,
        then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.

        Args:
            a: Tensor to extract from
            t: Timestep for which the coefficients are to be extracted
            x_shape: Shape of the current batched samples
        """
        batch_size = x_shape[0]
        out = tf.gather(a, t)
        return tf.reshape(out, [batch_size, 1, 1, 1])

    def q_mean_variance(self, x_start, t):
        """Extracts the mean, and the variance at current timestep.

        Args:
            x_start: Initial sample (before the first diffusion step)
            t: Current timestep
        """
        x_start_shape = tf.shape(x_start)
        mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start
        variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape)
        log_variance = self._extract(
            self.log_one_minus_alphas_cumprod, t, x_start_shape
        )
        return mean, variance, log_variance

    def q_sample(self, x_start, t, noise):
        """Diffuse the data.

        Args:
            x_start: Initial sample (before the first diffusion step)
            t: Current timestep
            noise: Gaussian noise to be added at the current timestep
        Returns:
            Diffused samples at timestep `t`
        """
        x_start_shape = tf.shape(x_start)
        return (
            self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start
            + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape)
            * noise
        )

    def predict_start_from_noise(self, x_t, t, noise):
        x_t_shape = tf.shape(x_t)
        return (
            self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t
            - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise
        )

    def q_posterior(self, x_start, x_t, t):
        """Compute the mean and variance of the diffusion
        posterior q(x_{t-1} | x_t, x_0).

        Args:
            x_start: Stating point(sample) for the posterior computation
            x_t: Sample at timestep `t`
            t: Current timestep
        Returns:
            Posterior mean and variance at current timestep
        """

        x_t_shape = tf.shape(x_t)
        posterior_mean = (
            self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start
            + self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance, t, x_t_shape)
        posterior_log_variance_clipped = self._extract(
            self.posterior_log_variance_clipped, t, x_t_shape
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, pred_noise, x, t, clip_denoised=True):
        x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise)
        if clip_denoised:
            x_recon = tf.clip_by_value(x_recon, self.clip_min, self.clip_max)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start=x_recon, x_t=x, t=t
        )
        return model_mean, posterior_variance, posterior_log_variance

    def p_sample(self, pred_noise, x, t, clip_denoised=True):
        """Sample from the diffusion model.

        Args:
            pred_noise: Noise predicted by the diffusion model
            x: Samples at a given timestep for which the noise was predicted
            t: Current timestep
            clip_denoised (bool): Whether to clip the predicted noise
                within the specified range or not.
        """
        model_mean, _, model_log_variance = self.p_mean_variance(
            pred_noise, x=x, t=t, clip_denoised=clip_denoised
        )
        noise = tf.random.normal(shape=x.shape, dtype=x.dtype)
        # No noise when t == 0
        nonzero_mask = tf.reshape(
            1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1, 1]
        )
        return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise

In [6]:
# Kernel initializer to use
def kernel_init(scale):
    scale = max(scale, 1e-10)
    return keras.initializers.VarianceScaling(
        scale, mode="fan_avg", distribution="uniform"
    )


class AttentionBlock(layers.Layer):
    """Applies self-attention.

    Args:
        units: Number of units in the dense layers
        groups: Number of groups to be used for GroupNormalization layer
    """

    def __init__(self, units, groups=8, **kwargs):
        self.units = units
        self.groups = groups
        super().__init__(**kwargs)

        self.norm = GroupNormalization(groups=groups)
        self.query = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.key = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.value = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.proj = layers.Dense(units, kernel_initializer=kernel_init(0.0))

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height = tf.shape(inputs)[1]
        width = tf.shape(inputs)[2]
        scale = tf.cast(self.units, tf.float32) ** (-0.5)

        #inputs = self.norm(inputs)
        q = self.query(inputs)
        k = self.key(inputs)
        v = self.value(inputs)

        attn_score = tf.einsum("bhwc, bHWc->bhwHW", q, k) * scale
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height * width])

        attn_score = tf.nn.softmax(attn_score, -1)
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height, width])

        proj = tf.einsum("bhwHW,bHWc->bhwc", attn_score, v)
        proj = self.proj(proj)
        return inputs + proj


class TimeEmbedding(layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.half_dim = dim // 2
        self.emb = math.log(10000) / (self.half_dim - 1)
        self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)

    def call(self, inputs):
        inputs = tf.cast(inputs, dtype=tf.float32)
        emb = inputs[:, None] * self.emb[None, :]
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        return emb


def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
    def apply(inputs):
        x, t = inputs
        input_width = x.shape[3]

        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(
                width, kernel_size=1, kernel_initializer=kernel_init(1.0)
            )(x)

        temb = activation_fn(t)
        temb = layers.Dense(width, kernel_initializer=kernel_init(1.0))(temb)[
            :, None, None, :
        ]

        x = GroupNormalization(groups=groups)(x)
        x = activation_fn(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
        )(x)

        x = layers.Add()([x, temb])
        x = GroupNormalization(groups=groups)(x)
        x = activation_fn(x)

        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(0.0)
        )(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownSample(width):
    def apply(x):
        x = layers.Conv2D(
            width,
            kernel_size=3,
            strides=2,
            padding="same",
            kernel_initializer=kernel_init(1.0),
        )(x)
        return x

    return apply


def UpSample(width, interpolation="nearest"):
    def apply(x):
        x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
        )(x)
        return x

    return apply


def TimeMLP(units, activation_fn=keras.activations.swish):
    def apply(inputs):
        temb = layers.Dense(
            units, activation=activation_fn, kernel_initializer=kernel_init(1.0)
        )(inputs)
        temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
        return temb

    return apply


def build_model(
    img_size,
    img_channels,
    widths,
    has_attention,
    num_res_blocks=2,
    norm_groups=8,
    interpolation="nearest",
    activation_fn=keras.activations.swish,
):
    image_input = layers.Input(
        shape=(img_size, img_size, img_channels), name="image_input"
    )
    time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")

    x = layers.Conv2D(
        first_conv_channels,
        kernel_size=(3, 3),
        padding="same",
        kernel_initializer=kernel_init(1.0),
    )(image_input)

    temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
    temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)

    skips = [x]

    # DownBlock
    for i in range(len(widths)):
        for _ in range(num_res_blocks):
            x = ResidualBlock(
                widths[i], groups=norm_groups, activation_fn=activation_fn
            )([x, temb])
            if has_attention[i]:
                x = AttentionBlock(widths[i], groups=norm_groups)(x)
            skips.append(x)

        if widths[i] != widths[-1]:
            x = DownSample(widths[i])(x)
            skips.append(x)

    # MiddleBlock
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
        [x, temb]
    )
    x = AttentionBlock(widths[-1], groups=norm_groups)(x)
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
        [x, temb]
    )

    # UpBlock
    for i in reversed(range(len(widths))):
        for _ in range(num_res_blocks + 1):
            x = layers.Concatenate(axis=-1)([x, skips.pop()])
            x = ResidualBlock(
                widths[i], groups=norm_groups, activation_fn=activation_fn
            )([x, temb])
            if has_attention[i]:
                x = AttentionBlock(widths[i], groups=norm_groups)(x)

        if i != 0:
            x = UpSample(widths[i], interpolation=interpolation)(x)

    # End block
    x = GroupNormalization(groups=norm_groups)(x)
    x = activation_fn(x)
    x = layers.Conv2D(1, (3, 3), padding="same", kernel_initializer=kernel_init(0.0))(x)
    return keras.Model([image_input, time_input], x, name="unet")

In [7]:
# Build the unet model
model = build_model(
    img_size=img_size,
    img_channels=img_channels,
    widths=widths,
    has_attention=has_attention,
    num_res_blocks=num_res_blocks,
    norm_groups=norm_groups,
    activation_fn=keras.activations.swish,
)

In [8]:
# Compile the model
model.compile(
    loss=keras.losses.MeanAbsoluteError(),
    optimizer=keras.optimizers.Adam(learning_rate=lr),
)

In [9]:
W_old = mu.dummy_loader(model_name)
model.set_weights(W_old)



In [10]:
gdf_util = GaussianDiffusion(timesteps=total_timesteps)

In [28]:
L_valid = 50

filenames = np.array(sorted(glob(BATCH_dir+'DB*.npy')))

L = len(filenames)
filename_valid = filenames[::20][:L_valid]
filename_train = list(set(filenames) - set(filename_valid))

L_train = len(filename_train)
Y_valid_HR = np.empty((L_valid, img_size, img_size, img_channels))

for i, name in enumerate(filename_valid):
    temp_data = np.load(name, allow_pickle=True)[()]
    patch = temp_data['patch'][..., None]
    patch = 2*(patch/np.max(patch)-0.5)
    Y_valid_HR[i, ...] = patch

In [29]:
t_valid_ = np.random.uniform(low=0, high=total_timesteps, size=(L_valid,))
t_valid = t_valid_.astype(int)

# sample random noise to be added to the images in the batch
noise_valid = np.random.normal(size=(L_valid, img_size, img_size, img_channels))
images_valid = np.array(gdf_util.q_sample(Y_valid_HR, t_valid, noise_valid))

In [30]:
pred_noise = model.predict([images_valid, t_valid])
record = np.mean(np.abs(noise_valid - pred_noise))
print('initial loss {}'.format(record))

initial loss 0.11760722623947549


In [31]:
x = noise_valid

for t in reversed(range(0, total_timesteps)):
    tt = tf.cast(tf.fill(L_valid, t), dtype=tf.int64)
    pred_noise = model.predict([x, tt])
    model_mean, _, model_log_variance =  gdf_util.p_mean_variance(pred_noise, x=x, t=tt, clip_denoised=True)
    nonzero_mask = (1 - (np.array(tt)==0)).reshape((len(noise_valid), 1, 1, 1))
    x = np.array(model_mean) + nonzero_mask * np.exp(0.5 * np.array(model_log_variance)) * np.random.normal(size=x.shape)

In [54]:
#cmap_pct, A = gu.precip_cmap()

In [100]:
# fig, AX = plt.subplots(1, 2, figsize=(9, 4.5))

# plt.tight_layout()

# CS1 = AX[0].pcolormesh(noise_valid[49, ..., 0], vmin=-1, vmax=1, cmap=cmap_pct)
# #plt.colorbar(CS1)
# AX[0].tick_params(axis="both", which="both", bottom=False, top=False, labelbottom=False, left=False, right=False, labelleft=False)
# AX[0].set_title('Forward diffusion (t=1000)', fontsize=14)

# CS2 = AX[1].pcolormesh(x[49, :, :, 0], vmin=-1, vmax=1, cmap=cmap_pct)
# #plt.colorbar(CS2)
# AX[1].tick_params(axis="both", which="both", bottom=False, top=False, labelbottom=False, left=False, right=False, labelleft=False)
# AX[1].set_title('Reverse diffusion (t=0)', fontsize=14)

In [99]:
# plt.figure(figsize=(8, 8))
# plt.tight_layout()
# plt.pcolormesh(noise_valid[49, ..., 0], vmin=-1, vmax=1, cmap=cmap_pct)
# plt.tick_params(axis="both", which="both", bottom=False, top=False, labelbottom=False, left=False, right=False, labelleft=False)

In [101]:
# n = 9

# xt1 = np.array(gdf_util.q_sample(Y_valid_HR[n, ...][None, ...], 1, noise_valid[n, ...][None, ...]))
# xt100 = np.array(gdf_util.q_sample(Y_valid_HR[n, ...][None, ...], 50, noise_valid[n, ...][None, ...]))
# xt500 = np.array(gdf_util.q_sample(Y_valid_HR[n, ...][None, ...], 100, noise_valid[n, ...][None, ...]))
# xt1000 = np.array(gdf_util.q_sample(Y_valid_HR[n, ...][None, ...], 999, noise_valid[n, ...][None, ...]))

# fig, AX = plt.subplots(2, 2, figsize=(9, 7))

# plt.tight_layout()

# CS1 = AX[0][0].pcolormesh(Y_valid_HR[n, ..., 0], vmin=-1, vmax=1, cmap=cmap_pct)
# AX[0][0].tick_params(axis="both", which="both", bottom=False, top=False, labelbottom=False, left=False, right=False, labelleft=False)
# AX[0][0].set_title('MRMS (128-by-128; 0.01 deg)', fontsize=14)

# CS2 = AX[0][1].pcolormesh(xt100[0, ..., 0], vmin=-1, vmax=1, cmap=cmap_pct)
# AX[0][1].tick_params(axis="both", which="both", bottom=False, top=False, labelbottom=False, left=False, right=False, labelleft=False)
# AX[0][1].set_title('Forward diffusion (t=50)', fontsize=14)

# CS3 = AX[1][0].pcolormesh(xt500[0, ..., 0], vmin=-1, vmax=1, cmap=cmap_pct)
# AX[1][0].tick_params(axis="both", which="both", bottom=False, top=False, labelbottom=False, left=False, right=False, labelleft=False)
# AX[1][0].set_title('Forward diffusion (t=100)', fontsize=14)

# CS4 = AX[1][1].pcolormesh(xt1000[0, ..., 0], vmin=-1, vmax=1, cmap=plt.cm.nipy_spectral_r)
# AX[1][1].tick_params(axis="both", which="both", bottom=False, top=False, labelbottom=False, left=False, right=False, labelleft=False)
# AX[1][1].set_title('Forward diffusion (t=1000)', fontsize=14)

In [None]:
epochs = 99999
L_train = 512
batch_size = 48

min_del = 0.0
max_tol = 3 # early stopping with 2-epoch patience
tol = 0

Y_batch_HR = np.empty((batch_size, img_size, img_size, img_channels))
Y_batch_HR[...] = np.nan

for i in range(epochs):
    print('epoch = {}'.format(i))
    if i == 0:
        print('Initial validation loss: {}'.format(record))
        
    start_time = time.time()
    # loop over batches
    
    for j in range(L_train):

        # collect training batches
        
        inds_rnd = du.shuffle_ind(L_train)
        inds_ = inds_rnd[:batch_size]

        for k, ind in enumerate(inds_):
            # import batch data
            temp_name = filename_train[ind]
            temp_data = np.load(temp_name, allow_pickle=True)[()]
            
            patch = temp_data['patch'][..., None]
            patch = 2*(patch/np.max(patch)-0.5)
            Y_batch_HR[k, ...] = patch

        # sample timesteps uniformly
        t_ = np.random.uniform(low=0, high=total_timesteps, size=(batch_size,))
        t = t_.astype(int)

        # sample random noise to be added to the images in the batch
        noise = np.random.normal(size=(batch_size, img_size, img_size, img_channels))
        images_t = np.array(gdf_util.q_sample(Y_batch_HR, t, noise))

        model.train_on_batch([images_t, t], noise)
        
    # on epoch-end
    pred_noise = model.predict([images_valid, t_valid])
    record_temp = np.mean(np.abs(noise_valid - pred_noise))
    model.save(model_name)
    # print out valid loss change
    if record - record_temp > min_del:
        print('Validation loss improved from {} to {}'.format(record, record_temp))
        record = record_temp
        
    else:
        print('Validation loss {} NOT improved'.format(record_temp))

    print("--- %s seconds ---" % (time.time() - start_time))
    # mannual callbacks

epoch = 0
Initial validation loss: 0.04726360960251018




Validation loss 0.06482305978707407 NOT improved
--- 562.0616343021393 seconds ---
epoch = 1




Validation loss 0.06268290248332475 NOT improved
--- 557.5204842090607 seconds ---
epoch = 2




Validation loss 0.06537165687700841 NOT improved
--- 862.9260144233704 seconds ---
epoch = 3




Validation loss 0.06388709119474965 NOT improved
--- 581.2239294052124 seconds ---
epoch = 4




Validation loss 0.06525931059516694 NOT improved
--- 571.5544695854187 seconds ---
epoch = 5




Validation loss 0.06549951202822332 NOT improved
--- 565.7670652866364 seconds ---
epoch = 6




Validation loss 0.06542037142380665 NOT improved
--- 560.3352584838867 seconds ---
epoch = 7




Validation loss 0.06558676257143567 NOT improved
--- 558.7391023635864 seconds ---
epoch = 8




Validation loss 0.06478595981533244 NOT improved
--- 557.9981825351715 seconds ---
epoch = 9




Validation loss 0.06556927765791153 NOT improved
--- 557.1158781051636 seconds ---
epoch = 10




Validation loss 0.06624470861103428 NOT improved
--- 906.5992517471313 seconds ---
epoch = 11




Validation loss 0.0683566086854175 NOT improved
--- 571.6454033851624 seconds ---
epoch = 12




Validation loss 0.06714016328456603 NOT improved
--- 566.1405913829803 seconds ---
epoch = 13




Validation loss 0.06602684290167231 NOT improved
--- 559.1043643951416 seconds ---
epoch = 14




Validation loss 0.06689863532043781 NOT improved
--- 558.4184081554413 seconds ---
epoch = 15




Validation loss 0.06682977720368358 NOT improved
--- 556.3433685302734 seconds ---
epoch = 16




Validation loss 0.068029294873693 NOT improved
--- 556.64697098732 seconds ---
epoch = 17




Validation loss 0.06807845223333192 NOT improved
--- 557.034410238266 seconds ---
epoch = 18




Validation loss 0.06766445151282477 NOT improved
--- 932.6213793754578 seconds ---
epoch = 19




Validation loss 0.06839045693527299 NOT improved
--- 578.8871653079987 seconds ---
epoch = 20




Validation loss 0.06836063049402714 NOT improved
--- 569.2685670852661 seconds ---
epoch = 21




Validation loss 0.06898118044224374 NOT improved
--- 565.1466281414032 seconds ---
epoch = 22




Validation loss 0.06975231991372935 NOT improved
--- 560.6628084182739 seconds ---
epoch = 23




Validation loss 0.06793546611364865 NOT improved
--- 560.0448389053345 seconds ---
epoch = 24




Validation loss 0.06872564142182658 NOT improved
--- 560.4207479953766 seconds ---
epoch = 25




Validation loss 0.07036058791068873 NOT improved
--- 560.7791764736176 seconds ---
epoch = 26




Validation loss 0.07058852430895451 NOT improved
--- 564.4891722202301 seconds ---
epoch = 27




Validation loss 0.06924581787424242 NOT improved
--- 969.8013830184937 seconds ---
epoch = 28




Validation loss 0.07187131853981342 NOT improved
--- 580.9059166908264 seconds ---
epoch = 29




Validation loss 0.0714325810280568 NOT improved
--- 573.9393050670624 seconds ---
epoch = 30




Validation loss 0.07214662119892692 NOT improved
--- 576.8839266300201 seconds ---
epoch = 31




Validation loss 0.06927335739916829 NOT improved
--- 567.3811116218567 seconds ---
epoch = 32




Validation loss 0.0708491959881131 NOT improved
--- 565.2825622558594 seconds ---
epoch = 33




Validation loss 0.06999735840499681 NOT improved
--- 562.8090789318085 seconds ---
epoch = 34




Validation loss 0.07165650782970251 NOT improved
--- 567.0356924533844 seconds ---
epoch = 35




Validation loss 0.0709252347410295 NOT improved
--- 567.5563044548035 seconds ---
epoch = 36




Validation loss 0.07144069273680338 NOT improved
--- 570.6778547763824 seconds ---
epoch = 37




Validation loss 0.07067229717236308 NOT improved
--- 1009.5932667255402 seconds ---
epoch = 38




Validation loss 0.06981427875529307 NOT improved
--- 582.6738147735596 seconds ---
epoch = 39




Validation loss 0.07052888002736071 NOT improved
--- 581.9670603275299 seconds ---
epoch = 40




Validation loss 0.07068385499286957 NOT improved
--- 580.9595561027527 seconds ---
epoch = 41




Validation loss 0.0703570701426098 NOT improved
--- 580.2310395240784 seconds ---
epoch = 42




Validation loss 0.06939566421766521 NOT improved
--- 571.2825398445129 seconds ---
epoch = 43




Validation loss 0.07465238298387271 NOT improved
--- 1063.8349649906158 seconds ---
epoch = 44




Validation loss 0.07332964180505369 NOT improved
--- 559.2754867076874 seconds ---
epoch = 45




Validation loss 0.07171591394753046 NOT improved
--- 560.798104763031 seconds ---
epoch = 46




Validation loss 0.07034513345865996 NOT improved
--- 560.5610065460205 seconds ---
epoch = 47




Validation loss 0.07109881135501546 NOT improved
--- 561.5224254131317 seconds ---
epoch = 48




Validation loss 0.07091938956350725 NOT improved
--- 1023.332884311676 seconds ---
epoch = 49


In [21]:
# 2. Sample timesteps uniformly
t_ = np.random.uniform(low=0, high=total_timesteps, size=(batch_size,))
t = t_.astype(int)

# 3. Sample random noise to be added to the images in the batch
noise = np.random.normal(size=(batch_size, 128, 128, 1))

images_t_ = np.array(gdf_util.q_sample(Y_valid_HR, t, noise))