In [None]:
import os

from tqdm import tqdm

import jax
jax.config.update("jax_default_matmul_precision", "highest")

import jax.numpy as jnp
from jax import random, vmap

from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P

from function_diffusion.models import Encoder, Decoder, DiT

from function_diffusion.utils.model_utils import (
    create_autoencoder_state,
    create_diffusion_state,
    create_optimizer,
    compute_total_params,
)
from function_diffusion.utils.train_utils import  sample_ode
from function_diffusion.utils.data_utils import create_dataloader
from function_diffusion.utils.checkpoint_utils import (
    create_checkpoint_manager,
    restore_checkpoint,
)

from linear_elasticity.data_utils import create_dataset
from model_utils import create_encoder_step, create_decoder_step

In [None]:
from configs import diffusion

config = diffusion.get_config('fae,dit')

In [None]:
def restore_fae_state(config, encoder, decoder):
    # Create learning rate schedule and optimizer
    lr, tx = create_optimizer(config)

    # Create train state
    state = create_autoencoder_state(config, encoder, decoder, tx)

    # Create checkpoint manager
    fae_job_name = f"{config.autoencoder.model_name}"

    ckpt_path = os.path.join(os.getcwd(), fae_job_name, "ckpt")
    ckpt_mngr = create_checkpoint_manager(config.saving, ckpt_path)

    # Restore the model from the checkpoint
    fae_state = restore_checkpoint(ckpt_mngr, state)
    print(f"Restored model {fae_job_name} from step", fae_state.step)

    return fae_state

In [None]:
# Initialize function autoencoder
encoder = Encoder(**config.autoencoder.encoder)
decoder = Decoder(**config.autoencoder.decoder)

fae_state = restore_fae_state(config, encoder, decoder)

In [None]:
# Initialize diffusion model
dit = DiT(**config.diffusion)
# Create learning rate schedule and optimizer
lr, tx = create_optimizer(config)

# Create diffusion train state
state = create_diffusion_state(config, dit, tx, use_conditioning=False)
num_params = compute_total_params(state)
print(f"Model storage cost: {num_params * 4 / 1024 / 1024:.2f} MB of parameters")

In [None]:
# Create checkpoint manager
job_name = f"{config.diffusion.model_name}"
ckpt_path = os.path.join(os.getcwd(), job_name, "ckpt")
# Create checkpoint manager
ckpt_mngr = create_checkpoint_manager(config.saving, ckpt_path)

# Restore the model from the checkpoint
state = restore_checkpoint(ckpt_mngr, state)
print(f"Restored model {job_name} from step", state.step)

In [None]:
# Device count
num_local_devices = jax.local_device_count()
num_devices = jax.device_count()
print(f"Number of devices: {num_devices}")
print(f"Number of local devices: {num_local_devices}")

# Create sharding for data parallelism
mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), "batch")
state = multihost_utils.host_local_array_to_global_array(state, mesh, P())
fae_state = multihost_utils.host_local_array_to_global_array(fae_state, mesh, P())

In [None]:
# Create encoder and decoder steps
encoder_step = create_encoder_step(encoder, mesh)
decoder_step = create_decoder_step(decoder, mesh)

In [None]:
# Get test dataset
_, test_dataset = create_dataset(config)
test_loader = create_dataloader(test_dataset,
                                batch_size=2,
                                num_workers=config.dataset.num_workers,
                                shuffle=False)

In [None]:
# Create uniform grid for evaluation
h, w = 256, 256

x_coords = jnp.linspace(0, 1, h)
y_coords = jnp.linspace(0, 1, w)
x_coords, y_coords = jnp.meshgrid(x_coords, y_coords, indexing='ij')
coords = jnp.hstack([x_coords.reshape(-1, 1), y_coords.reshape(-1, 1)])
coords = multihost_utils.host_local_array_to_global_array(coords, mesh, P())

In [None]:
rng_key = jax.random.PRNGKey(888)

u_pred_list = []
v_pred_list = []
u_true_list = []
v_true_list = []
div_pred_list = []

iters = 0 
for batch in tqdm(test_loader):
    iters += 1
    rng_key, *keys = random.split(rng_key, 3)
    
    batch = jax.tree.map(jnp.array, batch)

    uv = batch
    uv_batch = (jnp.ones_like(uv), uv, jnp.ones_like(uv))
  
    # Shard the batch across devices
    uv_batch = multihost_utils.host_local_array_to_global_array(
        uv_batch, mesh, P("batch")
        )
    z_uv = encoder_step(fae_state.params[0], uv_batch)  # Just to get shape not used for generation

    z0 = random.normal(keys[1], shape=z_uv.shape)
    z1_new, _ = sample_ode(state, z0=z0, c=None, num_steps=100, use_conditioning=False)  
    u_pred, v_pred, div_pred = decoder_step(fae_state.params[1], z1_new, coords)

    u_pred = u_pred.reshape(-1, h, w)
    v_pred = v_pred.reshape(-1, h, w)

    u_true = uv[..., 0]
    v_true = uv[..., 1]
    div_pred = div_pred.reshape(-1, h, w)  

    u_pred_list.append(u_pred)
    v_pred_list.append(v_pred)
    u_true_list.append(u_true)
    v_true_list.append(v_true)
    div_pred_list.append(div_pred)

    if iters ==4:  # Comment out to run on full test set
        break

u_pred = jnp.vstack(u_pred_list)
u_true = jnp.vstack(u_true_list)
div_pred = jnp.vstack(div_pred_list)

In [None]:
# Visualization of some examples
import matplotlib.pyplot as plt

k = 0

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.title('Generated')
plt.pcolor(u_pred[k, :, :], cmap='jet')
plt.colorbar()

plt.subplot(1, 2, 2)
plt.title('Pred Divergence')
plt.pcolor(jnp.abs(div_pred[k, :, :]), cmap='jet')
plt.colorbar()

plt.tight_layout()
plt.show()