In [None]:
from stride import *
import mosaic
# await mosaic.interactive('off')
await mosaic.interactive('on', num_workers=2, log_level='info')

runtime = mosaic.runtime()
%matplotlib widget
sos = 1500.
dx =  0.2e-3 # lambda / 15
cfl = 0.5 

grid_size = int (3150)
space = Space(shape=(grid_size , grid_size), extra=(100, 100), absorbing=(80, 80), spacing=dx) # wavelenght : 3e-03 meter
time = Time(start=0.0e-7, step= cfl * dx / sos , num= int (15000 ))
grid = Grid(space, time)

problem = Problem(name='fullbodyring', space=space, time=time)

In [None]:
def ellipt_coordinates(num, radius): # centered at [0,0]
    
    angles = np.linspace(0, 2*np.pi, num, endpoint=False)
    geometry = np.zeros((num, 2))
    for index, angle in zip(range(num), angles):
        
        geometry[index, 0] = radius[0] * np.cos(angle) 
        geometry[index, 1] = radius[1] * np.sin(angle)
    

    return geometry

In [None]:

print(problem.geometry.num_locations)
problem.transducers.default()
geometry_type = "elliptical"
num_receiver = 512
problem.geometry.default(geometry_type, num_receiver, radius = [0.3,0.3])
receivers = problem.geometry.locations



In [None]:
offset = problem.geometry.num_locations
coords_offest = grid_size / 2 * dx

num_source = 256
coordinates = ellipt_coordinates(num_source , [0.22, 0.22])
for index in range(coordinates.shape[0]):
        _coordinates = coordinates[index, :]
        if len(_coordinates) != problem.geometry.space.dim:
            _coordinates = np.pad(_coordinates, ((0, 1),))
            _coordinates[-1] = problem.geometry.space.limit[2] / 2

        problem.geometry.add(index+offset, problem.geometry._transducers.get(0), _coordinates + coords_offest)

sources = problem.geometry.locations[offset:]

In [None]:
f_centre = 0.5e6
t = np.arange(0, time.num * time.step,time.step)
b = 1 / f_centre
delay = 0.03 * time.num * time.step
wave = np.exp(-(t - delay)**2/b**2) * np.cos(2* np.pi * f_centre * (t - delay))

# shots 
for source in sources:
    problem.acquisitions.add(Shot(source.id,
                          sources=[source], receivers=receivers,
                          geometry=problem.geometry, problem=problem))
    
for shot in problem.acquisitions.shots:
    shot.wavelets.data[0, :] = wave

print(f"total number of shots:{len(problem.acquisitions.shots)}")


In [None]:
data = np.full((grid_size, grid_size), 1500, dtype=int)

s = 0.08 /dx
R = s / np.sqrt(3)
angles = np.deg2rad([90, 210, 330])  # Convert degrees to radians
centers = [(R * np.cos(theta), R * np.sin(theta)) for theta in angles]
radii = [0.02 / dx , 0.03/dx, 0.04/dx] 
sos_vals = [1450, 1480, 1550]

y, x = np.ogrid[:grid_size, :grid_size]
origin = grid_size / 2 


for center, radius, sos in zip(centers, radii, sos_vals):
    mask = (x - center[0] - origin)**2 + (y - center[1]- origin)**2 <= radius**2
    data[mask] = sos

vp_true = ScalarField(name='vp', grid =problem.grid ,data = data)
problem.medium.add(vp_true)
problem.plot()

In [None]:
# print(problem.acquisitions.shot_ids)
for id in problem.acquisitions.shot_ids:
    print(problem.acquisitions.get(id))

In [None]:

pde = IsoAcousticDevito.remote(grid=problem.grid, len=runtime.num_workers)

await forward(problem, pde, vp_true, dump=True)

problem.acquisitions.plot()

In [None]:
# # Get all remaining shot IDs
# shot_ids = problem.acquisitions.remaining_shot_ids

# # Run an asynchronous loop across all shot IDs
# @runtime.async_for(shot_ids)
# async def loop(worker, shot_id):
#     runtime.logger.info('Giving shot %d to %s' % (shot_id, worker.uid))

#     # Fetch one sub-problem corresponding to a shot ID
#     sub_problem = problem.sub_problem(shot_id)
    
#     # Access the source wavelets of this shot
#     wavelets = sub_problem.shot.wavelets
    
#     # Execute the PDE forward
#     traces = await pde(wavelets, vp_true,
#                        problem=sub_problem,
#                        runtime=worker).result()

#     runtime.logger.info('Shot %d retrieved' % sub_problem.shot_id)

#     # Store the retrieved traces into the shot
#     shot = problem.acquisitions.get(shot_id)
#     shot.observed.data[:] = traces.data

#     runtime.logger.info('Retrieved traces for shot %d' % sub_problem.shot_id)

# # Because this is an asynchronous loop, it needs to be awaited 
# _ = await loop

# # Plot the result
# _ = problem.acquisitions.plot()

In [None]:
vp = ScalarField.parameter(name='vp', grid=grid, needs_grad=True)
vp.fill(1500.)

problem.medium.add(vp)

In [None]:
loss = L2DistanceLoss.remote(len=runtime.num_workers)
process_wavelets = ProcessWavelets.remote(len=runtime.num_workers)
process_observed = ProcessObserved.remote(len=runtime.num_workers)
process_wavelets_observed = ProcessWaveletsObserved.remote(len=runtime.num_workers)
process_traces = ProcessTraces.remote(len=runtime.num_workers)

In [None]:
step_size = 10
process_grad = ProcessGlobalGradient()
process_model = ProcessModelIteration(min=1400., max=2100.)

optimiser = GradientDescent(vp, step_size=step_size,
                            process_grad=process_grad,
                            process_model=process_model)

In [None]:
from stride import OptimisationLoop

# Clear the previous Devito operators
await pde.clear_operators()

optimisation_loop = OptimisationLoop()

# Specify a series of frequency bands, which we will introduce gradually 
# into the inversion in order to better condition it
max_freqs = [0.2e6, 0.8e6]

num_blocks = len(max_freqs)
num_iters = 4

# Start iterating over each block in the optimisation
for block, f_max in optimisation_loop.blocks(num_blocks, max_freqs):

    # Proceed through every iteration in the block
    for iteration in block.iterations(num_iters):
        runtime.logger.info('Starting iteration %d (out of %d), '
                            'block %d (out of %d)' %
                            (iteration.id+1, block.num_iterations, block.id+1,
                             optimisation_loop.num_blocks))

        # Select some shots for this iteration
        shot_ids = problem.acquisitions.select_shot_ids(num=15, randomly=True)

        # Clear the gradient buffers of the variable
        vp.clear_grad()

        # Asynchronously loop over all the selected shot IDs
        @runtime.async_for(shot_ids)
        async def loop(worker, shot_id):
            runtime.logger.info('Giving shot %d to %s' % (shot_id, worker.uid))

            # Fetch one sub-problem corresponding to the shot ID
            sub_problem = problem.sub_problem(shot_id)
            wavelets = sub_problem.shot.wavelets
            observed = sub_problem.shot.observed

            # Pre-process the wavelets and observed
            wavelets = process_wavelets(wavelets, f_max=f_max, filter_relaxation=0.75, runtime=worker)
            observed = process_observed(observed, f_max=f_max, filter_relaxation=0.75, runtime=worker)
            processed = process_wavelets_observed(wavelets, observed, f_max=f_max, runtime=worker)
            wavelets = processed.outputs[0]
            observed = processed.outputs[1]
            
            # Execute the PDE forward
            modelled = pde(wavelets, vp, problem=sub_problem, runtime=worker)

            # Pre-process the modelled and observed traces
            traces = process_traces(modelled, observed, f_max=f_max, filter_relaxation=0.75, runtime=worker)
            # and use these pre-processed versions to calculate the
            # value of the loss_freq function
            fun = await loss(traces.outputs[0], traces.outputs[1],
                             problem=sub_problem, runtime=worker).result()

            iteration.add_loss(fun)
            runtime.logger.info('Functional value for shot %d: %s' % (shot_id, fun))

            # Now, we can calculate the gradient by executing the adjoint of the
            # forward process
            await fun.adjoint()

            runtime.logger.info('Retrieved gradient for shot %d' % sub_problem.shot_id)

        # Because this is an async loop, it needs to be awaited    
        _ = await loop
        # Update the vp with the calculated gradient by taking a step with the optimiser
        await optimiser.step()

        runtime.logger.info('Done iteration %d (out of %d), '
                            'block %d (out of %d) - Total loss_freq %e' %
                            (iteration.id+1, block.num_iterations, block.id+1,
                             optimisation_loop.num_blocks, iteration.total_loss))
        runtime.logger.info('====================================================================')

# Plot the vp afterwards   
vp.plot()