# -----------------------------------------------------------------------------
# EXPLORING SEISMOGRAMS

This is a script to compare data to synthetics. 
At the top there are a number of parameteres to change:
- What to plot (data/synthetics)
- What frequencies to plot
- What componeent to plot
- What phase to center on
- Which travel times to plot
- Organizing seismograms as a function of distance or azimuth

Code blocks can be run with shift-enter.
NOTE: every time you change something, you need to re-run  that code-block, before rerunnign the last code block (which does the actual plotting)

# -----------------------------------------------------------------------------
# IMPORTS
These blocks only need to be run once (if all goes well)
## Importing libraries

In [None]:
%matplotlib inline

import obspy
import obspy.taup
from obspy import read
from obspy import UTCDateTime
from obspy.taup import TauPyModel
import obspy.signal
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
import geographiclib
import numpy as np
import sys
sys.path.insert(1, './basemap')
from mpl_toolkits.basemap import Basemap
import glob
import time



## Importing data
There is only one event available to you, so don't change this. <br>


In [None]:
event = '20161225'
datalist = read(event +'data.PICKLE')
instaseislist = read(event+'instaseis.PICKLE')
shakemovielist = read(event+'shakemovie.PICKLE')

datalist.resample(5) # Sampling to 5 samples/ sec to speed things up a bit

model = TauPyModel(model='ak135') # Load 1D model for travel time and ray predictions
print('done loading')

# -----------------------------------------------------------------------------
# PLOTTING WAVEFORMS
## Define subplots
This sets the number of subplots. <br>
Choose from 'realdata', 'instaseis', 'shakemovie'

In [None]:
subplots = ['realdata','shakemovie']

## Compontents to plot
Choose 'Z' for vertical, 'T' for transverse, or 'R' for radial <br>
All three could be plotted with e.g. compontents = ['T','R','Z']

In [None]:
components = ['T']
compcolors = ['k', 'r', 'b'] # Colors used to plot components

## Phase to center plot on
if first phase doesn't exist at a specific distance, secondary phase will be tried
(only one backup phase can be given)

In [None]:
center_phase = ['S','Sdiff']

## Predicted travel times to plot
phases follow taup naming conventions, see:  <br>
https://docs.obspy.org/packages/obspy.taup.html <br>
A number of phases are precomputed and stored with the data, others will slow down the script to compute on the fly

In [None]:
tt_phases = ['S','Sdiff','ScS','SKS','SKKS']

## Set frequency limits
or 1/period  <br>
Frequencies up to 0.5 Hz are included. <br>
Note, the fmax is 1/17 Hz for the shakemovie synthetics. <br>

In [None]:
fmin = 1. / 40. #1/period
fmax = 1. / 17. #1/period

## Y-axis: Plot as function of distance or azimuth
Chooose to plot with distance or azimuth <br>
Choose 'dist' or 'az' <br>
'round' sets if seismograms are plotted at values rounded to whole numbers (this generally makes the plot easier to look at). 

In [None]:
yaxis = 'dist'
round = True

## Limits for azimuth, distance and time
-azimuth limits (data is included between 300 and 360 deg) <br>
-distance limits (data is included between 90 and 130 deg) <br>
-Time limits (Window around center phase)

In [None]:
azmin = 325.
azmax = 360.

distmin = 88
distmax = 110

timemin = -10
timemax = 80

## Normalization factor
factor to normalize all data by to improve plot <br>

In [None]:
fact = 1.

## Plotting data
All the plotting happens after this, and in principle doesn't
need to be changed... although you are welcome to play
around/ improve/ break it...

In [None]:

# Colormap for travel times
cNorm = colors.Normalize(vmin=0, vmax=len(tt_phases))
tt_cmap = cm.ScalarMappable(norm=cNorm, cmap=plt.get_cmap('jet'))


# Setting the maximum feasible frequencies when synthetics are included
if 'shakemovie' in subplots and fmax>1./17.:
    print('Warning: max frequency set to 1/17 Hz as anything higher is not valid for shakemovie synthetics')
if fmax > .5:
    print('Warning: max frequency set to 0.5 Hz as anything higher is pre-filtered from the data')


slon = []
slat = []
elon = []
dists =[]
axes = []



fig = plt.figure(figsize = (18,10))
for plotnum in range(len(subplots)):
     axes.append(plt.subplot(1, len(subplots), 1+plotnum))   
# Loop through data types
for plotnum, sub in enumerate(subplots):
    print('plotting ', sub, '...')
    # Loop through components
    for c, comp in enumerate(components):

        if sub == 'realdata':
            seislist = datalist.select(channel='*' + comp)
            selectseis = seislist.copy()
        elif sub == 'instaseis':
            seislist = instaseislist.select(channel='*' + comp)       
            selectseis =  seislist.copy()
        elif sub == 'shakemovie':
            seislist = shakemovielist.select(channel='*' + comp)       
            selectseis = seislist.copy()
            if fmax > 1./17.:
                fmax = 1./17.

        else:
            print(sub, 'option is unknown, choose realdata, instaseis, or shakemovie')

        # Loop through seismic receivers
        relnorm = None
        for s,seis in enumerate(selectseis):

            if seis.stats['az'] > azmin and seis.stats['az'] < azmax:
                if seis.stats['dist'] > distmin and seis.stats['dist'] < distmax:
                    try:
                        center_time = seis.stats.traveltimes[center_phase[0]]
                    except:
                        arrs = model.get_travel_times(seis.stats['evdp'],seis.stats['dist'], phase_list = [center_phase[0]])
                        center_time = arrs[0].time
                    if len(center_phase) > 1:
                        if center_time is None:
                            try:
                                center_time = seis.stats.traveltimes[center_phase[1]]
                            except:
                                arrs = model.get_travel_times(seis.stats['evdp'],seis.stats['dist'], phase_list = [center_phase[1]])
                                center_time = arrs[0].time
                    if center_time is None:
                        continue

                        
                    seis.filter('highpass',freq=fmin,corners=2,zerophase=True)
                    seis.filter('lowpass',freq=fmax,corners=2,zerophase=True)

                    xplot = seis.times(
                        reftime=UTCDateTime(
                            seis.stats['eventtime'])) - center_time
                    win1 = np.argmin(np.abs(xplot - timemin))
                    win2 = np.argmin(np.abs(xplot - timemax))
                                                    
                    if round:
                        yloc = np.round(seis.stats[yaxis])
                    else:
                        yloc = seis.stats[yaxis]
                    if relnorm == None:
                        relnorm = np.max(np.abs(seis.data))
                    yplot = fact * seis.data / (relnorm) + yloc
                    axes[plotnum].plot(xplot[win1:win2], yplot[win1:win2], compcolors[c], linewidth=1)

                    

                    # Plot travel times
                    if c==(len(components)-1) and plotnum == len(subplots)-1:

                        for ph, phase in enumerate(tt_phases):
                            if phase in seis.stats.traveltimes:
                                if seis.stats.traveltimes[phase] is not None:
                                    xtime = seis.stats.traveltimes[phase] - center_time

                                    for i in range(len(subplots)):
                                        axes[i].plot([xtime, xtime], [yloc - 0.5,yloc + 0.5],
                                             color=tt_cmap.to_rgba(ph), linewidth=2)
                            else:
                                arrs = model.get_travel_times(seis.stats['evdp'],seis.stats['dist'], phase_list = (phase,))
                                if len(arrs)>0:
                                    for arr in arrs:
                                        xtime = arr.time - center_time
                                        for i in range(len(subplots)):
                                            axes[i].plot([xtime, xtime], [yloc - 0.5,yloc + 0.5],
                                                 color=tt_cmap.to_rgba(ph), linewidth=2)
                                            
                        #Save locations to plot
                        if elon == []:
                            elon = seis.stats['evlo']
                            elat = seis.stats['evla']
                            edep = seis.stats['evdp']
                        slon.append(seis.stats['stlo'])
                        slat.append(seis.stats['stla'])
                        dists.append(seis.stats['dist'])

# plot some fake points to make an okay legend
for ph, phase in enumerate(tt_phases):
    plt.plot(0, 0, '.', color=tt_cmap.to_rgba(ph), label=phase)
plt.legend()

# prettify plots
for i in range(len(subplots)):
    axes[i].set_xlabel('time (s) w.r.t.' + center_phase[0])
    if yaxis == 'dist':
        if i ==0:
            axes[i].set_ylabel('distance (dg)')
        axes[i].set_ylim([distmin, distmax])
        axes[i].set_title(str(azmin) + ' < azimuth < ' + str(azmax))
    if yaxis == 'az':
        if i ==0:
            axes[i].set_ylabel('azimuth (dg)')
        axes[i].set_ylim([azmin, azmax])
        axes[i].set_title(str(distmin) + ' < distance < ' + str(distmax))
    axes[i].set_xlim([timemin, timemax])
    axes[i].grid()



# -----------------------------------------------------------------------------
# PLOTTING RAY PATHS
Plotting ray paths for chosen phases.

In [None]:
fig= plt.figure(figsize = (10,6))
ax = fig.add_subplot(1, 1, 1, polar=True)
# Phases to plot, e.g.  tt_phaces, center_phase, or a user-defined list ["PKJKP", "SKKS"]
plot_phases =tt_phases

widths = [1,1,1,1,1,1]
# depth of earthquake in km
depth_earthquake=edep
# Distances to compute, currently set to plot a ray at the minimum and maximum distance included in the data set.
dist_raypaths = [np.min(dists), np.max(dists)]


# computing ray paths
for d, dist in enumerate(dist_raypaths):
    if d == 0:
        arrivals= model.get_ray_paths(depth_earthquake, dist, phase_list= plot_phases)
    else:
        arrivals= arrivals+model.get_ray_paths(depth_earthquake, dist, phase_list= plot_phases)


# plotting

arrivals.plot_rays(ax = ax, legend = True, plot_all=False)



# -------------------------------------------------------------------------------
# PLOT TOMOGRAPHIC MODEL AND PIERCE POINTS
## Load depth slice
Model provided to plot on the map is S362ANI+M (Moulik & Ekstrom, 2014). This model is an update of the model S362ANI (Kustowski et al. 2008) which is used as the background model for the 3D shakemovie synthetics. <br>
Choose which depth to plot. This can be 25, 50, 75, 100, 125, 150, 200, 250, 300, 350, 400, 500, 600, 700, 800, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 2800, or 2890 km. <br>
Choose what to plot 'vsv', 'vsh', or 'voigt'. <br>
Choose 'absolute' or 'relative'. 

In [None]:
model3D = 'S362ANI+M' # No other model currently provided
depth_mod = 2800
to_plot = 'voigt'
abs_or_rel = 'relative'

# Load model
mod = np.loadtxt(model3D+'/'+model3D+'.'+to_plot+'.'+str(int(depth_mod))+'.txt', 
           dtype={'names': ('lon', 'lat', 'vel'),'formats': ('f4', 'f4', 'f4')})  

# Reshape for plotting
len1= len(np.unique(mod['lon']))
len2= len(np.unique(mod['lat']))
lon = mod['lon'].reshape((len2,len1))
lat = mod['lat'].reshape((len2,len1))
vel = mod['vel'].reshape((len2,len1))
print('model loaded')

## Compute pierce points
For given phase(s) and a depth


In [None]:
pp_depth = depth_mod # Depth to compute pierce point at
pp_phases = center_phase # Phase(s) for piercepoints, can be set to center_phase


downgoing =[]
upgoing=[]
for d in range(len(dists)):
        # Compute ray paths
        arrs = model.get_ray_paths_geo(source_depth_in_km=edep, 
                                          source_latitude_in_deg=elat,
                                          source_longitude_in_deg=elon,
                                          receiver_latitude_in_deg=slat[d],
                                          receiver_longitude_in_deg=slon[d],
                                          phase_list=pp_phases)
        for A in arrs:
            p=A.path
            cut = np.argmax(p['depth'])
            # Only select rays that go deep enough
            if pp_depth < np.max(p['depth']):
                # Interpolate downgoing pierce point
                downgoing.append((np.interp(pp_depth, p['depth'][:cut],p['lon'][:cut] ), 
                             np.interp(pp_depth, p['depth'][:cut],p['lat'][:cut])))

                # Interpolate upgoing pierce point           
                upgoing.append((np.interp(pp_depth, p['depth'][cut:][::-1],p['lon'][cut:][::-1] ), 
                             np.interp(pp_depth, p['depth'][cut:][::-1],p['lat'][cut:][::-1])))

downgoing=np.asarray(downgoing)
upgoing=np.asarray(upgoing)
print('done computing pierce points')

## Plotting Map
Uses the selection of data from above

In [None]:
plot_tomography = True # Choose to plot background seismic model
plot_piercepoints = True # Choose to plot pierce points computed above

# Plot map background
plt.figure(figsize = (10,8))
m = Basemap(projection='ortho',lat_0=20.,lon_0=-110,resolution='l')
m.drawcoastlines()
m.drawcountries()


# Plot tomographic model
if plot_tomography:
    x,y = m(lon,lat)
    if abs_or_rel == 'relative':
        ref = np.mean(vel.ravel())
        rel = (vel-ref)/ref*100.
        lim = np.ceil(np.max(np.abs(rel)))
        levels = np.linspace(-lim,lim,33)
        m.contourf(x,y,rel,levels=levels,cmap=cm.seismic_r,linewidth=0)
    if abs_or_rel == 'absolute':
        m.contourf(x,y,vel,33, cmap=cm.viridis_r,linewidth=0)    

    plt.colorbar()
    plt.title('model ' + model3D + ' sliced at ' +str(int(depth_mod)) + 'km', fontsize=18)


#Plot pierce points
if plot_piercepoints:
    x2,y2=m(downgoing[:,0], downgoing[:,1])
    m.scatter(x2,y2,s=35,c='y',marker='o',alpha=1)
    x2,y2=m(upgoing[:,0], upgoing[:,1])
    m.scatter(x2,y2,s=35,c='c',marker='o',alpha=1)

#Plot event
x2,y2=m(elon+360.,elat)
m.scatter(x2,y2,s=600,marker='*',facecolors='y',alpha=1)
#Plot stations
slon=[slon[x]+360. if slon[x]<0 else slon[x]  for x in range(len(slon))]
x2,y2=m(slon,slat)
m.scatter(x2,y2,s=35,c='m',marker='^',alpha=1)
