In [1]:
import numpy as np
import jax.numpy as jnp
import jaxopt

from itertools import product

from pointscat.forward_problem import PointScatteringProblem, compute_far_field
from pointscat.inverse_problem import unif_sample_disk, DiscreteMeasure


np.random.seed(0)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
amplitudes = np.array([0.5, 2, 1])
locations = np.array([[-3.3, -3.7], [-2.8, 3.5], [3.2, 2.6]])
wave_number = 1
point_scat = PointScatteringProblem(locations, amplitudes, wave_number)

In [3]:
box_size = 10  # locations should belong to (-box_size/2,box_size/2)
num_frequencies = 50
cutoff_frequency = 2 * wave_number
frequencies = unif_sample_disk(num_frequencies, cutoff_frequency)

In [4]:
incident_angles = np.array([np.pi + np.angle(k[0]+1j*k[1]) - np.arccos(np.linalg.norm(k)/(2*wave_number))
                            for k in frequencies])
observation_directions = np.array([np.angle(k[0]+1j*k[1]) + np.arccos(np.linalg.norm(k)/(2*wave_number))
                                   for k in frequencies])

far_field = point_scat.compute_far_field(incident_angles, observation_directions)
far_field_born = point_scat.compute_far_field(incident_angles, observation_directions, born_approx=True)
obs = np.concatenate([np.real(far_field), -np.imag(far_field)])

In [5]:
x_tab = np.linspace(-box_size/2, box_size/2, 3)
init_locations = np.array([[x_1, x_2] for (x_1, x_2) in product(x_tab, x_tab)])
init_amplitudes = np.ones(len(init_locations))

In [6]:
init_locations[0] = locations[0]
init_locations[1] = locations[1]
init_locations[2] = locations[2]

In [7]:
estimated_measure = DiscreteMeasure(init_locations, init_amplitudes)

In [8]:
num_spikes = len(init_amplitudes)

def sliding_obj(amplitudes):
    far_field = compute_far_field(init_locations, amplitudes, wave_number, incident_angles, observation_directions)
    image = jnp.concatenate([jnp.real(far_field), -jnp.imag(far_field)])  # TODO: fix ugly

    return jnp.sum((image - obs)**2) / 2

# vector of initial parameters
# TODO: fix ugly conversion
x_0 = jnp.array(init_amplitudes, dtype='float64')

bounds = jnp.inf * jnp.ones(num_spikes)
solver = jaxopt.LBFGSB(fun=sliding_obj, maxiter=1)

In [None]:
params, state = solver.run(x_0, bounds=(-bounds, bounds))