In [None]:
import sys
sys.path.append('../')
import matplotlib.pyplot as plt
import numpy as np
import pygmt
from remit.utils.grid import make_dataarray, coeffs2map

from remit.data.models import load_lcs, load_ocean_age_model, load_vis_model, create_vim

%load_ext autoreload
%autoreload 2


In [None]:
# A single cell to construct the VIM model
ocean = load_ocean_age_model()

vis = load_vis_model(match=(ocean.lon, ocean.lat, ocean.age))

#pol = PolarityTimescale()
#polarity_da = make_dataarray(ocean.lon, ocean.lat, pol.Interpolator(ocean.age))
#polarity_da.data[np.isnan(ocean.age.data)] = np.nan


In [None]:
def plot_mag_map(fig, rad, projection, region, cmax, frame=['afg', 'WSne']):

    radi = pygmt.grdsample(grid=rad, spacing='0.05d')
    pygmt.config(COLOR_FOREGROUND='red', COLOR_BACKGROUND='blue')
    pygmt.makecpt(cmap='polar', series='-{:f}/{:f}'.format(float(cmax), float(cmax)), reverse=True)
    radi.to_netcdf('./grids/_tmp.nc')
    fig.grdimage(region=region, projection=projection, 
                 grid='./grids/_tmp.nc', transparency=10)

    pygmt.config(COLOR_FOREGROUND='white', COLOR_BACKGROUND='black')
    pygmt.makecpt(cmap='gray', series='-60/60', reverse=True)

    fig.plot(data='./gis/Muller_etal_AREPS_2016_C13y.gmt', pen='0.6p,black')
    fig.plot(data='./gis/Muller_etal_AREPS_2016_C34.gmt', pen='0.6p,black')
    fig.plot(data='./gis/Muller_etal_AREPS_2016_M0.gmt', pen='0.6p,black')
    fig.plot(data='./gis/Muller_etal_AREPS_2016_Ridges.gmt', pen='1p,black,-')

    fig.coast(shorelines='0.5p,gray20', resolution='l', area_thresh=5000.,
              region=region, projection=projection, land='gray', transparency=20)
    fig.basemap(frame=frame, region=region, projection=projection)

    

def PlotPanels(fig, projection, region, figsize, cmax, altitude=0):
    
    (model_1, model_2, 
     model_3, model_4, model_5) = models
    
    with fig.subplot(nrows=2, ncols=3, figsize=figsize, frame="lrtb", autolabel='+gwhite+pblack', margins='0.15i'):
        with fig.set_panel(panel=0, fixedlabel='LCS'):
            plot_mag_map(fig, lcs_rad, projection=projection, region=region, cmax=cmax)
        with fig.set_panel(panel=1, fixedlabel=model_1[0]):
            plot_mag_map(fig, model_1[1], projection=projection, region=region, cmax=cmax)
        with fig.set_panel(panel=2, fixedlabel=model_2[0]):
            plot_mag_map(fig, model_2[1], projection=projection, region=region, cmax=cmax)
        with fig.set_panel(panel=3, fixedlabel=model_3[0]):
            plot_mag_map(fig, model_3[1], projection=projection, region=region, cmax=cmax)
        with fig.set_panel(panel=4, fixedlabel=model_4[0]):
            plot_mag_map(fig, model_4[1], projection=projection, region=region, cmax=cmax)
        with fig.set_panel(panel=5, fixedlabel=model_5[0]):
            plot_mag_map(fig, model_5[1], projection=projection, region=region, cmax=cmax)

                  

In [None]:
lmax = 140

lcs = load_lcs(lmax=lmax)
lcs_rad = lcs.expand(a=lcs.r0).rad.to_xarray()


In [None]:
projection='A-105/-25/?'
region = '-150/-60/-92/2r'
figsize="12i", "17i"
altitude = 0.
lmax = 130
cmax = 50



fig = pygmt.Figure()
    
with fig.subplot(nrows=4, ncols=5, figsize=figsize, frame="lrtb", autolabel='+gwhite+pblack', margins='0.15i'):
    with fig.set_panel(panel=0, fixedlabel='LCS'):
        plot_mag_map(fig, lcs_rad, projection=projection, region=region, cmax=cmax, frame=['af', 'Wsne'])
        
    panel = 1
    for i,lmbda in enumerate([0.1,0.2,1,5]):
        for j,P in enumerate([0,1,2,3,4]): 
            if (i==0 and j==0):
                continue
            totalvim = create_vim(ocean, vis, seafloor_layer='2d',
                          layer_boundary_depths=[0,500,1500,6500], 
                          layer_weights=[5,2.3,1.2], MagMax=None, P=P, lmbda=lmbda, Mtrm=1, Mcrm=0)
            vsh, coeffs = totalvim.transform(lmax=lmax)
            model_rad = coeffs2map(coeffs, altitude=altitude, lmax=lmax, lmin=16)
            
            labels='wsne'
            if panel in [0,5,10,15]:
                labels = labels.replace('w', 'W')
            if panel>=15:
                labels = labels.replace('s', 'S')
            with fig.set_panel(panel=panel, fixedlabel='P={:d}, lambda={:0.1f}'.format(P,lmbda)):
                plot_mag_map(fig, model_rad.to_xarray(), projection=projection, region=region, cmax=cmax, frame=['af', labels])
                
            panel+=1

fig.savefig('./figures/EPacificRise_MOR_P_vs_Lambda_{:d}km_lmax{:d}.png'.format(int(altitude/1000),int(lmax)))
fig.show(width=1000)
