In [1]:
# %%writefile run_em1d_multiprocessing.py
try:
    from multiprocessing import Pool
except ImportError:
    PARALLEL = False
else:
    PARALLEL = True
    
from SimPEG import Mesh, Maps, Utils, Survey, Problem, Props
from simpegem1d import *
from scipy.constants import mu_0
import numpy as np
from multiprocessing import Pool   
import multiprocessing
import scipy.sparse as sp
from functools import reduce
# from schwimmbad import MPIPool as Pool
import time 

def dot(args):
    return np.dot(args[0], args[1])

def set_mesh_1d():        
    cs = 10.
    nearthick = np.logspace(-1, 1, 3)
    linthick = np.ones(15)*cs
    deepthick = np.logspace(1, 2, 3)
    hx = np.r_[nearthick, linthick, deepthick, deepthick[-1]]
    return Mesh.TensorMesh([hx], [0.])

def run_simulation_FD(args):
    """
        rxLoc, SrcLoc, mesh_1d, offset, frequency,
        field_type = 'secondary',
        rxType = 'Hz',
        srcType = 'VMD'

    """    
    sigma, jacSwitch = args
    mesh_1d = set_mesh_1d()
    # Todo: require input args
#     rxLoc, SrcLoc, mesh_1d, offset, frequency, field_type, rxType, srcType = args            
    FDsurvey = EM1DSurveyFD()
    FDsurvey.rxLoc = np.array([0., 0., 100.+30.])
    FDsurvey.srcLoc = np.array([0., 0., 100.+30.])
    FDsurvey.fieldtype = 'secondary'
    FDsurvey.rxType = 'Hz'
    FDsurvey.srcType = 'VMD'
    FDsurvey.offset = np.r_[8., 8., 8.]        
    depth = -mesh_1d.gridN[:-1]
    LocSigZ = -mesh_1d.gridCC
    nlay = depth.size
    topo = np.r_[0., 0., 100.]
    FDsurvey.depth = depth
    FDsurvey.topo = topo
    FDsurvey.LocSigZ = LocSigZ
    FDsurvey.frequency = np.r_[900., 7200., 56000]
    FDsurvey.Nfreq = FDsurvey.frequency.size
    FDsurvey.Setup1Dsystem()
    FDsurvey.SetOffset()
    sig_half = 1e-4
    sig_blk = 1e-2
    chi_half = 0.
    expmap = Maps.ExpMap(mesh_1d)
    sig  = np.ones(nlay)*sig_half
    blk_ind = (-50>LocSigZ) & (-100<LocSigZ)
    sig[blk_ind] = sig_blk
    m = np.log(sigma)

    prob = EM1D(
        mesh_1d, sigmaMap=expmap, filter_type='key_101',
        jacSwitch=jacSwitch,
        chi= np.zeros(FDsurvey.nlay)
    )
    if prob.ispaired:
        prob.unpair()
    if FDsurvey.ispaired:
        FDsurvey.unpair()
    prob.pair(FDsurvey)    
    if jacSwitch:
        u, dudsig = prob.fields(m)
        drespdsig = FDsurvey.projectFields(dudsig)
        return drespdsig
    else:
        u = prob.fields(m)
        resp = FDsurvey.projectFields(u)                
        return resp

class GlobalEM1DProblem(Problem.BaseProblem):
    """
        The GlobalProblem allows you to run a whole bunch of SubProblems,
        potentially in parallel, potentially of different meshes.
        This is handy for working with lots of sources,
    """
    sigma, sigmaMap, sigmaDeriv = Props.Invertible(
        "Electrical conductivity (S/m)"
    )    
    
    _Jmatrix = None
    n_cpu = None
    n_sounding = None    
    n_layer = None
    
    def __init__(self, mesh, **kwargs):
        Utils.setKwargs(self, **kwargs)
        if self.n_cpu is None:
            self.n_cpu = multiprocessing.cpu_count()
        self.sigmaMap = Maps.IdentityMap(mesh)
        # temporary
        self.n_sounding = mesh.nCx
        self.n_layer = mesh.nCy
        
    def forward(self, m, f=None):          
        self.model = m
        pool = Pool(self.n_cpu)
        Sigma = self.sigma.reshape((self.n_layer, self.n_sounding), order="F")
        if PARALLEL:
            result = pool.map(run_simulation_FD, [(Sigma[:,i], False) for i in range(self.n_sounding)])
            pool.close()
            pool.join()
        else:
            result = [run_simulation_FD((Sigma[:,i], False)) for i in range(self.n_sounding)]
        return np.hstack(result)
    
    def getJ(self, m):
        if self._Jmatrix is not None:
            return self._Jmatrix     
        pool = Pool(self.n_cpu)
        Sigma = self.sigma.reshape((self.n_layer, self.n_sounding), order="F")
        # _Jmatrix is sparse matrix
        if PARALLEL:
            self._Jmatrix = sp.block_diag(
                pool.map(run_simulation_FD, [(Sigma[:,i], True) for i in range(self.n_sounding)])
            ).tocsr()
            pool.close()
            pool.join()                    
        else:
            self._Jmatrix = sp.block_diag(
                [run_simulation_FD((Sigma[:,i], True)) for i in range(self.n_sounding)]
            ).tocsr()            
        # Possibility for parallel Jvec and Jtvec
        # self._Jmatrix = pool.map(run_simulation_FD, [(Sigma[:,i], True) for i in range(self.n_sounding)])
        return self._Jmatrix
    
    def Jvec(self, m, v, f=None):
        J = self.getJ(m)
        return J*v        
        # Possibility for parallel Jvec
#         V = v.reshape((self.n_sounding, self.n_layer))        
#         pool = Pool(self.n_cpu)
#         Jv = np.hstack(
#             pool.map(dot, [(J[i], V[i, :]) for i in range(self.n_sounding)])
#         )
#         pool.close()
#         pool.join()                
#         return Jv

    def Jtvec(self, m, v, f=None):
        J = self.getJ(m)
        return J.T*v
        # Possibility for parallel Jtvec        
#         V = v.reshape((self.n_sounding, 6))
#         pool = Pool(self.n_cpu)
#         Jtv = pool.map(dot, [(J[i].T, V[i, :]) for i in range(self.n_sounding)])     
#         pool.close()
#         pool.join()                
#         return reduce(np.add, Jtv)


# class GlobalEM1DSurveyFD(Survey.BaseSurvey):
    
#     rxlocs = None
#     srclocs = None
#     frequency = None
    
#     @Utils.count
#     @Utils.requires('prob')
#     def dpred(self, m=None, f=None):
#         """dpred(m, f=None)

#             Create the projected data from a model.
#             The fields, f, (if provided) will be used for the predicted data
#             instead of recalculating the fields (which may be expensive!).

#             .. math::

#                 d_\\text{pred} = P(f(m))

#             Where P is a projection of the fields onto the data space.
#         """
#         return self.prob.forward(m)
    
#     @property
#     def nD(self):
#         return self.prob.G.shape[0]   
    
#     def read_xyz_data(self, fname):
        

In [2]:
# %%time
# from SimPEG import Mesh
# import numpy as np
# n = [4, 40, 400]
# t = []
# for n_sounding in n:
# # n_sounding = 4
#     start = time.time()
#     n_layer = 22
#     mesh = Mesh.TensorMesh([n_sounding, n_layer])
#     m = np.ones(mesh.nC)
#     prob = GlobalEM1DProblem(mesh, n_cpu=3)
#     pred = prob.forward(m)
#     J = prob.getJ(m)
#     end = time.time()
#     t.append(end-start)

In [6]:
n_sounding = 100
start = time.time()
n_layer = 22
mesh = Mesh.TensorMesh([n_sounding, n_layer])
m = np.ones(mesh.nC)
prob = GlobalEM1DProblem(mesh, n_cpu=2)
pred = prob.forward(m)
J = prob.getJ(m)
prob.Jtvec(m, np.ones(int(6*n_sounding)))
prob.Jvec(m, m)
end = time.time()
print(end-start)

2.6174027919769287


In [13]:
# def run_simulation(
#     rxLoc, SrcLoc, mesh_1d, offset, frequency,
#     field_type = 'secondary',
#     rxType = 'Hz',
#     srcType = 'VMD'
# ):
#     FDsurvey = EM1DSurveyFD()
#     depth = -mesh1D.gridN[:-1]
#     LocSigZ = -mesh1D.gridCC
#     nlay = depth.size
#     topo = np.r_[0., 0., 100.]
#     FDsurvey.depth = depth
#     FDsurvey.topo = topo
#     FDsurvey.LocSigZ = LocSigZ
#     FDsurvey.Nfreq = FDsurvey.frequency.size
#     FDsurvey.Setup1Dsystem()
#     FDsurvey.SetOffset()
#     sig_half = 1e-4
#     sig_blk = 1e-2
#     chi_half = 0.
#     expmap = Maps.ExpMap(mesh1D)
#     sig  = np.ones(nlay)*sig_half
#     blk_ind = (-50>LocSigZ) & (-100<LocSigZ)
#     sig[blk_ind] = sig_blk
#     m_true = np.log(sig)

#     WT0, WT1, YBASE = DigFilter.LoadWeights()
#     prob = EM1D(
#         mesh1D, sigmaMap=expmap, filter_type='key_101',
#         jacSwitch=True,
#         chi= np.zeros(FDsurvey.nlay)
#     )
#     if prob.ispaired:
#         prob.unpair()
#     if FDsurvey.ispaired:
#         FDsurvey.unpair()
#     prob.pair(FDsurvey)    
#     u, dudsig = prob.fields(m_true)
#     resp = FDsurvey.projectFields(u)
#     drespdsig = FDsurvey.projectFields(dudsig)
#     return resp, drespdsig
# !python run_em1d_multiprocessing.py