In [None]:
import os
import pdb
import pywt
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from obspy import read
import pickle
import re

from scipy.ndimage import median_filter
from sklearn.preprocessing import MinMaxScaler
from obspy.signal.filter import bandpass

import tensorflow as tf
from sklearn.decomposition import PCA
from tensorflow.keras.models import Sequential
from keras.layers import Conv2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Dense, GlobalAveragePooling2D
from sklearn.model_selection import train_test_split

In [None]:
lunar_cat = 'data/lunar/training/catalogs/apollo12_catalog_GradeA_final.csv'
lunar_df = pd.read_csv(lunar_cat)

lunar_event = {}
for index, row in lunar_df.iterrows():

    filename = row['filename']
    cleaned_filename = re.match(r"^.*\d{4}-\d{2}-\d{2}", filename).group()
    
    detection_time = row['time_rel(sec)']
    
    if cleaned_filename in lunar_event:
        lunar_event[cleaned_filename].append(detection_time)
    else:
        lunar_event[cleaned_filename] = [detection_time]

In [None]:
lunar_event

In [None]:
def wavelet_transform(data, wavelet, scales):
    coefficients, frequencies = pywt.cwt(data, scales, wavelet)
    return np.abs(coefficients), frequencies

In [None]:
def plot_trace(tr_times, tr_data, arrival, coefficients, scales):
    fig, (ax, ax2) = plt.subplots(2, 1, figsize=(10, 6))

    ax.plot(tr_times,tr_data)

    ax.axvline(x = arrival, color='red',label='Rel. Arrival')
    ax.legend(loc='upper left')

    ax.set_xlim([min(tr_times),max(tr_times)])
    ax.set_ylabel('Velocity (m/s)')
    ax.set_title('Seismic Trace', fontweight='bold')

    ax2.imshow(coefficients, extent=[tr_times.min(), tr_times.max(), scales.min(), scales.max()],
               aspect='auto', interpolation='bilinear', cmap='jet')
    ax2.set_ylabel('Scales')
    ax2.set_xlabel('Time (s)')
    ax2.set_title('Wavelet Coefficients', fontweight='bold')

    plt.tight_layout()
    plt.show()

In [None]:
def preprocess_mseed(file_path, arrival, minfreq=0.5, maxfreq=1.5):
    wavelet='cmor1.0-0.5' #adjust the center frequency (0.5 to 10 Hz) and bandwidth (around 1.5)
    scales = np.arange(1, 20)

    st = read(file_path)    
    tr = st[0]
    if np.count_nonzero(np.isnan(tr.data)) > 0:
        print(f"Warning: Missing values found in {tr.id}. Interpolation may be needed.")
        tr.interpolate(method='linear', tolerance=0.1, sampling_rate=tr.stats.sampling_rate)
    
    data = tr.data.reshape(-1, 1)
    tr_times = tr.times()

    tr.data = data.flatten()
    tr.filter("bandpass", freqmin=minfreq, freqmax=maxfreq, corners=4, zerophase=True)
    
    coefficients, frequencies = wavelet_transform(tr.data, wavelet, scales)
    
    # plot_trace(tr_times, tr.data, arrival, coefficients, scales) #Use this to check the plot for Seismic trace and Waveleet Coefficients
    return tr.data, coefficients, os.path.basename(file_path), tr_times

In [None]:
def process_file(file_path, event_time):
    try:
        filtered_data, wavelet_coefficients, filename, timeline = preprocess_mseed(file_path, event_time)
        return filename, filtered_data, wavelet_coefficients, timeline
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None

In [None]:
def clean_filename(filename):
    match = re.match(r"^.*\d{4}-\d{2}-\d{2}", filename)
    if match:
        return match.group()  
    else:
        return filename  

In [None]:
def filter_wavelet_coefficients_with_time(coefficients, timeline, threshold=0.2):
    # Compute the threshold based on a percentage of the maximum value in the coefficients
    max_value = np.max(np.abs(coefficients))
    significant_mask = np.abs(coefficients) >= (threshold * max_value)
    
    # Retain only coefficients that have significant values
    filtered_coefficients = coefficients * significant_mask
    
    # Keep only the relevant time indices where significant coefficients exist
    relevant_time_steps = [timeline[i] for i in range(coefficients.shape[1]) if np.any(significant_mask[:, i])]
    filtered_coefficients = filtered_coefficients[:, significant_mask.any(axis=0)]
    
    return filtered_coefficients, relevant_time_steps

In [None]:
mseed_training_directory = 'data/lunar/training/data/S12_GradeA'
mseed_testing_directory = 'data/lunar/test/data/'

'''
Looping through each mseed file to apply data preprocessing.
'''


cnn_train_input = []
train_labels = []

cnn_test_input = []
for filename in os.listdir(mseed_training_directory):
    if filename.endswith(".mseed"):
        file_path = os.path.join(mseed_training_directory, filename)
        print(f"Processing file: {filename}")
        cleaned_filename = clean_filename(filename)
        filename, filtered_data, wavelet_coefficients, timeline = process_file(file_path, lunar_event[cleaned_filename])
        
        # filtered_coefficients, relevant_time_steps = filter_wavelet_coefficients_with_time(wavelet_coefficients, timeline)
        

        # cnn_train_input.append(filtered_coefficients)
        cnn_train_input.append(wavelet_coefficients)
        # Assuming the label is the seismic event arrival time, which should be within the relevant time steps
        arrival_times = lunar_event[cleaned_filename]  # Adjust according to your label
        train_labels.append(arrival_times)

for directory in os.listdir(mseed_testing_directory):
    for filename in os.listdir(os.path.join(mseed_testing_directory, directory)):
        if filename.endswith(".mseed"):
            file_path = os.path.join(mseed_testing_directory, directory, filename)
            print(f"Processing file: {filename}")
            cleaned_filename = clean_filename(filename)
            filename, filtered_data, wavelet_coefficients, timeline = process_file(file_path, 0)
            
            # filtered_coefficients, relevant_time_steps = filter_wavelet_coefficients_with_time(wavelet_coefficients, timeline)
            

            # cnn_test_input.append(filtered_coefficients)
            cnn_test_input.append(wavelet_coefficients)
    break

# After looping through all files, determine the maximum time steps
max_time_steps = max(coeff.shape[1] for coeff in cnn_train_input)
max_time_steps_test = max(coeff.shape[1] for coeff in cnn_test_input)

# Pad the coefficients
padded_coefficients = []
for coeff in cnn_train_input:
    padded = np.pad(coeff, ((0, 0), (0, max_time_steps - coeff.shape[1])), mode='constant', constant_values=0)
    padded_coefficients.append(padded)

padded_test_coefficients = []
for coeff in cnn_test_input:
    padded = np.pad(coeff, ((0, 0), (0, max_time_steps_test - coeff.shape[1])), mode='constant', constant_values=0)
    padded_test_coefficients.append(padded)


# Convert to NumPy arrays
cnn_train_input = np.array(padded_coefficients)
cnn_test_input = np.array(padded_test_coefficients)


In [None]:
max_events = max(len(arrival) for arrival in train_labels)  # Determine the maximum number of events
train_labels_padded = []

for arrival in train_labels:
    padded = np.pad(arrival, (0, max_events - len(arrival)), 'constant')
    train_labels_padded.append(padded)

train_labels = np.array(train_labels_padded)

In [None]:
print("cnn_train_input shape:", cnn_train_input.shape)  # Should be (num_samples, 19, max_time_steps)
print("train_labels shape:", train_labels.shape)  # Should be (num_samples, max_events)

# Reshape the input for CNN (num_samples, height, width, channels)
cnn_train_input = cnn_train_input.reshape(cnn_train_input.shape[0], cnn_train_input.shape[1], cnn_train_input.shape[2], 1)

In [None]:
model = Sequential()
# First Convolutional Layer
model.add(Conv2D(16, (3, 3), activation='relu', input_shape=(19, 572427, 1)))
model.add(MaxPooling2D(pool_size=(1, 10)))
model.add(Dropout(0.5))
model.add(BatchNormalization())  # Added Batch Normalization

# Second Convolutional Layer
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(1, 10)))  # Further downsampling
model.add(Dropout(0.5))
model.add(BatchNormalization())  # Added Batch Normalization

# Global Average Pooling Layer
model.add(GlobalAveragePooling2D())  # Use global average pooling

# Fully Connected Layers
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))  # Dropout to reduce overfitting

# Output Layer
model.add(Dense(1, activation='linear')) 

In [None]:
model.summary()

In [None]:
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae'])

history = model.fit(cnn_train_input, train_labels, epochs=10, batch_size=5, validation_split=0.2)

In [None]:
predictions = model.predict(cnn_test_input)

# Print predicted arrival times
print(predictions)

In [None]:
model.summary