In [1]:
import os
import sys
import time
import math
import logging
import warnings
import numpy as np
from glob import glob

warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
logging.getLogger("tensorflow").setLevel(logging.ERROR) 

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from GroupNormalization import GroupNormalization

tf.autograph.set_verbosity(0)
tf.get_logger().setLevel('ERROR')

In [4]:
print('tf.version = {}'.format(tf.__version__))
print('tf on GPU = {}'.format(tf.test.is_gpu_available()))

tf.version = 2.12.1
tf on GPU = True


In [5]:
tf.config.run_functions_eagerly(True)

In [6]:
sys.path.insert(0, '/glade/u/home/ksha/GAN_proj/')
sys.path.insert(0, '/glade/u/home/ksha/GAN_proj/libs/')

from namelist import *
import data_utils as du
import model_utils as mu
#import graph_utils as gu

**Hyperparameters**

In [7]:
class GaussianDiffusion:
    """Gaussian diffusion utility.

    Args:
        beta_start: Start value of the scheduled variance
        beta_end: End value of the scheduled variance
        timesteps: Number of time steps in the forward process
    """

    def __init__(
        self,
        beta_start=1e-4,
        beta_end=0.02,
        timesteps=100,
        clip_min=-1.0,
        clip_max=1.0,
    ):
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.timesteps = timesteps
        self.clip_min = clip_min
        self.clip_max = clip_max

        # Define the linear variance schedule
        # ============ Linear ============== #
        # self.betas = betas = np.linspace(
        #     beta_start,
        #     beta_end,
        #     timesteps,
        #     dtype=np.float64,  # Using float64 for better precision
        # )
        # =========== tuned =============== #
        sch_ = (beta_end/beta_start)**(1/3)
        self.betas = betas = np.linspace(1, sch_, timesteps, dtype=np.float64)
        
        self.num_timesteps = int(timesteps)

        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

        self.betas = tf.constant(betas, dtype=tf.float32)
        self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
        self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)

        # Calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = tf.constant(
            np.sqrt(alphas_cumprod), dtype=tf.float32
        )

        self.sqrt_one_minus_alphas_cumprod = tf.constant(
            np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32
        )

        self.log_one_minus_alphas_cumprod = tf.constant(
            np.log(1.0 - alphas_cumprod), dtype=tf.float32
        )

        self.sqrt_recip_alphas_cumprod = tf.constant(
            np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32
        )
        self.sqrt_recipm1_alphas_cumprod = tf.constant(
            np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32
        )

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)

        # Log calculation clipped because the posterior variance is 0 at the beginning
        # of the diffusion chain
        self.posterior_log_variance_clipped = tf.constant(
            np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32
        )

        self.posterior_mean_coef1 = tf.constant(
            betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
            dtype=tf.float32,
        )

        self.posterior_mean_coef2 = tf.constant(
            (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
            dtype=tf.float32,
        )

    def _extract(self, a, t, x_shape):
        """Extract some coefficients at specified timesteps,
        then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.

        Args:
            a: Tensor to extract from
            t: Timestep for which the coefficients are to be extracted
            x_shape: Shape of the current batched samples
        """
        batch_size = x_shape[0]
        out = tf.gather(a, t)
        return tf.reshape(out, [batch_size, 1, 1, 1])

    def q_mean_variance(self, x_start, t):
        """Extracts the mean, and the variance at current timestep.

        Args:
            x_start: Initial sample (before the first diffusion step)
            t: Current timestep
        """
        x_start_shape = tf.shape(x_start)
        mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start
        variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape)
        log_variance = self._extract(
            self.log_one_minus_alphas_cumprod, t, x_start_shape
        )
        return mean, variance, log_variance

    def q_sample(self, x_start, t, noise):
        """Diffuse the data.

        Args:
            x_start: Initial sample (before the first diffusion step)
            t: Current timestep
            noise: Gaussian noise to be added at the current timestep
        Returns:
            Diffused samples at timestep `t`
        """
        x_start_shape = tf.shape(x_start)
        return (
            self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start
            + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape)
            * noise
        )

    def predict_start_from_noise(self, x_t, t, noise):
        x_t_shape = tf.shape(x_t)
        return (
            self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t
            - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise
        )

    def q_posterior(self, x_start, x_t, t):
        """Compute the mean and variance of the diffusion
        posterior q(x_{t-1} | x_t, x_0).

        Args:
            x_start: Stating point(sample) for the posterior computation
            x_t: Sample at timestep `t`
            t: Current timestep
        Returns:
            Posterior mean and variance at current timestep
        """

        x_t_shape = tf.shape(x_t)
        posterior_mean = (
            self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start
            + self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance, t, x_t_shape)
        posterior_log_variance_clipped = self._extract(
            self.posterior_log_variance_clipped, t, x_t_shape
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, pred_noise, x, t, clip_denoised=True):
        x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise)
        if clip_denoised:
            x_recon = tf.clip_by_value(x_recon, self.clip_min, self.clip_max)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start=x_recon, x_t=x, t=t
        )
        return model_mean, posterior_variance, posterior_log_variance

    def p_sample(self, pred_noise, x, t, clip_denoised=True):
        """Sample from the diffusion model.

        Args:
            pred_noise: Noise predicted by the diffusion model
            x: Samples at a given timestep for which the noise was predicted
            t: Current timestep
            clip_denoised (bool): Whether to clip the predicted noise
                within the specified range or not.
        """
        model_mean, _, model_log_variance = self.p_mean_variance(
            pred_noise, x=x, t=t, clip_denoised=clip_denoised
        )
        noise = tf.random.normal(shape=x.shape, dtype=x.dtype)
        # No noise when t == 0
        nonzero_mask = tf.reshape(
            1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1, 1]
        )
        return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise

In [8]:
# Kernel initializer to use
def kernel_init(scale):
    scale = max(scale, 1e-10)
    return keras.initializers.VarianceScaling(
        scale, mode="fan_avg", distribution="uniform"
    )

class TimeEmbedding(layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.half_dim = dim // 2
        self.emb = math.log(10000) / (self.half_dim - 1)
        self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)

    def call(self, inputs):
        inputs = tf.cast(inputs, dtype=tf.float32)
        emb = inputs[:, None] * self.emb[None, :]
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        return emb


def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
    def apply(inputs):
        x, t = inputs
        input_width = x.shape[3]

        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(
                width, kernel_size=1, kernel_initializer=kernel_init(1.0)
            )(x)

        temb = activation_fn(t)
        temb = layers.Dense(width, kernel_initializer=kernel_init(1.0))(temb)[
            :, None, None, :
        ]

        x = layers.GroupNormalization(groups=groups)(x)
        x = activation_fn(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
        )(x)

        x = layers.Add()([x, temb])
        x = layers.GroupNormalization(groups=groups)(x)
        x = activation_fn(x)

        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(0.0)
        )(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownSample(width):
    def apply(x):
        x = layers.Conv2D(
            width,
            kernel_size=3,
            strides=2,
            padding="same",
            kernel_initializer=kernel_init(1.0),
        )(x)
        return x

    return apply


def UpSample(width, interpolation="nearest"):
    def apply(x):
        x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
        )(x)
        return x

    return apply


def TimeMLP(units, activation_fn=keras.activations.swish):
    def apply(inputs):
        temb = layers.Dense(
            units, activation=activation_fn, kernel_initializer=kernel_init(1.0)
        )(inputs)
        temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
        return temb

    return apply


# def build_model(input_shape, gfs_shape, widths, has_attention, num_res_blocks=2, norm_groups=8,
#                 interpolation="nearest", activation_fn=keras.activations.swish,):
    
#     image_input = layers.Input(shape=input_shape, name="image_input")
#     time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")
#     gfs_input = layers.Input(shape=gfs_shape, name="gfs_input")
    
#     x = layers.Conv2D(first_conv_channels, kernel_size=(3, 3), padding="same",
#                       kernel_initializer=kernel_init(1.0),)(image_input)

#     temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
#     temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)

#     skips = [x]

#     # DownBlock
#     for i in range(len(widths)):
#         for _ in range(num_res_blocks):
#             x = ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)([x, temb])
            
#             if has_attention[i]:
#                 x_gfs = layers.Conv2D(widths[i], kernel_size=(2**i, 2**i), strides=2**i, padding="same",)(gfs_input)
#                 x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[i])(x, x_gfs)
                
#             skips.append(x)

#         if widths[i] != widths[-1]:
#             x = DownSample(widths[i])(x)
#             skips.append(x)

#     # MiddleBlock
#     x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])
    
#     x_gfs = layers.Conv2D(widths[-1], kernel_size=(8, 8), strides=8, padding="same",)(gfs_input)
#     x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[-1])(x, x_gfs)
    
#     x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])

#     # UpBlock
#     for i in reversed(range(len(widths))):
#         for _ in range(num_res_blocks + 1):
#             x = layers.Concatenate(axis=-1)([x, skips.pop()])
#             x = ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)([x, temb])
            
#             if has_attention[i]:
#                 x_gfs = layers.Conv2D(widths[i], kernel_size=(2**i, 2**i), strides=2**i, padding="same",)(gfs_input)
#                 x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[i])(x, x_gfs)

#         if i != 0:
#             x = UpSample(widths[i], interpolation=interpolation)(x)

#     # End block
#     x = layers.GroupNormalization(groups=norm_groups)(x)
#     x = activation_fn(x)
#     x = layers.Conv2D(1, (3, 3), padding="same", kernel_initializer=kernel_init(0.0))(x)
#     return keras.Model([image_input, time_input, gfs_input], x, name="unet")

def build_model(input_shape, gfs_shape, widths, has_attention, num_res_blocks=2, norm_groups=8,
                interpolation="nearest", activation_fn=keras.activations.swish,):
    
    image_input = layers.Input(shape=input_shape, name="image_input")
    time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")
    gfs_input = layers.Input(shape=gfs_shape, name="gfs_input")
    
    x = layers.Conv2D(first_conv_channels, kernel_size=(3, 3), padding="same",
                      kernel_initializer=kernel_init(1.0),)(image_input)

    temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
    temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)

    skips = [x]

    # DownBlock
    for i in range(len(widths)):
        for _ in range(num_res_blocks):
            x = ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)([x, temb])
            
            if has_attention[i]:
                x_gfs = gfs_input
                for _ in range(i):
                    x_gfs = layers.Conv2D(widths[i], kernel_size=(2, 2), strides=2, padding="same",)(x_gfs)
                x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[i])(x, x_gfs)
                
            skips.append(x)

        if widths[i] != widths[-1]:
            x = DownSample(widths[i])(x)
            skips.append(x)

    # MiddleBlock
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])
    
    x_gfs = gfs_input
    for _ in range(3):
        x_gfs = layers.Conv2D(widths[i], kernel_size=(2, 2), strides=2, padding="same",)(x_gfs)
    x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[i])(x, x_gfs)
    x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[-1])(x, x_gfs)
    
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])

    # UpBlock
    for i in reversed(range(len(widths))):
        for _ in range(num_res_blocks + 1):
            x = layers.Concatenate(axis=-1)([x, skips.pop()])
            x = ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)([x, temb])
            
            if has_attention[i]:
                x_gfs = gfs_input
                for _ in range(i):
                    x_gfs = layers.Conv2D(widths[i], kernel_size=(2, 2), strides=2, padding="same",)(x_gfs)
                x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[i])(x, x_gfs)

        if i != 0:
            x = UpSample(widths[i], interpolation=interpolation)(x)

    # End block
    x = layers.GroupNormalization(groups=norm_groups)(x)
    x = activation_fn(x)
    x = layers.Conv2D(input_shape[-1], (3, 3), padding="same", kernel_initializer=kernel_init(0.0))(x)
    return keras.Model([image_input, time_input, gfs_input], x, name="unet")

# config model

In [9]:
total_timesteps = 100
norm_groups = 8  # Number of groups used in GroupNormalization layer
lr = 1e-4

clip_min = -1.0
clip_max = 1.0


widths = [64, 96, 128, 256]
first_conv_channels = widths[0]

has_attention = [False, False, False, True]
num_res_blocks = 1  # Number of residual blocks

model_name = '/glade/work/ksha/GAN/models/LDM_base/'


input_shape = (32, 32, 16)
gfs_shape = (32, 32, 256)

In [10]:
# Build the unet model
model = build_model(input_shape=input_shape, gfs_shape=gfs_shape, widths=widths,
                    has_attention=has_attention, num_res_blocks=num_res_blocks, 
                    norm_groups=norm_groups, activation_fn=keras.activations.swish)

#17,565,729
#17,602,017

In [11]:
# Compile the model
model.compile(loss=keras.losses.MeanAbsoluteError(), optimizer=keras.optimizers.Adam(learning_rate=1e-4),)

In [12]:
# W_old = mu.dummy_loader(model_name)
# model.set_weights(W_old)

In [13]:
gdf_util = GaussianDiffusion(timesteps=total_timesteps)

In [14]:
BATCH_dir = '/glade/campaign/cisl/aiml/ksha/BATCH_LDM/'

In [15]:
def reverse_diffuse(model, x_in1, x_in2, total_timesteps, gdf_util):
    L_valid = len(x_in1)
    x_out = np.empty(x_in1.shape)

    for i in range(L_valid):
        x1 = x_in1[i, ...][None, ...]
        x2 = x_in2[i, ...][None, ...]
        
        for t in reversed(range(0, total_timesteps)):
            tt = tf.cast(tf.fill(1, t), dtype=tf.int64)
            pred_noise = model.predict([x1, tt, x2], verbose=0)
            model_mean, _, model_log_variance =  gdf_util.p_mean_variance(pred_noise, x=x1, t=tt, clip_denoised=True)
            nonzero_mask = (1 - (np.array(tt)==0)).reshape((1, 1, 1, 1))
            x1 = np.array(model_mean) + nonzero_mask * np.exp(0.5 * np.array(model_log_variance)) * np.random.normal(size=x1.shape)
        x_out[i, ...] = x1

    return x_out

In [16]:
def mean_absolute_error(y_true, y_pred):
    return np.mean(np.abs(y_true - y_pred))

In [17]:
F_x = 0.1
F_y = 1/2.76

L_valid = 32

filenames = np.array(sorted(glob(BATCH_dir+'*.npy')))

L = len(filenames)
filename_valid = filenames[:][:L_valid]
filename_train = list(set(filenames) - set(filename_valid))

L_train = len(filename_train)

Y_valid = np.empty((L_valid, 32, 32, 16))
X_valid = np.empty((L_valid, 32, 32, 256))

for i, name in enumerate(filename_valid):
    temp_data = np.load(name, allow_pickle=True)[()]
    X_valid[i, ...] = F_x*temp_data['GFS_latent']
    Y_valid[i, ...] = F_y*temp_data['Y_latent']

In [18]:
t_valid_ = np.random.uniform(low=0, high=total_timesteps, size=(L_valid,)) #(total_timesteps-1)*np.ones(L_valid) #
t_valid = t_valid_.astype(int)

# sample random noise to be added to the images in the batch
noise_valid = np.random.normal(size=(L_valid, 32, 32, 16))
images_valid = np.array(gdf_util.q_sample(Y_valid, t_valid, noise_valid))

In [19]:
# start_time = time.time()
# Y_pred = reverse_diffuse(model, images_valid, X_valid, total_timesteps, gdf_util)
# print("--- %s seconds ---" % (time.time() - start_time))

In [20]:
# x_in1 = images_valid
# x_in2 = X_valid

# x1 = x_in1[i, ...][None, ...]
# x2 = x_in2[i, ...][None, ...]
        
# for t in reversed(range(0, total_timesteps)):
#     tt = tf.cast(tf.fill(1, t), dtype=tf.int64)
#     pred_noise = model.predict([x1, tt, x2], verbose=0)
#     model_mean, _, model_log_variance =  gdf_util.p_mean_variance(pred_noise, x=x1, t=tt, clip_denoised=True)
#     nonzero_mask = (1 - (np.array(tt)==0)).reshape((1, 1, 1, 1))
#     x1 = np.array(model_mean) + nonzero_mask * np.exp(0.5 * np.array(model_log_variance)) * np.random.normal(size=x1.shape)

In [21]:
# pred_noise = model.predict([images_valid, t_valid, X_valid])
# record = np.mean(np.abs(noise_valid - pred_noise))
# print('initial loss {}'.format(record))

In [None]:
# ====================== #
filenames = np.array(sorted(glob(BATCH_dir+'*.npy')))
L = len(filenames)
filename_valid = filenames[:L_valid]
filename_train = list(set(filenames) - set(filename_valid))
L_train = len(filename_train)
# ====================== #

epochs = 99999
N_batch = 128
batch_size = 64 #32

min_del = 0.0
max_tol = 3 # early stopping with 2-epoch patience
tol = 0

Y_batch = np.empty((batch_size, 32, 32, 16))
X_batch = np.empty((batch_size, 32, 32, 256))

for i in range(epochs):
    
    print('epoch = {}'.format(i))
    if i == 0:
        pred_noise = model.predict([images_valid, t_valid, X_valid])
        record = np.mean(np.abs(noise_valid - pred_noise))
        #print('initial loss {}'.format(record))
        print('Initial validation loss: {}'.format(record))
        
    start_time = time.time()
    # loop over batches
    
    for j in range(N_batch):

        # collect training batches
        
        inds_rnd = du.shuffle_ind(L_train)
        inds_ = inds_rnd[:batch_size]

        for k, ind in enumerate(inds_):
            # import batch data
            temp_name = filename_train[ind]
            temp_data = np.load(temp_name, allow_pickle=True)[()]
            X_batch[k, ...] = F_x*temp_data['GFS_latent']
            Y_batch[k, ...] = F_y*temp_data['Y_latent']

        # sample timesteps uniformly
        t_ = np.random.uniform(low=0, high=total_timesteps, size=(batch_size,))
        t = t_.astype(int)

        # sample random noise to be added to the images in the batch
        noise = np.random.normal(size=(batch_size, 32, 32, 16))
        images_t = np.array(gdf_util.q_sample(Y_batch, t, noise))

        model.train_on_batch([images_t, t, X_batch], noise)
        
    # on epoch-end
    # Y_pred = reverse_diffuse(model, images_valid, X_valid, total_timesteps, gdf_util)
    # record_temp = mean_absolute_error(Y_valid, Y_pred)
    pred_noise = model.predict([images_valid, t_valid, X_valid])
    record_temp = np.mean(np.abs(noise_valid - pred_noise))
    
    # print out valid loss change
    if record - record_temp > min_del:
        print('Validation loss improved from {} to {}'.format(record, record_temp))
        record = record_temp
        model.save(model_name)
        
    else:
        print('Validation loss {} NOT improved'.format(record_temp))

    print("--- %s seconds ---" % (time.time() - start_time))
    # mannual callbacks

epoch = 0
Initial validation loss: 0.797513172304641
Validation loss improved from 0.797513172304641 to 0.5559610733509978




--- 471.420521736145 seconds ---
epoch = 1
Validation loss improved from 0.5559610733509978 to 0.3403905973686788




--- 402.20895886421204 seconds ---
epoch = 2
Validation loss improved from 0.3403905973686788 to 0.22919090064767977




--- 389.10267972946167 seconds ---
epoch = 3
Validation loss improved from 0.22919090064767977 to 0.2027314528156189




--- 389.27508449554443 seconds ---
epoch = 4
Validation loss improved from 0.2027314528156189 to 0.18128846542139904




--- 398.464955329895 seconds ---
epoch = 5
Validation loss improved from 0.18128846542139904 to 0.16057621433272762




--- 461.9826326370239 seconds ---
epoch = 6
Validation loss improved from 0.16057621433272762 to 0.14289869076055428




--- 533.4157667160034 seconds ---
epoch = 7
Validation loss improved from 0.14289869076055428 to 0.1284261688872051




--- 527.249320268631 seconds ---
epoch = 8
Validation loss improved from 0.1284261688872051 to 0.11621552889745633




--- 538.7512350082397 seconds ---
epoch = 9
Validation loss improved from 0.11621552889745633 to 0.1059909117553423




--- 408.60932087898254 seconds ---
epoch = 10
Validation loss improved from 0.1059909117553423 to 0.09716345015866287




--- 404.4019560813904 seconds ---
epoch = 11
Validation loss improved from 0.09716345015866287 to 0.0896161981162007




--- 416.0437467098236 seconds ---
epoch = 12
Validation loss improved from 0.0896161981162007 to 0.08273929819448905




--- 416.95712447166443 seconds ---
epoch = 13
Validation loss improved from 0.08273929819448905 to 0.07722468002175827




--- 438.0843915939331 seconds ---
epoch = 14
Validation loss improved from 0.07722468002175827 to 0.07227678327661065




--- 484.0047941207886 seconds ---
epoch = 15
Validation loss improved from 0.07227678327661065 to 0.06798663757803595




--- 417.7964925765991 seconds ---
epoch = 16
Validation loss improved from 0.06798663757803595 to 0.06409516549404144




--- 405.55396795272827 seconds ---
epoch = 17
Validation loss improved from 0.06409516549404144 to 0.060697814228430895




--- 387.5766806602478 seconds ---
epoch = 18
Validation loss improved from 0.060697814228430895 to 0.05797634668941367




--- 404.45651412010193 seconds ---
epoch = 19
Validation loss improved from 0.05797634668941367 to 0.05545675097770402




--- 455.4145929813385 seconds ---
epoch = 20
Validation loss improved from 0.05545675097770402 to 0.052888227889087716




--- 424.20606565475464 seconds ---
epoch = 21
Validation loss improved from 0.052888227889087716 to 0.05097449092689509




--- 458.1558723449707 seconds ---
epoch = 22
Validation loss improved from 0.05097449092689509 to 0.048967857496122404




--- 415.1968629360199 seconds ---
epoch = 23
Validation loss improved from 0.048967857496122404 to 0.04747930620211306




--- 550.704274892807 seconds ---
epoch = 24
Validation loss improved from 0.04747930620211306 to 0.04543046127853685




--- 613.0236401557922 seconds ---
epoch = 25
Validation loss improved from 0.04543046127853685 to 0.04400997001089581




--- 526.2651669979095 seconds ---
epoch = 26
Validation loss improved from 0.04400997001089581 to 0.043091280638278974




--- 474.00400161743164 seconds ---
epoch = 27
Validation loss improved from 0.043091280638278974 to 0.04134070614884069




--- 445.8767900466919 seconds ---
epoch = 28
Validation loss improved from 0.04134070614884069 to 0.040291605338835426




--- 420.67860651016235 seconds ---
epoch = 29
Validation loss improved from 0.040291605338835426 to 0.03936077880618272




--- 517.3442347049713 seconds ---
epoch = 30
Validation loss improved from 0.03936077880618272 to 0.03887766951133284




--- 588.3358027935028 seconds ---
epoch = 31
Validation loss improved from 0.03887766951133284 to 0.03724375688500586




--- 543.2640764713287 seconds ---
epoch = 32
Validation loss improved from 0.03724375688500586 to 0.036526212945955996




--- 508.91004371643066 seconds ---
epoch = 33
Validation loss improved from 0.036526212945955996 to 0.03549530360254592




--- 484.5173268318176 seconds ---
epoch = 34
Validation loss improved from 0.03549530360254592 to 0.03465105223460907




--- 612.5016539096832 seconds ---
epoch = 35
Validation loss improved from 0.03465105223460907 to 0.033886106169804406




--- 534.8786282539368 seconds ---
epoch = 36
Validation loss improved from 0.033886106169804406 to 0.03326138906762922




--- 618.2382049560547 seconds ---
epoch = 37
Validation loss improved from 0.03326138906762922 to 0.032582565489026015




--- 553.2573931217194 seconds ---
epoch = 38
Validation loss improved from 0.032582565489026015 to 0.031982645766156184




--- 617.039960861206 seconds ---
epoch = 39
Validation loss improved from 0.031982645766156184 to 0.03186714786632341




--- 520.8185844421387 seconds ---
epoch = 40
Validation loss improved from 0.03186714786632341 to 0.0309703045973853




--- 403.2380769252777 seconds ---
epoch = 41
Validation loss 0.031139389200003646 NOT improved
--- 387.69413447380066 seconds ---
epoch = 42
Validation loss improved from 0.0309703045973853 to 0.030427886833236353




--- 451.3786346912384 seconds ---
epoch = 43
Validation loss improved from 0.030427886833236353 to 0.02991023203615282




--- 428.2513630390167 seconds ---
epoch = 44
Validation loss improved from 0.02991023203615282 to 0.029149318553441404




--- 413.3736572265625 seconds ---
epoch = 45
Validation loss improved from 0.029149318553441404 to 0.02909479913815121




--- 405.3245346546173 seconds ---
epoch = 46
Validation loss improved from 0.02909479913815121 to 0.028391452501605147




--- 407.5135726928711 seconds ---
epoch = 47
Validation loss improved from 0.028391452501605147 to 0.028006451677511913




--- 425.7228271961212 seconds ---
epoch = 48
Validation loss improved from 0.028006451677511913 to 0.027501218645962083




--- 390.87215781211853 seconds ---
epoch = 49
Validation loss improved from 0.027501218645962083 to 0.02719560320606552




--- 386.37936782836914 seconds ---
epoch = 50
Validation loss improved from 0.02719560320606552 to 0.026697202905771315




--- 408.9744567871094 seconds ---
epoch = 51
Validation loss 0.02693225497113441 NOT improved
--- 361.71693873405457 seconds ---
epoch = 52


In [23]:
pred_noise = model.predict([images_valid, t_valid, X_valid])



In [23]:
images_t.shape

(32, 32, 32, 16)

In [21]:
# 2. Sample timesteps uniformly
t_ = np.random.uniform(low=0, high=total_timesteps, size=(batch_size,))
t = t_.astype(int)

# 3. Sample random noise to be added to the images in the batch
noise = np.random.normal(size=(batch_size, 128, 128, 1))

images_t_ = np.array(gdf_util.q_sample(Y_valid_HR, t, noise))