In [1]:
# %%
import os
import ml_collections
import numpy as np
import cv2
import matlab.engine # the matlab engine for python
import jax
from IPython.display import HTML

from jwave import FourierSeries
from jwave.utils import load_image_to_numpy
from jwave.acoustics import simulate_wave_propagation
from jwave.geometry import Medium, Domain, TimeAxis
from jax import jit
from jax import numpy as jnp
from celluloid import Camera

import matplotlib.pyplot as plt
from wavebench.generate_data.time_varying.generate_data_rtc import generate_rtc
from wavebench import wavebench_dataset_path
from wavebench.utils import absolute_file_paths
from wavebench import wavebench_path
from wavebench.plot_utils import plot_images, remove_frame
from wavebench import wavebench_figure_path

# %%

config = ml_collections
config.initial_pressure_type = 'thick_lines'
# config.initial_pressure_type = 'mnist'

config.save_data = False
# config.medium_type = 'gaussian_random_field'
config.medium_type = 'gaussian_lens'
config.device_id = 0


config.domain_sidelen = 128
config.domain_dx = 8
# the above seetings give a domain of 1024 km x 1024 km

config.medium_source_loc = (50, 55)
config.medium_density = 2650
config.pml_size = 2

#  define the properties of the propagation medium
min_wavespeed = 1400 # [m/s]
max_wavespeed = 4000 # [m/s]
point_mass_strength = -31000


data_path = os.path.join(
    wavebench_dataset_path,
    f"time_varying/{config.initial_pressure_type}")


if config.medium_type == 'gaussian_lens':
  z = np.ones((config.domain_sidelen,config.domain_sidelen))
  z[config.medium_source_loc] = point_mass_strength
  medium_sound_speed = cv2.GaussianBlur(
      z,
      ksize=(0, 0),
      sigmaX=50,
      borderType=cv2.BORDER_REPLICATE)
elif config.medium_type == 'gaussian_random_field':
  medium_sound_speed = np.fromfile(
    os.path.join(
      wavebench_dataset_path, "time_varying/wavespeed/cp_128x128_00001.H@"),
    dtype=np.float32).reshape(128, 128)

  if config.domain_sidelen != 128:
    medium_sound_speed = jax.image.resize(
        medium_sound_speed,
        (config.domain_sidelen, config.domain_sidelen),
        'bicubic')
else:
  raise NotImplementedError

medium_sound_speed -= medium_sound_speed.min()
medium_sound_speed /= medium_sound_speed.max()

config.medium_sound_speed = medium_sound_speed*(
max_wavespeed - min_wavespeed) + min_wavespeed

# only a single example is generated
config.source_list = sorted(absolute_file_paths(data_path))[:1]#[82:83]


In [2]:



jax.config.update(
"jax_default_device", jax.devices()[config.device_id])

domain = Domain(
(config.domain_sidelen, config.domain_sidelen),
(config.domain_dx, config.domain_dx))

medium = Medium(
domain=domain,
sound_speed=config.medium_sound_speed[..., np.newaxis])
medium.density = config.medium_density
medium.pml_size = config.pml_size

time_axis = TimeAxis.from_medium(medium, cfl=0.3, t_end=0.2)

@jit
def record_pressure_traces(medium, initial_pressure):
    final_pressure = simulate_wave_propagation(
        medium, time_axis, p0=initial_pressure,
        )
    return final_pressure.on_grid.squeeze()


In [3]:
image = config.source_list[0]
initial_pressure = load_image_to_numpy(image,
    image_size=(config.domain_sidelen, config.domain_sidelen))/255


initial_pressure = jnp.expand_dims(initial_pressure, -1)
initial_pressure = FourierSeries(initial_pressure, domain)
pressure_traces = record_pressure_traces(medium, initial_pressure)


In [4]:
fig = plt.figure(figsize=(4, 4))

camera = Camera(fig)

num_frames = pressure_traces.shape[0]
for snapshot_idx in list(range(0, num_frames, 10) ):
  plt.imshow(pressure_traces[snapshot_idx],
             cmap='coolwarm', vmin=-0.1, vmax=1.1)
  camera.snap()

animation = camera.animate();

plt.close()
animation.save(
  f'{wavebench_figure_path}_rtc_{config.medium_type}_{config.initial_pressure_type}.mp4')

HTML(animation.to_html5_video())
