"""
Verification example for a multi-layer model with induced polarisation

This example is based on the publication
Li et al (2019), A discussion of 2D induced polarization effects in airborne
electromagnetic and inversion with a robust 1D laterally constrained inversion
scheme, Geophysics, vol 84, no 2, doi 10.1190/geo2018-0102.1

"""

# 0. Imports
############
import numpy as np
from matplotlib import pyplot as plt

from SimPEG import maps
import SimPEG.electromagnetics.time_domain as tdem
import pathlib

# 1. Synthetic test case (Lin et al 2019)
#########################################
# Waveform
peak_current = 100 # [A]
nturns = 16 # [-]
waveform = np.array((
    (-10.e-3, 0.),
    (-9.e-3, 1.),
    (0., 1.),
    (5.e-6, 0.),
))
waveform_times, waveform_current = waveform.T
# Ground
resistivities = np.array([5000., 500., 5000.]) # [Ohm.m]
thicknesses = np.array([20., 50.]) # [m]
depths = np.cumsum(thicknesses)
eta = np.array([0., 350e-3, 0.]) # [V/V]
tau = np.array([0., 1.e-3, 0.]) # [s]
c = np.array([0., 0.5, 0.]) # -
# TX (nodes forming a circle)
elevation = 30 # [m]
area = 300 # [m2]
tx_center = np.array((0,0,elevation))
radius = np.sqrt(area/np.pi)
theta = np.linspace(0, 2*np.pi, 200, endpoint=False)
pts = np.vstack((radius*np.cos(theta), radius*np.sin(theta), [0.]*len(theta)))
source_location = pts.T + tx_center
# RX
receiver_location = tx_center
receiver_orientation = 'z'
gates = 10**np.linspace(-2, -5, 31) # time channels [s]
# Results
data = dict(np.load('simpeg_ip_verification.zip'))

# 2. Model
##########
# Waveform
waveform = tdem.sources.PiecewiseLinearWaveform(
    times=waveform_times, currents=waveform_current
)
# Ground
n_layer = len(thicknesses)+1
model_mapping = maps.IdentityMap(nP=n_layer)
# RX
receiver_list = [
    tdem.receivers.PointMagneticFluxTimeDerivative(
        receiver_location, gates, orientation=receiver_orientation,
    )
]
# TX
source_current = peak_current*nturns
source_list = [
    tdem.sources.LineCurrent(
        receiver_list=receiver_list,
        location=source_location,
        waveform=waveform,
        current=source_current
    )
]
# Survey
survey = tdem.Survey(source_list)

# 3. Computing
##############
simulation_ip = tdem.Simulation1DLayered(
    survey=survey,
    thicknesses=thicknesses,
    rhoMap=model_mapping,
    eta=eta,
    tau=tau,
    c=c
)
dbdt_ip = np.abs(simulation_ip.dpred(resistivities))
data['dBdt_typo_formula_corrections'] = np.vstack((gates, dbdt_ip)).T

# 4. Plotting
#############
fig,axes = plt.subplots(ncols=3, sharey=True, figsize=(10,4))
titles = (
    'No code correction',
    'Typo correction only',
    'Typo+formula corrections',
)
data_labels = (
    'dBdt_no_correction',
    'dBdt_typo_correction',
    'dBdt_typo_formula_corrections'
)
ref_times, ref_dbdt = data['dBdt_lin_et_al_2019'].T
for ax, title, label in zip(axes, titles, data_labels):
    ax.loglog(
        ref_times, ref_dbdt, 'k--', label='Lin et al., 2019'
    )
    times, dbdt = data[label].T
    ax.loglog(
        times, dbdt, marker='o', mfc='none', mec='r', color='r', linestyle='--',
        label='SimPEG'
    )
    ax.set_title(title)
    ax.set_xlim((1.e-6, 1.e-1))
    ax.set_xlabel('Time(s)')
    ax.grid(which='both', c='w')
fig.suptitle('AEM 1D response with IP')
axes[0].set_ylabel(r'$\mathrm{d}\mathrm{B}_\mathrm{z}\,/\,\mathrm{d}t$')
axes[0].set_ylim((1.e-11, 1.e-2))
axes[0].legend(loc=1)
plt.tight_layout()
fig.savefig('simpeg_ip_verification.png', dpi=200)
plt.close()