In [None]:
import concurrent.futures
from functools import partial
from glob import glob
import os

import matplotlib.pyplot as plt
import numpy as np
from obspy import Stream
from pysep.utils.io import read_sem

In [None]:
write = False

In [None]:
map = 'Brown'
database_type = 'SymGroups'
#database_type = 't20'

In [None]:
if database_type == 'SymGroups':
    models = ['ISO', 'XISO']
    #models = ['MONO', 'TRIV']    # choose only two model types
elif database_type == 't20':
    models = ['t80', 't100']    # choose only two model types

database_loc = f'/scratch/agupta7/specfem/rectangular_grid/{map}_{database_type}' 

source_file = f'{database_loc}/{models[0]}/OUTPUT_FILES/CMTSOLUTION'
stations_file = f'{database_loc}/{models[0]}/OUTPUT_FILES/STATIONS_FILTERED'

time_length_s = 40 # corresponds to -1s to 39s of data

In [None]:
st_list = []
workers = os.cpu_count()
read_sem_new = partial(read_sem, source=source_file, stations=stations_file)

for model in models:   
    print(f'loading data for model: {model}')
    
    fids = []
    for fid in glob(f'{database_loc}/{model}/OUTPUT_FILES/seismograms/*R*'):
        fids.append(fid)
    
    with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
        data = list(executor.map(read_sem_new, fids))
    
    st = Stream()
    for st_data in data:
        st += st_data
    st_list.append(st)

In [None]:
for st in st_list:
    for tr in st:
        t1 = tr.stats.starttime
        t2 = t1 + time_length_s
        tr.trim(t1,t2)
        
st_mono = st_list[0].copy().differentiate()
st_triv = st_list[1].copy().differentiate()

In [None]:
stations_info = np.genfromtxt(stations_file, dtype=str)

stations = stations_info[:,0]
latitudes = np.array([float(i) for i in stations_info[:,2]])
longitudes = np.array([float(i) for i in stations_info[:,3]])

mask = np.char.startswith(stations, 'C')

stations = stations[~mask]
latitudes = latitudes[~mask]
longitudes = longitudes[~mask]

integral = np.zeros(len(stations))
for i, station in enumerate(stations):
    
    st1 = st_mono.select(station=station)
    st2 = st_triv.select(station=station)
    
    velocity = np.zeros(st1[0].stats.npts)
    for component in ['Z','Y','X']:
        v1 = st1.select(component=component)[0].data
        v2 = st2.select(component=component)[0].data
        velocity += (v1 - v2) ** 2
    velocity = np.sqrt(velocity)
    integral[i] = np.trapz(velocity)

In [None]:
indices = np.argsort(integral)[::-1]

sorted_stations = stations[indices]
sorted_integrals = integral[indices]
sorted_latitudes = latitudes[indices]
sorted_longitudes = longitudes[indices]

In [None]:
components = ['Z','Y','X']
vp = np.sqrt(118.3333E9/3378)
vs = np.sqrt(41.4333E9/3378)

for station in sorted_stations[:1]:

    st1 = st_mono.select(station=station)
    st2 = st_triv.select(station=station)
    
    x1 = st1[0].stats.sac['stlo']
    y1 = st1[0].stats.sac['stla']
    x2 = st1[0].stats.sac['evlo']
    y2 = st1[0].stats.sac['evla']
    
    evdp = 75000 # should be read in from the sac header
    
    source_station_distance = np.sqrt( (y2-y1)**2 + (x2-x1)**2 + evdp**2 )
    
    tp = source_station_distance / vp
    ts = source_station_distance / vs
    
    fig, axs = plt.subplots(3,1,figsize=(10,20)) 
    max_velocity = 0
    for i, component in enumerate(components):
        tr1 = st1.select(component=component)[0]
        tr2 = st2.select(component=component)[0]
        axs[i].plot(tr1.times(), tr1.data, c='r', label=f'{models[0]}')
        axs[i].plot(tr2.times(), tr2.data, c='k', label=f'{models[1]}') 
        axs[i].axvline(x=tp, color='b', linestyle='--')
        axs[i].axvline(x=ts, color='b', linestyle='--')
        axs[i].set_title(f'component = {component}')
        axs[i].set_xlabel('time (s)')
        axs[i].set_ylabel('velocity (m/s)')
        axs[i].legend()
        max_velocity = max([max_velocity, max(tr1.data, key=abs), max(tr2.data, key=abs)]) 
    
    for i, _ in enumerate(components):
        axs[i].set_ylim([-1.1 * max_velocity, 1.1 * max_velocity])
        
    fig.suptitle(f'{station}')
    
    if write: plt.savefig(f'max_variability_seismograms_{database_type}.png', bbox_inches='tight')
        
    plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(sorted_longitudes, sorted_latitudes, c=sorted_integrals, s=30)
plt.colorbar()
plt.axis('equal')

if write: plt.savefig(f'variability_scatter_{models[0]}_vs_{models[1]}.png', bbox_inches='tight')
    
plt.show()