# SimPEG(emg3d) -- small dev-example

Use either via

- `conda env create -f environment-simpeg-emg3d.yml`

(creates env `simpeg-emg3d`), or manually by installing

- SimPEG: `pip install git+https://github.com/simpeg/simpeg@refs/pull/1515/head`
- emg3d: `pip install git+https://github.com/emsig/emg3d@simpeg`
- Additionally: `xarray`, `matplotlib`, `ipympl`

## TODOs:
- Clean-up the two PR's
- Change this example to create survey and simulation with emg3d, and use simpeg only for the inversion.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

import emg3d
import simpeg
import simpeg.electromagnetics.frequency_domain as FDEM

In [None]:
%matplotlib widget

## Create a grid

Very small now, for dev purposes.

In [None]:
frequency = 2.
seafloor = -2000.
#min_cell = 100.  # => 81,920 cells, 5 sources, 41 receivers
min_cell = 200.  # => 18,432 cells, 3 sources, 21 receivers

mesh = emg3d.construct_mesh(
    center=(min_cell/2, 0, seafloor),  # Smallest cell at center
    frequency=1.,          # Reference frequency
    properties=(0.3, 1, 1, 0.3),  # Reference resistivity
    min_width_limits=min_cell,
    # Domain in which we want precise results
    domain=([-2000, 2000],    # x-dir: where we have receivers
            [-1000, 1000],    # y-dir: just extent of target
            [seafloor-1000, seafloor+500]),  # 
    mapping='Conductivity',
    center_on_edge=True,
)

# Active cells for inversion: everything below water
inds_active = mesh.cell_centers[:, 2] < seafloor

mesh

## Create a model

Deep marine scenario, so we can ignore the air layer for developing.

In [None]:
# Initiate with background resistivity
sigma = np.ones(mesh.n_cells)

# Add water layer
sigma[mesh.cell_centers[:, 2] > seafloor] = 1/0.3

# Add target
inds_target = (
    (abs(mesh.cell_centers[:, 0]) < 1000) &
    (abs(mesh.cell_centers[:, 1]) < 1000) &
    np.logical_and(mesh.cell_centers[:, 2] < seafloor-500,
                   mesh.cell_centers[:, 2] > seafloor-700) 
)
sigma[inds_target] = 1/100.

# Create model
model_true = emg3d.Model(mesh, sigma, mapping='Conductivity')
model_start = emg3d.Model(mesh, 1.0, mapping='Conductivity')

# QC
mesh.plot_3d_slicer(
    1/sigma,  # Plot resistivities, not conductivities
    pcolor_opts={'edgecolors': 'grey', 'linewidth': 0.5,
                 'cmap':'Spectral_r',
                 'norm':LogNorm(vmin=0.3, vmax=100)}, 
    xlim=[-3000, 3000],
    ylim=[-2000, 2000],
    zlim=[seafloor-1500, seafloor+500],
    zslice=seafloor-500,
)

# QC
model_true

## Create a survey

For now, we have sources and receivers on the corresponding edges. However, this has to become flexible!

In [None]:
rec_x = mesh.cell_centers_x[abs(mesh.cell_centers_x) < 2100]
rec = emg3d.surveys.txrx_coordinates_to_dict(emg3d.RxElectricPoint, (rec_x, 0, seafloor+100, 0, 0))
src = emg3d.surveys.txrx_coordinates_to_dict(emg3d.TxElectricDipole, (rec_x[::8], 0, seafloor+100, 0, 0))
survey = emg3d.Survey(
    sources=src,
    receivers=rec,
    frequencies=frequency,
)
survey

In [None]:
rec_coords = survey.receiver_coordinates()
src_coords = survey.source_coordinates()
# QC
mesh.plot_slice(
    1./sigma, grid=True, normal='Z', 
    ind=12,
    pcolor_opts={'cmap':'Spectral_r', 'norm':LogNorm(vmin=1, vmax=100)}, 
    range_x=(-3000, 3000),
    range_y=(-2000, 2000),
)

plt.plot(rec_coords[0], rec_coords[1], 'rv')
plt.plot(src_coords[0], src_coords[1], 'w*')

## Create a Simulation and observed data

In [None]:
try:
    ddata = emg3d.load(f'toy-simpeg.h5')
    sim = ddata['sim']
    model_true = ddata['model_true']
except FileNotFoundError:
    # Create an emg3d Simulation instance
    sim = emg3d.simulations.Simulation(
        survey=survey.copy(),
        model=model_true,
        gridding='both',
        max_workers=10,
        gridding_opts={'center_on_edge': False},
        receiver_interpolation='linear',
        solver_opts = {'tol_gradient': 1e-3},
        tqdm_opts=False,
    )
    sim.compute(observed=True, min_offset=100)
    sim.clean('computed')
    
    sim.model = model_start
    
    sim.compute()
    sim.survey.data['start'] = sim.survey.data.synthetic
    sim.clean('computed')
    
    emg3d.save(f'toy-simpeg.h5', sim=sim, model_true=model_true)


sim

## Create SimPEG-Simulation

In [None]:
# TODO move the mapping inside
active_map = simpeg.maps.InjectActiveCells(mesh, inds_active, sigma[~inds_active])
nP = int(inds_active.sum())
conductivity_map = active_map * simpeg.maps.ExpMap(nP=nP)

In [None]:
# Define the Simulation
simulation = emg3d.inversion.simpeg.FDEMSimulationNew(
    sim,
    sigmaMap=conductivity_map,
    verbose=False,
)

# True and initial model
m_true =  np.log(sigma[inds_active])
m0 =  np.ones(m_true.shape) * np.log(1.)

# WORK in PROGRESS - HERE

In [None]:
f = simulation.fields(m_true)
d_true = simulation.dpred(m_true, f=f)

In [None]:
d_0 = simulation.dpred(m0)

In [None]:
relative_error = 0.01
noise_floor = 1e-14
standard_deviation = np.sqrt(abs(relative_error*d_true)**2 + (noise_floor)**2)

In [None]:
residual = (d_true - d_0)/standard_deviation

### Question @Seogi: Why are there less data per source than there are receivers?

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.set_title('Responses (V/m)')
ax1.semilogy(abs(d_true), 'o', label='true/observed')
ax1.semilogy(abs(d_0), '.', label='initial model')
ax1.legend()

ax2.set_title('Weighted Residuals')
nsrc = xyz_src.shape[0]
nrec = len(rec_x)-4
x = np.arange(len(residual))
for s in range(nsrc):
    ax2.semilogy(x[s*nrec:(s+1)*nrec], abs(residual[s*nrec:(s+1)*nrec]), '.')

plt.show()

In [None]:
em_data = simpeg.data.ComplexData(survey, dobs=d_true, standard_deviation=standard_deviation)
dmis = simpeg.data_misfit.L2DataMisfit(data=em_data, simulation=simulation)

In [None]:
adjoint_tol = 1e-10
def test_misfit():
    passed = simpeg.tests.check_derivative(
        lambda m: (simulation.dpred(m), lambda mx: simulation.Jvec(m0, mx, f=f)),
        m0,
        plotIt=False,
        num=3,
    )

def test_adjoint():
    # Adjoint Test
    f = simulation.fields(m=m0)
    # u = np.random.rand(.mesh.nC * .survey.nSrc)
    v = np.random.rand(inds_active.sum())
#     v = np.random.rand(mesh.nC)
    w = np.random.rand(simulation.survey.nD)
    wtJv = np.vdot(w, simulation.Jvec(m0, v, f=f)).real
    vtJtw = np.vdot(v, simulation.Jtvec(m0, w, f=f))
    passed = np.abs(wtJv - vtJtw) < adjoint_tol
    print("Adjoint Test", np.abs(wtJv - vtJtw), passed)
    print(wtJv, vtJtw)
    
def test_dataObj():
    passed = simpeg.tests.check_derivative(
        lambda m: [dmis(m), dmis.deriv(m)], m0, plotIt=False, num=2
    )    

In [None]:
test_dataObj()

In [None]:
test_misfit()

In [None]:
test_adjoint()

In [None]:
%%time

# Define the regularization (model objective function)
reg = simpeg.regularization.WeightedLeastSquares(
    mesh,
    active_cells=inds_active,
    reference_model=m0,
    alpha_s=1e-8,
    alpha_x=1,
    alpha_y=10,
    alpha_z=1
)

nit = 20
#nit = 2
opt = simpeg.optimization.InexactGaussNewton(
    maxIter=nit, maxIterLS=nit, maxIterCG=nit, tolCG=1e-3,
)

inv_prob = simpeg.inverse_problem.BaseInvProblem(dmis, reg, opt)
starting_beta = simpeg.directives.BetaEstimate_ByEig(beta0_ratio=1)
save = simpeg.directives.SaveOutputDictEveryIteration()

beta_schedule = simpeg.directives.BetaSchedule(coolingFactor=2, coolingRate=1)
target_misfit = simpeg.directives.TargetMisfit(chifact=1)

directives_list = [
    starting_beta,
    beta_schedule,
    target_misfit,
    save
]
em_inversion = simpeg.inversion.BaseInversion(inv_prob, directiveList=directives_list)

# Run inversion
recovered_conductivity_model = em_inversion.run(m0)

In [None]:
target_misfit.target

In [None]:
plt.figure(figsize=(10, 4))
iteration = len(save.outDict.keys())
plt.semilogy(abs(em_data.dobs), 'o', label='Observed')
plt.semilogy(abs(save.outDict[iteration]['dpred']), '.', label='Predicted')
plt.legend()

In [None]:
fig, axs = plt.subplots(1,2, figsize=(8, 4))

sigm_est = conductivity_map * save.outDict[iteration]['m']
sigmas = [sigm_est, sigma]
titles = ["Estimated", "True"]
for ii, ax in enumerate(axs):
    out = mesh.plot_slice(
        1./sigmas[ii], grid=False, normal='Y', 
        pcolor_opts={'cmap':'Spectral_r', 'norm':LogNorm(vmin=0.33, vmax=100)}, 
        ax=ax,
    )
    ax.set_aspect(1)
    ax.set_ylim(-4000, 0)
    ax.set_xlim(-2000, 2000)
    if ii == 1:
        ax.set_yticks([])
    ax.set_title(titles[ii])
    cb = plt.colorbar(out[0], ax=ax, fraction=0.03, orientation='horizontal')
    cb.set_label("Resistivity (Ω m)")

In [None]:
emg3d.Report()