In [None]:
import os

from tqdm import tqdm

import jax
jax.config.update("jax_default_matmul_precision", "highest")

import jax.numpy as jnp
from jax import random, vmap

from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P

from function_diffusion.models import Encoder, Decoder, DiT

from function_diffusion.utils.model_utils import (
    create_autoencoder_state,
    create_diffusion_state,
    create_optimizer,
    compute_total_params,
)
from function_diffusion.utils.train_utils import  sample_ode
from function_diffusion.utils.data_utils import create_dataloader
from function_diffusion.utils.checkpoint_utils import (
    create_checkpoint_manager,
    restore_checkpoint,
)

from data_utils import create_dataset
from model_utils import create_encoder_step, create_decoder_step

In [None]:
from configs import diffusion

config = diffusion.get_config('fae,dit')

In [None]:
def restore_fae_state(config, encoder, decoder):
    # Create learning rate schedule and optimizer
    lr, tx = create_optimizer(config)

    # Create train state
    state = create_autoencoder_state(config, encoder, decoder, tx)

    # Create checkpoint manager
    fae_job_name = f"{config.autoencoder.model_name}"

    ckpt_path = os.path.join(os.getcwd(), fae_job_name, "ckpt")
    ckpt_mngr = create_checkpoint_manager(config.saving, ckpt_path)

    # Restore the model from the checkpoint
    fae_state = restore_checkpoint(ckpt_mngr, state)
    print(f"Restored model {fae_job_name} from step", fae_state.step)

    return fae_state

In [None]:
# Initialize function autoencoder
encoder = Encoder(**config.autoencoder.encoder)
decoder = Decoder(**config.autoencoder.decoder)

fae_state = restore_fae_state(config, encoder, decoder)

In [None]:
# Initialize diffusion model
dit = DiT(**config.diffusion)
# Create learning rate schedule and optimizer
lr, tx = create_optimizer(config)

# Create diffusion train state
state = create_diffusion_state(config, dit, tx, use_conditioning=True)
num_params = compute_total_params(state)
print(f"Model storage cost: {num_params * 4 / 1024 / 1024:.2f} MB of parameters")

In [None]:
# Create checkpoint manager
job_name = f"{config.diffusion.model_name}"
ckpt_path = os.path.join(os.getcwd(), job_name, "ckpt")
# Create checkpoint manager
ckpt_mngr = create_checkpoint_manager(config.saving, ckpt_path)

# Restore the model from the checkpoint
state = restore_checkpoint(ckpt_mngr, state)
print(f"Restored model {job_name} from step", state.step)

In [None]:
# Device count
num_local_devices = jax.local_device_count()
num_devices = jax.device_count()
print(f"Number of devices: {num_devices}")
print(f"Number of local devices: {num_local_devices}")

# Create sharding for data parallelism
mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), "batch")
state = multihost_utils.host_local_array_to_global_array(state, mesh, P())
fae_state = multihost_utils.host_local_array_to_global_array(fae_state, mesh, P())

In [None]:
# Create encoder and decoder steps
encoder_step = create_encoder_step(encoder, mesh)
decoder_step = create_decoder_step(decoder, mesh)

In [None]:
# Get test dataset
_, test_dataset = create_dataset(config)
test_loader = create_dataloader(test_dataset,
                                batch_size=2,
                                num_workers=config.dataset.num_workers,
                                shuffle=False)

In [None]:
# Create uniform grid for evaluation
h, w = 200, 100

x_coords = jnp.linspace(0, 1, h)
y_coords = jnp.linspace(0, 1, w)
x_coords, y_coords = jnp.meshgrid(x_coords, y_coords, indexing='ij')
coords = jnp.hstack([x_coords.reshape(-1, 1), y_coords.reshape(-1, 1)])
coords = multihost_utils.host_local_array_to_global_array(coords, mesh, P())

In [None]:
noise_level = 1.0
keys = random.split(jax.random.PRNGKey(88), 4)
d = 2

u_input_list = []
v_input_list = []
p_pred_list = []
sdf_pred_list = []
p_true_list = []
sdf_true_list = []

iters = 0
for batch in tqdm(test_loader):
    iters = iters + 1
    batch = jax.tree.map(jnp.array, batch)
    u = batch[:, ::d, ::d, 0:1]
    v = batch[:, ::d, ::d, 1:2]
    p = batch[..., 2:3]
    sdf = batch[..., 3:4]

    u = u + noise_level * jax.random.normal(keys[0], u.shape)
    v = v + noise_level * jax.random.normal(keys[1], v.shape)

    u_batch = (jnp.ones_like(u), u, jnp.ones_like(u))
    v_batch = (jnp.ones_like(v), v, jnp.ones_like(v))
    p_batch = (jnp.ones_like(p), p, jnp.ones_like(p))
    sdf_batch = (jnp.ones_like(sdf), sdf, jnp.ones_like(sdf))

    # Shard the batch across devices
    u_batch = multihost_utils.host_local_array_to_global_array(
        u_batch, mesh, P("batch")
    )
    v_batch = multihost_utils.host_local_array_to_global_array(
        v_batch, mesh, P("batch")
    )
    p_batch = multihost_utils.host_local_array_to_global_array(
        p_batch, mesh, P("batch")
    )
    sdf_batch = multihost_utils.host_local_array_to_global_array(
        sdf_batch, mesh, P("batch")
    )

    z_u = encoder_step(fae_state.params[0], u_batch)
    z_v = encoder_step(fae_state.params[0], v_batch)

    z_u = encoder_step(fae_state.params[0], u_batch)
    z_v = encoder_step(fae_state.params[0], v_batch)

    z_c = jnp.concatenate([z_u, z_v], axis=-1)  # (b, l, 2c)

    z0 = random.normal(keys[1], shape=z_u.shape)
    z1_new, _ = sample_ode(state, z0=z0, c=z_c, num_steps=100, use_conditioning=True) 
    
    c_dim = z_c.shape[-1]
    z_p_new = z1_new[..., :c_dim//2]
    z_sdf_new = z1_new[..., c_dim//2:]

    p_pred = decoder_step(fae_state.params[1], z_p_new, coords)
    sdf_pred = decoder_step(fae_state.params[1], z_sdf_new, coords)

    p_pred = p_pred.reshape(-1, h, w)
    sdf_pred = sdf_pred.reshape(-1, h, w)
    
    u_input_list.append(u)
    v_input_list.append(v)
    p_pred_list.append(p_pred)
    sdf_pred_list.append(sdf_pred)
    p_true_list.append(p)
    sdf_true_list.append(sdf)
    
    if iters == 4:     
        break
        
# Concatenate all results
u_input = jnp.concatenate(u_input_list, axis=0)
v_input = jnp.concatenate(v_input_list, axis=0)
p_pred = jnp.concatenate(p_pred_list, axis=0)
sdf_pred = jnp.concatenate(sdf_pred_list, axis=0)
p_true = jnp.concatenate(p_true_list, axis=0)
sdf_true = jnp.concatenate(sdf_true_list, axis=0)

In [None]:
def compute_error(pred, y):
    return jnp.linalg.norm(pred.flatten() - y.flatten()) / jnp.linalg.norm(y.flatten())

error = vmap(compute_error)(p_pred, p_true)

print("Mean relative error:", jnp.mean(error))
print("Max relative error:", jnp.max(error))
print("Min relative error:", jnp.min(error))
print("Std relative error:", jnp.std(error))

In [None]:
# Visualization of some examples
import matplotlib.pyplot as plt

k = 0

fig = plt.figure(figsize=(17, 4))
plt.subplot(1, 4, 1)
plt.title('Input')
plt.pcolor(u_input[k, :, :, 0], cmap='jet')
plt.colorbar()

plt.subplot(1, 4, 2)
plt.title('Reference')
plt.pcolor(p_true[k, :, :, 0], cmap='jet')
plt.colorbar()

plt.subplot(1, 4, 3)
plt.title('Prediction')
plt.pcolor(p_pred[k, :, :, 0], cmap='jet')
plt.colorbar()

plt.subplot(1, 4, 4)
plt.title('Absolute Error')
plt.pcolor(jnp.abs(p_pred[k, :, :, 0] - p_true[k, :, :, 0]), cmap='jet')
plt.colorbar()

plt.tight_layout()
plt.show()