# Latent diffusion model without GFS embeddings

* The cross-attention version

In [1]:
import os
import sys
import time
import math
import logging
import warnings
import numpy as np
from glob import glob

# supress regular warnings
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
logging.getLogger("tensorflow").setLevel(logging.ERROR) 

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# supress tensorflow warnings
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel('ERROR')

# adjust for time step embedding layer
tf.config.run_functions_eagerly(True)

In [2]:
np.random.seed(999)

In [3]:
sys.path.insert(0, '/glade/u/home/ksha/GAN_proj/')
sys.path.insert(0, '/glade/u/home/ksha/GAN_proj/libs/')

from namelist import *
import data_utils as du
import model_utils as mu

## Hyperparameters

In [4]:
total_timesteps = 50 # diffusion time steps
norm_groups = 8 # number of attention heads, number of layer normalization groups 

# min-max values of the diffusion target (learning target) 
clip_min = -1.0
clip_max = 1.0

precip_max = np.log(100+1)

input_shape = (128, 256, 1) # the tensor shape of reverse diffusion input
gfs_shape = (128, 256, 8) # the tensor shape of GFS embeddings

widths = [64, 96, 128, 256] # number of convolution kernels per up-/downsampling level
feature_sizes = [32, 16, 8, 4]

left_attention = [False, False, True, True] # True: use multi-head attnetion on each up-/downsampling level
right_attention = [False, False, True, True]
num_res_blocks = 2  # Number of residual blocks

N_atten1 = np.sum(left_attention)
N_atten2 = np.sum(right_attention)

load_weights = True # True: load previous weights
# location of the previous weights
model_name = '/glade/work/ksha/GAN/models/LDM_025_resize{}-{}_res{}_tune1/'.format(
    N_atten1, N_atten2, num_res_blocks)

# location for saving new weights
model_name_save = '/glade/work/ksha/GAN/models/LDM_025_resize{}-{}_res{}_tune2/'.format(
    N_atten1, N_atten2, num_res_blocks)

lr = 1e-5 # learning rate

# samples per epoch = N_batch * batch_size
epochs = 99999
N_batch = 128
batch_size = 16

## Model design

In [5]:
def build_model(input_shape, gfs_shape, widths, feature_sizes, left_attention, right_attention, num_res_blocks=2, norm_groups=8,
                interpolation='bilinear', activation_fn=keras.activations.swish,):

    first_conv_channels = widths[0]
    
    image_input = layers.Input(shape=input_shape, name="image_input")
    time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")
    gfs_input = layers.Input(shape=gfs_shape, name="gfs_input")
    
    x = layers.Conv2D(first_conv_channels, kernel_size=(3, 3), padding="same",
                      kernel_initializer=mu.kernel_init(1.0),)(image_input)

    temb = mu.TimeEmbedding(dim=first_conv_channels * 4)(time_input)
    temb = mu.TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)

    skips = [x]

    # DownBlock
    has_attention = left_attention
    for i in range(len(widths)):
        for _ in range(num_res_blocks):
            x = mu.ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)([x, temb])
            
            if has_attention[i]:
                # GFS cross-attention inputs
                size_ = feature_sizes[i]
                x_gfs = gfs_input
                x_gfs = layers.Resizing(size_, 2*size_, interpolation='bilinear')(x_gfs)

                x_gfs = layers.Conv2D(int(0.5*widths[i]), kernel_size=(3, 3), padding="same",)(x_gfs)
                x_gfs = layers.GroupNormalization(groups=norm_groups)(x_gfs)
                x_gfs = activation_fn(x_gfs)

                x_gfs = layers.Conv2D(widths[i], kernel_size=(3, 3), padding="same",)(x_gfs)
                x_gfs = layers.GroupNormalization(groups=norm_groups)(x_gfs)
                x_gfs = activation_fn(x_gfs)
                
                x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[i])(x, x_gfs)
                
            skips.append(x)

        if widths[i] != widths[-1]:
            x = mu.DownSample(widths[i])(x)
            skips.append(x)

    # MiddleBlock
    x = mu.ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])
    
    size_ = feature_sizes[-1]
    x_gfs = gfs_input
    x_gfs = layers.Resizing(size_, 2*size_, interpolation='bilinear')(x_gfs)
    
    x_gfs = layers.Conv2D(int(0.5*widths[-1]), kernel_size=(3, 3), padding="same",)(x_gfs)
    x_gfs = layers.GroupNormalization(groups=norm_groups)(x_gfs)
    x_gfs = activation_fn(x_gfs)

    x_gfs = layers.Conv2D(widths[-1], kernel_size=(3, 3), padding="same",)(x_gfs)
    x_gfs = layers.GroupNormalization(groups=norm_groups)(x_gfs)
    x_gfs = activation_fn(x_gfs)
    
    x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[-1])(x, x_gfs)
    
    x = mu.ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])

    # UpBlock
    has_attention = right_attention
    for i in reversed(range(len(widths))):
        for _ in range(num_res_blocks + 1):
            x = layers.Concatenate(axis=-1)([x, skips.pop()])
            x = mu.ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)([x, temb])
            
            if has_attention[i]:
                
                # GFS cross-attention inputs
                size_ = feature_sizes[i]
                x_gfs = gfs_input
                x_gfs = layers.Resizing(size_, 2*size_, interpolation='bilinear')(x_gfs)

                x_gfs = layers.Conv2D(int(0.5*widths[i]), kernel_size=(3, 3), padding="same",)(x_gfs)
                x_gfs = layers.GroupNormalization(groups=norm_groups)(x_gfs)
                x_gfs = activation_fn(x_gfs)

                x_gfs = layers.Conv2D(widths[i], kernel_size=(3, 3), padding="same",)(x_gfs)
                x_gfs = layers.GroupNormalization(groups=norm_groups)(x_gfs)
                x_gfs = activation_fn(x_gfs)
                
                x = layers.MultiHeadAttention(num_heads=norm_groups, key_dim=widths[i])(x, x_gfs)
                
        if i != 0:
            x = mu.UpSample(widths[i], interpolation=interpolation)(x)

    # End block
    x = layers.GroupNormalization(groups=norm_groups)(x)
    x = activation_fn(x)
    x = layers.Conv2D(input_shape[-1], (3, 3), padding="same", kernel_initializer=mu.kernel_init(0.0))(x)
    return keras.Model([image_input, time_input, gfs_input], x, name="unet")


In [6]:
# Reverse diffusino model
model = build_model(input_shape=input_shape, gfs_shape=gfs_shape, widths=widths, 
                    feature_sizes=feature_sizes, left_attention=left_attention, right_attention=right_attention, 
                    num_res_blocks=num_res_blocks, norm_groups=norm_groups, activation_fn=keras.activations.swish)

In [8]:
# Compile the mdoel
model.compile(loss=keras.losses.MeanAbsoluteError(), optimizer=keras.optimizers.Adam(learning_rate=lr),)

# load previous weights
if load_weights:
    W_old = mu.dummy_loader(model_name)
    model.set_weights(W_old)

# configure the forward diffusion steps
gdf_util = mu.GaussianDiffusion(timesteps=total_timesteps)

## Validation set preparation

In [9]:
L_valid = 270 # number of validation samples

# locations of training data
BATCH_dir = '/glade/campaign/cisl/aiml/ksha/BATCH_LDM_025/'

# preparing training batches
filenames = np.array(sorted(glob(BATCH_dir+'*2023*.npy')))

L = len(filenames)
filename_valid = filenames[::50][:L_valid]

Y_valid = np.empty((L_valid,)+input_shape)
X_valid = np.empty((L_valid,)+gfs_shape)

for i, name in enumerate(filename_valid):
    temp_data = np.load(name, allow_pickle=True)[()]
    X_valid[i, ...] = temp_data['GFS']
    Y_valid[i, ...] = 2*(temp_data['MRMS']/precip_max-0.5)
Y_valid[Y_valid>1.0] = 1.0

# validate on random timesteps
t_valid_ = np.random.uniform(low=0, high=total_timesteps, size=(L_valid,))
t_valid = t_valid_.astype(int)

# sample random noise to be added to the images in the batch
noise_valid = np.random.normal(size=((L_valid,)+input_shape))
images_valid = np.array(gdf_util.q_sample(Y_valid, t_valid, noise_valid))

# validation prediction example:
# pred_noise = model.predict([images_valid, t_valid, X_valid])

In [11]:
pred_noise = model.predict([images_valid, t_valid, X_valid])
record = np.mean(np.abs(noise_valid - pred_noise))
print('Initial validation loss: {}'.format(record))

Initial validation loss: 0.03732102262436367


## Training loop

In [10]:
# collect all training batches
filename_train1 = sorted(glob(BATCH_dir+'*2021*.npy'))
filename_train2 = sorted(glob(BATCH_dir+'*2022*.npy'))

filename_train = list(filename_train1) + list(filename_train2)
L_train = len(filename_train)

In [11]:
min_del = 0.0
max_tol = 3 # early stopping with 2-epoch patience
tol = 0

Y_batch = np.empty((batch_size,)+input_shape)
X_batch = np.empty((batch_size,)+gfs_shape)

for i in range(epochs):    
    print('epoch = {}'.format(i))
    if i == 0:
        pred_noise = model.predict([images_valid, t_valid, X_valid])
        record = np.mean(np.abs(noise_valid - pred_noise))
        #print('initial loss {}'.format(record))
        print('Initial validation loss: {}'.format(record))
        
    start_time = time.time()
    # loop over batches
    for j in range(N_batch):
        
        inds_rnd = du.shuffle_ind(L_train) # shuffle training files
        inds_ = inds_rnd[:batch_size] # select training files
        
        # collect training batches
        for k, ind in enumerate(inds_):
            # import batch data
            temp_name = filename_train[ind]
            temp_data = np.load(temp_name, allow_pickle=True)[()]
            X_batch[k, ...] = temp_data['GFS']
            Y_batch[k, ...] = 2*(temp_data['MRMS']/precip_max-0.5)
            
        Y_batch[Y_batch>1.0] = 1.0

        # sample timesteps uniformly
        t_ = np.random.uniform(low=0, high=total_timesteps, size=(batch_size,))
        t = t_.astype(int)
        
        # sample random noise to be added to the images in the batch
        noise = np.random.normal(size=(batch_size,)+input_shape)
        images_t = np.array(gdf_util.q_sample(Y_batch, t, noise))
        
        # train on batch
        model.train_on_batch([images_t, t, X_batch], noise)
        
    # on epoch-end
    pred_noise = model.predict([images_valid, t_valid, X_valid])
    record_temp = np.mean(np.abs(noise_valid - pred_noise))
    
    # print out valid loss change
    if record - record_temp > min_del:
        print('Validation loss improved from {} to {}'.format(record, record_temp))
        record = record_temp
        print("Save to {}".format(model_name_save))
        model.save(model_name_save)
        
    else:
        print('Validation loss {} NOT improved'.format(record_temp))

    print("--- %s seconds ---" % (time.time() - start_time))
    # mannual callbacks


epoch = 0
Initial validation loss: 0.03988743426908134
Validation loss improved from 0.03988743426908134 to 0.0380500046397584
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 382.7728805541992 seconds ---
epoch = 1
Validation loss improved from 0.0380500046397584 to 0.03803164265120117
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 382.5710651874542 seconds ---
epoch = 2
Validation loss 0.038082282364895255 NOT improved
--- 297.5637457370758 seconds ---
epoch = 3
Validation loss 0.03804873609656813 NOT improved
--- 297.39590787887573 seconds ---
epoch = 4
Validation loss improved from 0.03803164265120117 to 0.03791200756951528
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 390.5275945663452 seconds ---
epoch = 5
Validation loss 0.037970427600414934 NOT improved
--- 301.19238090515137 seconds ---
epoch = 6
Validation loss improved from 0.03791200756951528 to 0.03779376910423886
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 389.1069667339325 seconds ---
epoch = 7
Validation loss improved from 0.03779376910423886 to 0.03773410468035273
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 392.3885190486908 seconds ---
epoch = 8
Validation loss 0.037864383583463314 NOT improved
--- 299.65380096435547 seconds ---
epoch = 9
Validation loss 0.037972250496816656 NOT improved
--- 332.04002714157104 seconds ---
epoch = 10
Validation loss 0.03773697903432527 NOT improved
--- 335.9906802177429 seconds ---
epoch = 11
Validation loss 0.03775461290422561 NOT improved
--- 336.6466872692108 seconds ---
epoch = 12
Validation loss improved from 0.03773410468035273 to 0.03771438653474483
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 425.84297943115234 seconds ---
epoch = 13
Validation loss improved from 0.03771438653474483 to 0.03766659332805671
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 392.3420763015747 seconds ---
epoch = 14
Validation loss 0.03769522034551742 NOT improved
--- 300.34046173095703 seconds ---
epoch = 15
Validation loss 0.037682053005707795 NOT improved
--- 302.8091037273407 seconds ---
epoch = 16
Validation loss improved from 0.03766659332805671 to 0.03763868403768055
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 383.5532786846161 seconds ---
epoch = 17
Validation loss 0.03770679710404308 NOT improved
--- 296.2361342906952 seconds ---
epoch = 18
Validation loss 0.037828993398880656 NOT improved
--- 296.36062002182007 seconds ---
epoch = 19
Validation loss 0.037674154394488624 NOT improved
--- 297.7257604598999 seconds ---
epoch = 20
Validation loss 0.03772923930487647 NOT improved
--- 294.6773455142975 seconds ---
epoch = 21
Validation loss 0.037674720951792716 NOT improved
--- 297.32176780700684 seconds ---
epoch = 22
Validation loss 0.037683050349037864 NOT improved
--- 297.3580901622772 seconds ---
epoch = 23
Validation loss 0.03775688922706251 NOT improved
--- 294.47267603874207 seconds ---
epoch = 24
Validation loss improved from 0.03763868403768055 to 0.03760060109912628
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 383.90272068977356 seconds ---
epoch = 25
Validation loss improved from 0.03760060109912628 to 0.03754288937130601
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 386.18426036834717 seconds ---
epoch = 26
Validation loss 0.0376096670996019 NOT improved
--- 296.6584937572479 seconds ---
epoch = 27
Validation loss 0.037572997252047234 NOT improved
--- 297.3257749080658 seconds ---
epoch = 28
Validation loss 0.03759963143971753 NOT improved
--- 301.20568323135376 seconds ---
epoch = 29
Validation loss 0.03758376156950309 NOT improved
--- 294.9192633628845 seconds ---
epoch = 30
Validation loss 0.03754526525729822 NOT improved
--- 295.3470706939697 seconds ---
epoch = 31
Validation loss improved from 0.03754288937130601 to 0.03752749322464404
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 384.0536425113678 seconds ---
epoch = 32
Validation loss 0.03770271384550706 NOT improved
--- 298.57740092277527 seconds ---
epoch = 33
Validation loss improved from 0.03752749322464404 to 0.03750192241173372
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 385.98226165771484 seconds ---
epoch = 34
Validation loss 0.037529787594636016 NOT improved
--- 300.5511803627014 seconds ---
epoch = 35
Validation loss 0.037693698806163786 NOT improved
--- 295.3760747909546 seconds ---
epoch = 36
Validation loss 0.03750600624327547 NOT improved
--- 297.4278745651245 seconds ---
epoch = 37
Validation loss 0.03752657496150639 NOT improved
--- 295.6003723144531 seconds ---
epoch = 38
Validation loss 0.03753569661324586 NOT improved
--- 295.9379518032074 seconds ---
epoch = 39
Validation loss 0.037607024348616874 NOT improved
--- 298.5415894985199 seconds ---
epoch = 40
Validation loss improved from 0.03750192241173372 to 0.03747706681756768
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 386.8579885959625 seconds ---
epoch = 41
Validation loss 0.037497681214874526 NOT improved
--- 301.0545971393585 seconds ---
epoch = 42
Validation loss 0.03754681501005251 NOT improved
--- 297.5281147956848 seconds ---
epoch = 43
Validation loss 0.037592518373314524 NOT improved
--- 293.6817219257355 seconds ---
epoch = 44
Validation loss improved from 0.03747706681756768 to 0.037454468644935
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 384.72926354408264 seconds ---
epoch = 45
Validation loss improved from 0.037454468644935 to 0.03745314544872295
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 387.77791953086853 seconds ---
epoch = 46
Validation loss 0.03746416163979841 NOT improved
--- 295.92307782173157 seconds ---
epoch = 47
Validation loss 0.03749742092325636 NOT improved
--- 296.7400915622711 seconds ---
epoch = 48
Validation loss 0.03758099301620717 NOT improved
--- 297.82272505760193 seconds ---
epoch = 49
Validation loss improved from 0.03745314544872295 to 0.037444059330715554
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 391.40760827064514 seconds ---
epoch = 50
Validation loss 0.0374765976132695 NOT improved
--- 293.94353675842285 seconds ---
epoch = 51
Validation loss improved from 0.037444059330715554 to 0.03741259305204699
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 389.63240480422974 seconds ---
epoch = 52
Validation loss 0.037487177403142315 NOT improved
--- 294.58331966400146 seconds ---
epoch = 53
Validation loss 0.03751557034900253 NOT improved
--- 315.0151665210724 seconds ---
epoch = 54
Validation loss 0.0374561430247498 NOT improved
--- 326.33166241645813 seconds ---
epoch = 55
Validation loss 0.03744267445695326 NOT improved
--- 319.8150351047516 seconds ---
epoch = 56
Validation loss 0.037419891659404285 NOT improved
--- 323.64186334609985 seconds ---
epoch = 57
Validation loss 0.03744828809864342 NOT improved
--- 339.11074018478394 seconds ---
epoch = 58
Validation loss improved from 0.03741259305204699 to 0.0374032875362101
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 444.6967821121216 seconds ---
epoch = 59
Validation loss improved from 0.0374032875362101 to 0.037392658476622424
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 387.87478137016296 seconds ---
epoch = 60
Validation loss improved from 0.037392658476622424 to 0.03734908431437221
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 378.3278024196625 seconds ---
epoch = 61
Validation loss 0.037647692135827596 NOT improved
--- 293.97376108169556 seconds ---
epoch = 62
Validation loss 0.03755404850895701 NOT improved
--- 301.7577579021454 seconds ---
epoch = 63
Validation loss 0.037470603347100444 NOT improved
--- 292.3572111129761 seconds ---
epoch = 64
Validation loss 0.03741101800811885 NOT improved
--- 292.819237947464 seconds ---
epoch = 65
Validation loss 0.03746211377164849 NOT improved
--- 293.39149713516235 seconds ---
epoch = 66
Validation loss improved from 0.03734908431437221 to 0.03734062143556079
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 382.38566064834595 seconds ---
epoch = 67
Validation loss 0.03739866049203444 NOT improved
--- 291.934175491333 seconds ---
epoch = 68
Validation loss 0.03741834614761141 NOT improved
--- 291.6418402194977 seconds ---
epoch = 69
Validation loss 0.03756266722191525 NOT improved
--- 292.12633633613586 seconds ---
epoch = 70
Validation loss 0.03746798557970368 NOT improved
--- 292.5692365169525 seconds ---
epoch = 71
Validation loss 0.03740182692049369 NOT improved
--- 293.3640570640564 seconds ---
epoch = 72
Validation loss improved from 0.03734062143556079 to 0.03731325856988152
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 390.8638813495636 seconds ---
epoch = 73
Validation loss 0.037437789804464966 NOT improved
--- 288.0594472885132 seconds ---
epoch = 74
Validation loss 0.037330492684992965 NOT improved
--- 290.8726146221161 seconds ---
epoch = 75
Validation loss 0.037331965022318246 NOT improved
--- 296.05985856056213 seconds ---
epoch = 76
Validation loss 0.037347612171833106 NOT improved
--- 337.36700463294983 seconds ---
epoch = 77
Validation loss 0.03732851835413321 NOT improved
--- 348.5465364456177 seconds ---
epoch = 78
Validation loss 0.037377768241974356 NOT improved
--- 297.9385120868683 seconds ---
epoch = 79
Validation loss 0.03750843675827367 NOT improved
--- 297.58144640922546 seconds ---
epoch = 80
Validation loss improved from 0.03731325856988152 to 0.037308217414540226
Save to /glade/work/ksha/GAN/models/LDM_025_resize2-2_res2_tune0/




--- 387.1082637310028 seconds ---
epoch = 81
Validation loss 0.037472838707092296 NOT improved
--- 293.94818329811096 seconds ---
epoch = 82


KeyboardInterrupt: 

## Plot examples

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [23]:
pred_noise = model.predict([images_valid, t_valid, X_valid])



In [None]:
def reverse_diffuse(model, x_in1, x_in2, total_timesteps, gdf_util):
    L_valid = len(x_in1)
    x_out = np.empty(x_in1.shape)

    for i in range(L_valid):
        x1 = x_in1[i, ...][None, ...]
        x2 = x_in2[i, ...][None, ...]
        
        for t in reversed(range(0, total_timesteps)):
            tt = tf.cast(tf.fill(1, t), dtype=tf.int64)
            pred_noise = model.predict([x1, tt, x2], verbose=0)
            model_mean, _, model_log_variance =  gdf_util.p_mean_variance(pred_noise, x=x1, t=tt, clip_denoised=True)
            nonzero_mask = (1 - (np.array(tt)==0)).reshape((1, 1, 1, 1))
            x1 = np.array(model_mean) + nonzero_mask * np.exp(0.5 * np.array(model_log_variance)) * np.random.normal(size=x1.shape)
        x_out[i, ...] = x1

    return x_out

In [None]:
# start_time = time.time()
# Y_pred = reverse_diffuse(model, images_valid, X_valid, total_timesteps, gdf_util)
# print("--- %s seconds ---" % (time.time() - start_time))

**out-of-box reverse diffusion tests**

In [None]:
# x_in1 = images_valid
# x_in2 = X_valid

# x1 = x_in1[i, ...][None, ...]
# x2 = x_in2[i, ...][None, ...]
        
# for t in reversed(range(0, total_timesteps)):
#     tt = tf.cast(tf.fill(1, t), dtype=tf.int64)
#     pred_noise = model.predict([x1, tt, x2], verbose=0)
#     model_mean, _, model_log_variance =  gdf_util.p_mean_variance(pred_noise, x=x1, t=tt, clip_denoised=True)
#     nonzero_mask = (1 - (np.array(tt)==0)).reshape((1, 1, 1, 1))
#     x1 = np.array(model_mean) + nonzero_mask * np.exp(0.5 * np.array(model_log_variance)) * np.random.normal(size=x1.shape)