In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jax.lib import xla_bridge
print(f"Jax is using: {xla_bridge.get_backend().platform}")

# Setup

## Domain

In [None]:
from jwave.geometry import Domain

N, dx = (128, 128), (0.1e-3, 0.1e-3)
domain = Domain(N, dx)

## Acoustic medium

In [None]:
from jwave.geometry import Medium

medium = Medium(domain=domain, sound_speed=1500.0)
print(medium)

## Time

In [None]:
from jwave.geometry import TimeAxis

time_axis = TimeAxis.from_medium(medium, cfl=0.3)

In [None]:
time_axis

## Initial pressure

In [None]:
from jax import numpy as jnp

from jwave import FourierSeries
from jwave.geometry import circ_mask

p0 = 1.0 * jnp.expand_dims(circ_mask(N, 4, (80, 60)), -1)
p0 = FourierSeries(p0, domain)

In [None]:
from matplotlib import pyplot as plt

from jwave.utils import show_field

show_field(p0)
plt.title(f"Initial pressure field")
plt.show()

# Run simulation

In [None]:
from jax import jit

from jwave.acoustics import simulate_wave_propagation


@jit
def compiled_simulator(medium, p0):
    return simulate_wave_propagation(medium, time_axis, p0=p0)

In [None]:
pressure = compiled_simulator(medium, p0)

In [None]:
t = 250
show_field(pressure[t])
plt.title(f"Pressure field at t={time_axis.to_array()[t]}")
plt.show()

# Timings

In [None]:
%timeit compiled_simulator(medium, p0).params.block_until_ready()