# Diffusion Posterior Sampling with Gaussian Noise

In this notebook, we optimize the initial state of the buoyancy-driven flow with obstacles simulation to match the final state of a simulation. We use diffusion posterior sampling, DPS, with Gaussian noise as described in https://openreview.net/forum?id=OnD9zGAGT0k Algorithm 1 with a pretrained DDPM diffusion model. 

In [1]:
import sys

sys.path.append('github/smdp/buoyancy-flow') 
sys.path.append('github/smdp/buoyancy-flow/baselines/diffusion-posterior-sampling')

### Load the testing dataset

In [2]:
import h5py

# import dataloader
from dataloader_multi import DataLoader

file_test = 'github/smdp/buoyancy-flow/data/smoke_plumes_test_r0.h5'

dataKeys = None      
with h5py.File(file_test, 'r') as f:
    dataKeys = list(f.keys())

dataKeys = list(zip([file_test] * len(dataKeys), dataKeys))

test_data = DataLoader([file_test], dataKeys, name='test', batchSize=1)

2023-11-02 15:58:35.770202: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-11-02 15:58:35.867124: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-11-02 15:58:36.356161: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/cuda-11.5/lib64:/usr/local/cuda-11.5/lib64
2023-11-02 15:58:36.356214: W te

Length: 5


### Load the pretrained diffusion model
Define the model architecture and load the pretrained weights. 

In [3]:
from unet import Unet
import torch

# file path for stored weights
weight_file = 'github/diffusion-posterior-sampling-backup/results/ddpm-model-flow-2s3jcppm-20.pt'

model_spec = {'channels': 4,
              'image_size': 64,
              'data_shape': (4, 64, 64),
              'dim' : 64,
              'dim_mults' : (1, 2, 2, 4,)}


model = Unet(**model_spec)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = torch.load(weight_file,
                  map_location=device)
model.load_state_dict(data['model'])
model = model.to(device)

Load LPIPS for perceptual distance metric

In [4]:
# import LPIPS distance
import evaluation.lpips as lpips

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/benjamin/anaconda3/envs/smdp/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth


### Optimization with DPS
First, define all hyperparameters and simulation environment

In [5]:
# simulation time of initial state to be optimized
time_init = 35 # t=0.35

# diffusion posterior sampling parameter for scaling the gradient
zeta = 1.0

# inference time step to start optimizing (default: 0)
dps_optim_start = 0

params = {
    'batch_size' : 1,
    'DT' : 0.01,
    't1': 0.65,
    'time_init': time_init,
    'zeta': zeta,
    'image_channels' : 4,
    'image_size' : 64,
    'dps_optim_start' : dps_optim_start
}

Define optimization with DPS, Algorithm 1, Gaussian noise

In [6]:
from physics_check import batch_inflow, physics_forward, batch_geometries_pre_phiflow
from eval import eval_forward   
from phi.torch.flow import *
from sample import gather
from tqdm import tqdm

def optimize_sample(item, params):

    simulation_metadata = {}
    simulation_metadata['NSTEPS'] = int(params['t1'] / params['DT'])
    simulation_metadata['INFLOW'] = batch_inflow(item['INFLOW'], batchSize=params['batch_size'])
    simulation_metadata['INFLOW_1b'] = batch_inflow(item['INFLOW'], batchSize=1)
    bounds = item['BOUNDS']
    simulation_metadata['BOUNDS'] = Box(x=(bounds['_lower'][0], bounds['_upper'][0]),
                                        y=(bounds['_lower'][1], bounds['_upper'][1]))
    simulation_metadata['smoke_res'] = item['smoke_res']
    simulation_metadata['v_res'] = item['v_res']
    simulation_metadata['DT'] = params['DT']

    obstacles = [item['obstacle_list']]
    obstacles = batch_geometries_pre_phiflow(obstacles)

    smoke_state = torch.asarray(item['smoke'], dtype=torch.float32)
    vel_x_state = torch.asarray(item['vel_x'], dtype=torch.float32)
    vel_y_state = torch.asarray(item['vel_y'], dtype=torch.float32)
    mask_state = torch.asarray(item['mask'], dtype=torch.float32).to(device)

    init_state = [smoke_state[0][None], vel_x_state[0][None], vel_y_state[0][None], mask_state[0][None]]
    target_state = [smoke_state[-1][None].to('cuda:0'), vel_x_state[-1][None].to('cuda:0'),
                    vel_y_state[-1][None].to('cuda:0')]

    forward_fn = physics_forward(simulation_metadata)
    forward_fn = math.jit_compile(forward_fn)

    t0 = 0.64
    simulation_metadata['NSTEPS'] = int((params['t1'] - t0) / params['DT'])
    _ = eval_forward(init_state, obstacles, simulation_metadata, physics_forward_fn=forward_fn, t0=t0)
    
    def loss_function(init_state_):

        t0 = params['time_init'] * params['DT']

        init_state_.append(torch.zeros_like(init_state[3]).clone().detach().requires_grad_(False))

        simulation_metadata['NSTEPS'] = int((params['t1'] - t0) / params['DT'])

        out = eval_forward(init_state_, obstacles, simulation_metadata, physics_forward_fn=forward_fn, t0=t0)

        smoke_out = out[-1][0][0]
        vel_x_out = out[-1][1][0]
        vel_y_out = out[-1][2][0]

        smoke_target = target_state[0][0]
        vel_x_target = target_state[1][0]
        vel_y_target = target_state[2][0]

        norm = torch.linalg.norm(smoke_target - smoke_out) + torch.linalg.norm(
            vel_x_target - vel_x_out) + torch.linalg.norm(vel_y_target - vel_y_out)

        return torch.nn.functional.mse_loss(smoke_target, smoke_out) + torch.nn.functional.mse_loss(vel_x_target,
                                                                                                    vel_x_out) + torch.nn.functional.mse_loss(
            vel_y_target, vel_y_out), norm


    image_channels = params['image_channels']
    image_size = params['image_size']
    
    n_steps = 1000
    beta = torch.linspace(0.0001, 0.02, 1000).to(device)
    alpha = 1. - beta
    alpha_bar = torch.cumprod(alpha, dim=0)
    sigma2 = beta

    x = torch.randn([1, image_channels, image_size, image_size],
                    device=mask_state[:1].device)

    x[:, 0] = mask_state[:1]

    zeta_scale = params['zeta']

    pbar = tqdm(range(n_steps - 1))

    cutoff = params['dps_optim_start']

    for t_ in pbar:
        
        t = n_steps - t_ - 1
        t_m1 = t - 1
        t_in = x.new_full((1,), t, dtype=torch.long)
        t_in_m1 = x.new_full((1,), t_m1, dtype=torch.long)
        alpha_bar_t = gather(alpha_bar, t_in)
        alpha_bar_t_m1 = gather(alpha_bar, t_in_m1)
        alpha_t = gather(alpha, t_in)
        beta_t = gather(beta, t_in)
        var_t = gather(sigma2, t_in)
        eps = torch.randn(x.shape, device=x.device)

        x_grad_leaf = x.clone().detach().requires_grad_(True)

        if t_ > cutoff:

            s_hat = model(x_grad_leaf, t_in) / ((1 - alpha_bar_t) ** 0.5)
            x_hat_0 = (1 / (alpha_bar_t ** 0.5)) * (x_grad_leaf - (1 - alpha_bar_t) * s_hat)

            x_dash = (((alpha_t ** 0.5) * (1 - alpha_bar_t_m1)) / (1 - alpha_bar_t)) * x_grad_leaf
            x_dash = x_dash + (((alpha_bar_t_m1 ** 0.5) * beta_t) / (1 - alpha_bar_t)) * x_hat_0
            x_dash = x_dash + (var_t ** .5) * eps

            smoke_state_ = x_hat_0[:, 1]
            vel_x_state_ = x_hat_0[:, 2][:, :, :63]
            vel_y_state_ = x_hat_0[:, 3][:, :63, :]

            l, norm = loss_function([smoke_state_, vel_x_state_, vel_y_state_])
            l.backward()

            pbar.set_description("Loss: %s" % l.item())

            gradient_to_leaf = x_grad_leaf.grad[0]

            x = x_dash - zeta_scale * gradient_to_leaf * 1 / norm  # did not find scaling by norm in reference implementation by authors

        else:

            eps_theta = model(x_grad_leaf, t_in)
            eps_coef = (1 - alpha_t) / (1 - alpha_bar_t) ** .5
            mean = 1 / (alpha_t ** 0.5) * (x_grad_leaf - eps_coef * eps_theta)
            x_grad = mean + (var_t ** .5) * eps
            x = x_grad

        x[:, 0] = mask_state[:1]
        # pytorch clear cache
        torch.cuda.empty_cache()
        
    smoke_state_final = x[:, 1]
    vel_x_state_final = x[:, 2][:, :, :63]
    vel_y_state_final = x[:, 3][:, :63, :]
        
    state_final = [smoke_state_final, vel_x_state_final, vel_y_state_final, mask_state[0][None]]
        
    prediction = eval_forward(state_final, obstacles, simulation_metadata, 
                              physics_forward_fn=forward_fn, t0=params['time_init']*params['DT'])
    
    return [(marker_field.detach().cpu().numpy(), vel_x_field.detach().cpu().numpy(), 
             vel_y_field.detach().cpu().numpy(), mask_field.detach().cpu().numpy()) 
            for marker_field, vel_x_field, vel_y_field, mask_field in prediction]

In [7]:
results = {}

reconstruction_MSE = 0
lpips_smoke = 0

for key in dataKeys:

    item = test_data.load(key)

    prediction = optimize_sample(item, params)

    results[key] = prediction   

    smoke_state = torch.asarray(item['smoke'], dtype=torch.float32)

    reconstruction_MSE += torch.nn.functional.mse_loss(torch.tensor(prediction[-1][0][0]), smoke_state[-1]).item()

    lpips_smoke += lpips.lpips_dist(prediction[0][0], smoke_state[time_init][None].numpy())

lpips_smoke = lpips_smoke / len(dataKeys)
reconstruction_MSE = reconstruction_MSE / len(dataKeys)

print('Reconstruction MSE: ', reconstruction_MSE)
print('LPIPS smoke: ', lpips_smoke)

jit compile physics
tracing physics forwards...


  return torch.sparse_csr_tensor(row_pointers, column_indices, values, shape, device=values.device)
  tensor = torch.from_numpy(x)
  return torch.sparse_csr_tensor(row_pointers, column_indices, values, shape, device=values.device)
  return tuple([int(s) for s in tensor.shape])


tracing physics forwards...


  tensor = torch.tensor(x, device=self.get_default_device().ref)
  if coordinates.shape[0] != grid.shape[0]:  # repeating yields wrong result
  resolution = torch.tensor(self.staticshape(grid)[2:], dtype=coordinates.dtype, device=coordinates.device)
  if dim1 is None or dim1 == 1:
  coordinates = coordinates.repeat(batch_size, *[1] * (len(coordinates.shape-1))) if coordinates.shape[0] < batch_size else coordinates
  grid = grid.repeat(batch_size, *[1] * (len(grid.shape)-1)) if grid.shape[0] < batch_size else grid
  b_indices = self.unstack(indices[min(b, indices.shape[0] - 1)], -1)
  result.append(values[(min(b, values.shape[0] - 1),) + b_indices])
Loss: 2.4086344242095947: 100%|███████████████| 999/999 [00:36<00:00, 27.38it/s]


jit compile physics
tracing physics forwards...
tracing physics forwards...


Loss: 2.4835853576660156: 100%|███████████████| 999/999 [00:28<00:00, 35.22it/s]


jit compile physics
tracing physics forwards...
tracing physics forwards...


Loss: 1.5106109380722046: 100%|███████████████| 999/999 [00:28<00:00, 34.92it/s]


jit compile physics
tracing physics forwards...
tracing physics forwards...


Loss: 1.1385670900344849: 100%|███████████████| 999/999 [00:30<00:00, 32.76it/s]


jit compile physics
tracing physics forwards...
tracing physics forwards...


Loss: 0.8904870748519897: 100%|███████████████| 999/999 [00:28<00:00, 35.12it/s]


Reconstruction MSE:  0.7691944122314454
LPIPS smoke:  0.09035247787833214


### Save results

In [11]:
# save results to file
import pickle   

results_file = 'github/smdp/buoyancy-flow/evaluation/results/results_dps.pkl'

with open(results_file, 'wb') as f:
    pickle.dump(results, f)