In [None]:
import obspy
from obspy import UTCDateTime as utct
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec
from ipywidgets import interact, interact_manual
from taup_distance import taup_distance
from obspy.taup import TauPyModel
import matplotlib.dates as mdates
from helpers.create_specgrams import origin_time
model = TauPyModel('model_data/model4taup10.npz')

'''
st = obspy.read(f'event_data/{event}/waveforms/waveforms_VBB.mseed')
st.filter('highpass', freq=1./10)
'''

TT = np.load('helpers/traveltimes.npz')

all_st = dict()
all_spec_all = dict()
all_f_spec = dict()
all_t_spec = dict()

for event in origin_time.keys():
    st = obspy.read(f'helpers/{event}.mseed')
    st.integrate()
    
    spec_all = dict()
    for comp in ['Z', 'N', 'E']:
        data = np.load(f'helpers/{event}_spec_XB.ELYSE.02.BH{comp}.npz')
        f_spec = data['f']
        t_spec = data['t']
        spec_all[comp] = data['spec']
        
    all_f_spec[event] = f_spec
    all_t_spec[event] = t_spec
    all_spec_all[event] = spec_all
    all_st[event] = st

In [None]:
def plot_waveform(event='S0235b',
                  t_P=150, t_S=200, plot_spec=False):
    f_spec = all_f_spec[event]
    t_spec = all_t_spec[event]
    spec_all = all_spec_all[event]
    st = all_st[event]
    
    fig = plt.figure(constrained_layout=True)
    print(event)
    gs = GridSpec(3, 3, figure=fig)
    ax_Z = fig.add_subplot(gs[0, 0:2])
    ax_N = fig.add_subplot(gs[1, 0:2]) #, sharex=ax_Z)
    ax_E = fig.add_subplot(gs[2, 0:2]) #, sharex=ax_Z)
    ax_dist = fig.add_subplot(gs[0, 2])

    ax_time = ax_Z.twiny()
    if plot_spec:
        for comp, ax in zip(['Z', 'N', 'E'], (ax_Z, ax_N, ax_E)):
            spec = 20 * np.log10(spec_all[comp]) 
            ax.pcolormesh(t_spec, f_spec, spec,
                          vmin=-210, 
                          vmax=np.percentile(spec, q=90))
            ax.set_yscale('log')
            ax.set_xlim(100, 900)
        ax_time.set_xlim(utct(origin_time[event] + 100).datetime, 
                    utct(origin_time[event] + 900).datetime)
    else:
        for tr, ax in zip(st, (ax_Z, ax_N, ax_E)):
            ax.plot(tr.times(), tr.data*1e9, lw=0.5)
            ax.set_xlim(100, 900)
            
        ax_time.set_xlim(utct(origin_time[event] + 100).datetime, 
                         utct(origin_time[event] + 900).datetime)


    ax_Z.axvline(t_P, c='r')
    ax_N.axvline(t_S, c='r')
    ax_E.axvline(t_S, c='r')
    # make unnecessary labels disappear
    for ax in [ax_Z, ax_N]:
        ax.set_xticks([])
        plt.setp(ax.get_xticklabels(), visible=False)
    plt.setp(ax_Z.get_xticklabels(), visible=False)

    
    locator = mdates.AutoDateLocator(minticks=4, maxticks=7)
    formatter = mdates.ConciseDateFormatter(locator)
    ax_time.xaxis.set_major_locator(locator)
    
    dist = taup_distance.get_dist(model=model, tSmP=t_S-t_P, depth=50)
    
    if dist is not None:
        t_P_theo = model.get_travel_times(distance_in_degree=dist, 
                                          source_depth_in_km=50, 
                                          phase_list=['P'])[0].time
        t_S_theo = model.get_travel_times(distance_in_degree=dist, 
                                          source_depth_in_km=50, 
                                          phase_list=['S'])[0].time

        ax_dist.plot(dist, t_S_theo, 'o')
        ax_dist.plot(dist, t_P_theo, 'o')
        ax_dist.plot([dist, dist], [t_P_theo, t_S_theo], 'k')
        ax_dist.text(x=20, y=50, s=f'distance: {dist:4.1f} deg')
    for phase in ['P', 'S', 'PP', 'SS', 'ScS']:
        t = np.asarray(TT[phase])
        ax_dist.plot(t[:,0], t[:,1], label=phase)
    ax_dist.set_xlim(0, 100)
    ax_dist.set_ylim(0, 1000)
    ax_dist.xaxis.tick_top() 
    ax_dist.xaxis.set_label_position('top') 
    ax_dist.yaxis.tick_right()
    ax_dist.yaxis.set_label_position('right') 
    ax_dist.set_xlabel('distance')
    ax_dist.set_ylabel('t$_S$ - t$_P$')
    fig.tight_layout(pad=0.01)
    return ax_Z, ax_N, ax_E


_ = interact(plot_waveform,
             event=origin_time.keys(),
            t_P=(100, 900),
            t_S=(100, 900), 
            plot_spec=False)

In [None]:
origin_time.keys()