In [None]:
import os 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'
import sys 
sys.path.insert(1, os.path.dirname(os.getcwd()))
from functions import *

## Predictions

In [None]:
# load model and data
autoencoder = tf.keras.models.load_model('cae.keras', safe_mode=False)

const_dict = load_constants() 
Uf, P, T_h, T_0, Pr, Ra = get_model_constants(const_dict) 
_, data_val, x, z, _ = load_data(2000, 500, Uf, P, T_h, T_0)

In [None]:
preds_val = autoencoder.predict(data_val, batch_size=10, verbose=0) 
 
mse_val = ((preds_val-data_val)**2)
iqr = np.quantile(mse_val, 0.75) - np.quantile(mse_val, 0.25)
val_losses = mse_val.mean(axis=(0,1,2))

def ssim(preds, data):
  c1, c2 = 1e-5, 1e-5
  
  mu = data.mean(axis=(1,2))
  mu_hat = preds.mean(axis=(1,2))
  sigma = data.std(axis=(1,2))
  sigma_hat = preds.std(axis=(1,2))
  
  data_centered = data - mu[:, None, None, :]
  preds_centered = preds - mu_hat[:, None, None, :]
  sigma_cross = np.mean(data_centered * preds_centered, axis=(1, 2))  # shape (n, 4)
  
  # Compute SSIM components
  luminance = (2 * mu * mu_hat + c1) / (mu**2 + mu_hat**2 + c1)
  contrast = (2 * sigma * sigma_hat + c2) / (sigma**2 + sigma_hat**2 + c2)
  structure = (sigma_cross + c2/2) / (sigma * sigma_hat + c2/2)
  
  # Combine components
  ssim_map = luminance * contrast * structure
  
  return ssim_map


ssim_map = ssim(preds_val, data_val)

iqr = np.quantile(ssim_map, 0.75, axis=0) - np.quantile(ssim_map, 0.25, axis=0)

ssim_val = ssim_map.mean(axis=0)

print(f'Validation MSE: (u) {val_losses[0]:.2e}, (w) {val_losses[1]:.2e}, (p) {val_losses[2]:.2e}, (T) {val_losses[3]:.2e}\n')

print(f'SSIM: (u) {ssim_val[0]:.2e}, (w) {ssim_val[1]:.2e}, (p) {ssim_val[2]:.2e}, (T) {ssim_val[3]:.2e}')
print(f'IQR:  (u) {iqr[0]:.2e}, (w) {iqr[1]:.2e}, (p) {iqr[2]:.2e}, (T) {iqr[3]:.2e}')

## Figures

In [None]:
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib.ticker import FuncFormatter
formatter = FuncFormatter(lambda x, pos: f'{x:.1e}')

In [None]:
X, Z = np.meshgrid(x, z)
mins = data_val.min(axis=(0,1,2))
maxs = data_val.max(axis=(0,1,2))

fig, ax = plt.subplots(3, 4, figsize=(20,7), sharex=True, sharey=True, layout='constrained')

ax[0,0].set_ylabel('Data', fontsize=15)
ax[1,0].set_ylabel('Predictions', fontsize=15)

for j, v in enumerate([r'$u$',r'$w$',r'$p$',r'$\theta$']):        
  im1 = ax[0,j].contourf(X, Z, data_val[-150,:,:,j].T, cmap='jet', vmin=mins[j], vmax=maxs[j], levels=np.linspace(mins[j],maxs[j],40))
  im1 = ax[1,j].contourf(X, Z, preds_val[-150,:,:,j].T, cmap='jet', vmin=mins[j], vmax=maxs[j], levels=np.linspace(mins[j],maxs[j],40))
  
  cbar1 = plt.colorbar(im1, ax=ax[1,j], orientation='horizontal', extend='max', shrink=0.7, aspect=20, pad=0.05)
  cbar1.locator = ticker.MaxNLocator(nbins=5)
  cbar1.update_ticks()
  
  ax[0,j].set_title(v, fontsize=15)

plt.show()