创建自定义数据集类

In [2]:
import seisbench as sb
help(sb.data)

AttributeError: module 'seisbench' has no attribute 'data'

In [None]:
import seisbench.data as sbd
import numpy as np
import pandas as pd
import os
import torch
from obspy import read, UTCDateTime
from obspy.clients.fdsn import Client

class ObspyDataset(sbd.WaveformDataset):
    """使用ObsPy获取的数据创建自定义SeisBench数据集"""
    
    def __init__(self, waveforms=None, metadata=None, path=None):
        """
        初始化数据集
        
        Args:
            waveforms: 波形字典 {trace_name: ndarray}
            metadata: 元数据DataFrame
            path: 缓存路径
        """
        super().__init__()
        
        if waveforms is not None and metadata is not None:
            self.metadata = metadata
            self.waveforms = waveforms
        elif path is not None:
            # 如果提供了路径，尝试从缓存加载
            self._load_dataset(path)
        else:
            self.metadata = pd.DataFrame()
            self.waveforms = {}
    
    def from_obspy_events(self, catalog, client_name="IRIS", before_p=30, after_p=90, 
                          component_map={"Z": 0, "N": 1, "E": 2}, distance_range=(0, 100)):
        """
        从ObsPy事件目录创建数据集
        
        Args:
            catalog: ObsPy事件目录
            client_name: FDSN客户端名称
            before_p: P波到时前的秒数
            after_p: P波到时后的秒数
            component_map: 分量映射
            distance_range: 震中距范围 (km)
        """
        client = Client(client_name)
        metadata_records = []
        waveforms = {}
        
        for event in catalog:
            event_id = str(event.resource_id).split('/')[-1]
            origin = event.preferred_origin() or event.origins[0]
            magnitude = event.preferred_magnitude() or event.magnitudes[0]
            
            # 获取一些台站
            try:
                stations = client.get_stations(
                    starttime=origin.time - 60,
                    endtime=origin.time + 60,
                    latitude=origin.latitude,
                    longitude=origin.longitude,
                    minradius=distance_range[0] / 111.19,  # 转换为度
                    maxradius=distance_range[1] / 111.19,  # 转换为度
                    level="station"
                )
                
                # 对每个台站
                for network in stations:
                    for station in network:
                        dist_km = origin.distance(
                            latitude=station.latitude,
                            longitude=station.longitude
                        ) * 111.19  # 转换为km
                        
                        # 计算理论P波到时
                        # 简化计算：假设速度为6km/s
                        p_travel_time = dist_km / 6.0
                        p_time = origin.time + p_travel_time
                        s_time = origin.time + p_travel_time * 1.73  # 假设Vs/Vp = 1/1.73
                        
                        # 获取波形
                        try:
                            start_time = p_time - before_p
                            end_time = p_time + after_p
                            
                            st = client.get_waveforms(
                                network=network.code,
                                station=station.code,
                                location="00",  # 主要位置
                                channel="BH?",  # 宽频带高增益
                                starttime=start_time,
                                endtime=end_time
                            )
                            
                            if len(st) >= 3:  # 确保有3个分量
                                # 预处理
                                st.detrend("demean")
                                st.filter("bandpass", freqmin=1, freqmax=45)
                                st.trim(start_time, end_time, pad=True, fill_value=0)
                                
                                # 统一采样率
                                target_sampling_rate = 100
                                st.interpolate(sampling_rate=target_sampling_rate)
                                
                                # 确保所有跟踪有相同数量的样本
                                min_npts = min([tr.stats.npts for tr in st])
                                for tr in st:
                                    tr.data = tr.data[:min_npts]
                                
                                # 创建三分量数据
                                three_comp = np.zeros((3, min_npts))
                                for tr in st:
                                    comp = tr.stats.channel[-1]
                                    if comp in component_map:
                                        three_comp[component_map[comp]] = tr.data
                                
                                # 归一化
                                three_comp = three_comp / np.max(np.abs(three_comp), axis=1, keepdims=True)
                                
                                # 计算相对于波形起始的样本点
                                p_sample = int((p_time - start_time) * target_sampling_rate)
                                s_sample = int((s_time - start_time) * target_sampling_rate)
                                
                                # 创建唯一的跟踪名称
                                trace_name = f"{event_id}_{network.code}_{station.code}"
                                
                                # 存储波形
                                waveforms[trace_name] = three_comp
                                
                                # 添加元数据
                                metadata_records.append({
                                    "trace_name": trace_name,
                                    "trace_sampling_rate": target_sampling_rate,
                                    "trace_start_time": start_time.isoformat(),
                                    "p_arrival_sample": p_sample,
                                    "s_arrival_sample": s_sample,
                                    "source_magnitude": magnitude.mag,
                                    "source_magnitude_type": magnitude.magnitude_type,
                                    "source_latitude": origin.latitude,
                                    "source_longitude": origin.longitude,
                                    "source_depth_km": origin.depth / 1000,  # 转换为km
                                    "source_distance_km": dist_km,
                                    "network": network.code,
                                    "station": station.code
                                })
                                
                        except Exception as e:
                            print(f"获取波形错误: {e}")
                            continue
            
            except Exception as e:
                print(f"获取台站错误: {e}")
                continue
        
        # 创建元数据DataFrame
        self.metadata = pd.DataFrame(metadata_records)
        self.waveforms = waveforms
        
        print(f"已创建包含 {len(self.metadata)} 个样本的数据集")
        return self
    
    def save(self, path):
        """保存数据集到磁盘"""
        if not os.path.exists(path):
            os.makedirs(path)
            
        # 保存元数据
        self.metadata.to_csv(os.path.join(path, "metadata.csv"), index=False)
        
        # 保存波形为NPZ文件
        np.savez(os.path.join(path, "waveforms.npz"), **self.waveforms)
        
        print(f"数据集已保存到 {path}")
    
    def _load_dataset(self, path):
        """从磁盘加载数据集"""
        # 加载元数据
        self.metadata = pd.read_csv(os.path.join(path, "metadata.csv"))
        
        # 加载波形
        waveforms_file = np.load(os.path.join(path, "waveforms.npz"))
        self.waveforms = {key: waveforms_file[key] for key in waveforms_file.files}
        
        print(f"从 {path} 加载了包含 {len(self.metadata)} 个样本的数据集")
    
    def __getitem__(self, idx):
        """获取单个样本"""
        row = self.metadata.iloc[idx]
        trace_name = row["trace_name"]
        
        # 创建样本字典
        sample = {
            "X": self.waveforms[trace_name],  # 波形数据
            "p_arrival_sample": row["p_arrival_sample"],
            "s_arrival_sample": row["s_arrival_sample"],
        }
        
        # 添加其他元数据
        for col in self.metadata.columns:
            if col != "trace_name" and col not in sample:
                sample[col] = row[col]
                
        return sample
    
    def __len__(self):
        """数据集长度"""
        return len(self.metadata)

使用ObsPy获取数据并创建数据集

In [None]:
from obspy import UTCDateTime
from obspy.clients.fdsn import Client
import seisbench.models as sbm
import matplotlib.pyplot as plt
import numpy as np
import torch

# 1. 使用ObsPy获取地震目录
print("获取地震目录...")
client = Client("USGS")
starttime = UTCDateTime("2019-07-01")
endtime = UTCDateTime("2019-07-05")
catalog = client.get_events(starttime=starttime, endtime=endtime,
                           minmagnitude=5.5)
print(f"获取到 {len(catalog)} 个地震事件")

# 2. 创建自定义数据集
print("从地震事件创建数据集...")
dataset = ObspyDataset().from_obspy_events(
    catalog=catalog, 
    client_name="IRIS", 
    before_p=20,  # P波到时前20秒
    after_p=100,  # P波到时后100秒
    distance_range=(20, 90)  # 20-90km震中距
)

# 3. 保存数据集(可选)
dataset.save("my_obspy_dataset")

# 4. 显示数据集信息
print("\n数据集信息:")
print(f"样本数量: {len(dataset)}")
print(f"元数据列: {dataset.metadata.columns.tolist()}")
print("\n元数据示例:")
print(dataset.metadata.head())

# 5. 可视化第一个样本
if len(dataset) > 0:
    sample = dataset[0]
    
    plt.figure(figsize=(12, 8))
    
    # 波形
    plt.subplot(311)
    plt.plot(sample["X"][0])
    plt.title("Z分量")
    plt.axvline(sample["p_arrival_sample"], color='r', linestyle='--', label='P波')
    plt.axvline(sample["s_arrival_sample"], color='g', linestyle='--', label='S波')
    plt.legend()
    
    plt.subplot(312)
    plt.plot(sample["X"][1])
    plt.title("N分量")
    
    plt.subplot(313)
    plt.plot(sample["X"][2])
    plt.title("E分量")
    
    plt.tight_layout()
    plt.show()
    
    # 显示波形信息
    print(f"\n波形信息:")
    print(f"- 形状: {sample['X'].shape}")
    print(f"- P波到时 (样本点): {sample['p_arrival_sample']}")
    print(f"- S波到时 (样本点): {sample['s_arrival_sample']}")
    print(f"- 震级: {sample['source_magnitude']} {sample['source_magnitude_type']}")
    print(f"- 震中距: {sample['source_distance_km']} km")

使用SeisBench模型进行震相拾取和震级估计

In [None]:
import seisbench.models as sbm
import torch
import matplotlib.pyplot as plt
import numpy as np

# 1. 加载预训练的PhaseNet模型用于震相拾取
try:
    phasenet = sbm.PhaseNet.from_pretrained("stead")
    print("已加载PhaseNet模型进行震相拾取")
except Exception as e:
    print(f"无法加载预训练模型: {e}")
    phasenet = sbm.PhaseNet()  # 使用未训练的模型
    print("使用未训练的PhaseNet模型")

# 2. 加载EQTransformer模型用于事件检测和震相拾取
try:
    eqt = sbm.EQTransformer.from_pretrained("original")
    print("已加载EQTransformer模型")
except Exception as e:
    print(f"无法加载预训练EQTransformer模型: {e}")
    eqt = sbm.EQTransformer()
    print("使用未训练的EQTransformer模型")

# 3. 对第一个样本进行震相拾取
if len(dataset) > 0:
    sample = dataset[0]
    waveform = sample["X"]
    waveform_tensor = torch.from_numpy(waveform).float().unsqueeze(0)  # [1, 3, 时间步]
    
    # PhaseNet推理
    print("\n执行PhaseNet推理...")
    with torch.no_grad():
        phasenet_pred = phasenet(waveform_tensor)
    
    # EQTransformer推理
    print("执行EQTransformer推理...")
    with torch.no_grad():
        eqt_pred = eqt(waveform_tensor)
    
    # 可视化结果
    plt.figure(figsize=(12, 12))
    
    # 原始波形
    plt.subplot(511)
    plt.plot(waveform[0])
    plt.title("Z分量波形")
    plt.axvline(sample["p_arrival_sample"], color='r', linestyle='--', label='理论P波')
    plt.axvline(sample["s_arrival_sample"], color='g', linestyle='--', label='理论S波')
    plt.legend()
    
    # PhaseNet P波概率
    plt.subplot(512)
    plt.plot(phasenet_pred["p"].squeeze().numpy(), 'r-')
    plt.title("PhaseNet P波概率")
    plt.axvline(sample["p_arrival_sample"], color='r', linestyle='--')
    
    # PhaseNet S波概率
    plt.subplot(513)
    plt.plot(phasenet_pred["s"].squeeze().numpy(), 'g-')
    plt.title("PhaseNet S波概率")
    plt.axvline(sample["s_arrival_sample"], color='g', linestyle='--')
    
    # EQTransformer P波概率
    plt.subplot(514)
    plt.plot(eqt_pred["p"].squeeze().numpy(), 'r-')
    plt.title("EQTransformer P波概率")
    plt.axvline(sample["p_arrival_sample"], color='r', linestyle='--')
    
    # EQTransformer S波概率
    plt.subplot(515)
    plt.plot(eqt_pred["s"].squeeze().numpy(), 'g-')
    plt.title("EQTransformer S波概率")
    plt.axvline(sample["s_arrival_sample"], color='g', linestyle='--')
    
    plt.tight_layout()
    plt.show()
    
    # 提取拾取结果
    def extract_picks(prob, threshold=0.5):
        """从概率中提取拾取点"""
        triggers = np.where(prob > threshold)[0]
        picks = []
        
        if len(triggers) > 0:
            # 按照连续性分组
            clusters = np.split(triggers, np.where(np.diff(triggers) > 10)[0] + 1)
            for cluster in clusters:
                if len(cluster) > 0:
                    max_prob_idx = cluster[np.argmax(prob[cluster])]
                    picks.append(max_prob_idx)
        
        return picks
    
    # 提取PhaseNet拾取点
    pnet_p_picks = extract_picks(phasenet_pred["p"].squeeze().numpy(), threshold=0.5)
    pnet_s_picks = extract_picks(phasenet_pred["s"].squeeze().numpy(), threshold=0.5)
    
    # 提取EQTransformer拾取点
    eqt_p_picks = extract_picks(eqt_pred["p"].squeeze().numpy(), threshold=0.5)
    eqt_s_picks = extract_picks(eqt_pred["s"].squeeze().numpy(), threshold=0.5)
    
    print("\n震相拾取结果:")
    print(f"PhaseNet P波拾取: {pnet_p_picks}")
    print(f"PhaseNet S波拾取: {pnet_s_picks}")
    print(f"EQTransformer P波拾取: {eqt_p_picks}")
    print(f"EQTransformer S波拾取: {eqt_s_picks}")
    print(f"理论P波到时: {sample['p_arrival_sample']}")
    print(f"理论S波到时: {sample['s_arrival_sample']}")

震级估算实现

In [None]:
# 震级估算模型 (简易实现)
class MagnitudeEstimator:
    """简单的震级估算器"""
    
    def __init__(self):
        # 这里可以加载预训练模型，但为了简单起见，我们使用简单的特征
        pass
    
    def estimate(self, waveform, p_pick, s_pick, distance_km=None):
        """
        基于波形和震相拾取估算震级
        
        Args:
            waveform: 波形数据 [3, 时间步]
            p_pick: P波到时 (样本点)
            s_pick: S波到时 (样本点)
            distance_km: 震中距 (km)
            
        Returns:
            估算的震级
        """
        # 方法1: 基于峰值振幅的简单估算
        if p_pick is not None and s_pick is not None:
            # 获取P波后的信号
            p_window = waveform[:, p_pick:s_pick]
            
            # 计算峰值振幅
            peak_amp = np.max(np.abs(p_window))
            
            # 简单的震级模型 (仅用于演示)
            if distance_km is not None:
                # log(A) + 1.1 * log(D) + 0.0
                estimated_mag = np.log10(peak_amp) + 1.1 * np.log10(distance_km) + 0.0
            else:
                # 没有距离信息时的回退方法
                estimated_mag = np.log10(peak_amp) + 2.0
                
            return estimated_mag
        
        # 方法2: 基于频谱特征
        else:
            # 使用整个波形
            # 计算频谱
            from scipy import signal
            f, Pxx = signal.welch(waveform[0], fs=100, nperseg=256)
            
            # 使用低频能量作为震级指标
            low_freq_energy = np.sum(Pxx[f < 5])
            
            # 简单的震级模型
            estimated_mag = np.log10(low_freq_energy) + 1.5
            
            return estimated_mag

# 使用震级估算器
mag_estimator = MagnitudeEstimator()

if len(dataset) > 0:
    sample = dataset[0]
    
    # 使用理论到时进行估算
    theo_mag = mag_estimator.estimate(
        sample["X"], 
        sample["p_arrival_sample"], 
        sample["s_arrival_sample"], 
        sample["source_distance_km"]
    )
    
    # 使用PhaseNet拾取点进行估算
    if pnet_p_picks and pnet_s_picks:
        pnet_mag = mag_estimator.estimate(
            sample["X"], 
            pnet_p_picks[0], 
            pnet_s_picks[0], 
            sample["source_distance_km"]
        )
    else:
        pnet_mag = None
    
    # 使用EQTransformer拾取点进行估算
    if eqt_p_picks and eqt_s_picks:
        eqt_mag = mag_estimator.estimate(
            sample["X"], 
            eqt_p_picks[0], 
            eqt_s_picks[0], 
            sample["source_distance_km"]
        )
    else:
        eqt_mag = None
    
    print("\n震级估算结果:")
    print(f"真实震级: {sample['source_magnitude']} {sample['source_magnitude_type']}")
    print(f"使用理论到时估算震级: {theo_mag:.1f}")
    if pnet_mag is not None:
        print(f"使用PhaseNet拾取点估算震级: {pnet_mag:.1f}")
    if eqt_mag is not None:
        print(f"使用EQTransformer拾取点估算震级: {eqt_mag:.1f}")

批量处理多个样本并评估结果

In [None]:
# 对数据集中的所有样本进行处理和评估
if len(dataset) > 0:
    # 创建评估结果的容器
    results = {
        "event_id": [],
        "true_magnitude": [],
        "estimated_magnitude": [],
        "pnet_p_error_sec": [],
        "pnet_s_error_sec": [],
        "eqt_p_error_sec": [],
        "eqt_s_error_sec": []
    }
    
    # 设置批量大小
    batch_size = min(10, len(dataset))  # 最多处理10个样本
    
    # 处理每个样本
    for i in range(batch_size):
        print(f"\n处理样本 {i+1}/{batch_size}...")
        
        sample = dataset[i]
        waveform = sample["X"]
        waveform_tensor = torch.from_numpy(waveform).float().unsqueeze(0)
        
        # 理论到时
        true_p = sample["p_arrival_sample"]
        true_s = sample["s_arrival_sample"]
        sampling_rate = sample["trace_sampling_rate"]
        
        # 模型推理
        with torch.no_grad():
            phasenet_pred = phasenet(waveform_tensor)
            eqt_pred = eqt(waveform_tensor)
        
        # 提取拾取点
        pnet_p_picks = extract_picks(phasenet_pred["p"].squeeze().numpy())
        pnet_s_picks = extract_picks(phasenet_pred["s"].squeeze().numpy())
        eqt_p_picks = extract_picks(eqt_pred["p"].squeeze().numpy())
        eqt_s_picks = extract_picks(eqt_pred["s"].squeeze().numpy())
        
        # 震级估算
        estimated_mag = mag_estimator.estimate(
            waveform, true_p, true_s, sample["source_distance_km"]
        )
        
        # 计算误差
        pnet_p_error = (pnet_p_picks[0] - true_p) / sampling_rate if pnet_p_picks else np.nan
        pnet_s_error = (pnet_s_picks[0] - true_s) / sampling_rate if pnet_s_picks else np.nan
        eqt_p_error = (eqt_p_picks[0] - true_p) / sampling_rate if eqt_p_picks else np.nan
        eqt_s_error = (eqt_s_picks[0] - true_s) / sampling_rate if eqt_s_picks else np.nan
        
        # 保存结果
        results["event_id"].append(sample["trace_name"])
        results["true_magnitude"].append(sample["source_magnitude"])
        results["estimated_magnitude"].append(estimated_mag)
        results["pnet_p_error_sec"].append(pnet_p_error)
        results["pnet_s_error_sec"].append(pnet_s_error)
        results["eqt_p_error_sec"].append(eqt_p_error)
        results["eqt_s_error_sec"].append(eqt_s_error)
    
    # 转换为DataFrame
    import pandas as pd
    results_df = pd.DataFrame(results)
    
    # 显示结果
    print("\n处理结果:")
    print(results_df)
    
    # 计算统计量
    print("\n统计信息:")
    print(f"震级估算平均误差: {(results_df['estimated_magnitude'] - results_df['true_magnitude']).mean():.2f}")
    print(f"PhaseNet P波拾取平均误差: {results_df['pnet_p_error_sec'].mean():.2f} 秒")
    print(f"PhaseNet S波拾取平均误差: {results_df['pnet_s_error_sec'].mean():.2f} 秒")
    print(f"EQTransformer P波拾取平均误差: {results_df['eqt_p_error_sec'].mean():.2f} 秒")
    print(f"EQTransformer S波拾取平均误差: {results_df['eqt_s_error_sec'].mean():.2f} 秒")
    
    # 可视化结果
    plt.figure(figsize=(12, 8))
    
    # 震级估算散点图
    plt.subplot(221)
    plt.scatter(results_df["true_magnitude"], results_df["estimated_magnitude"])
    plt.plot([min(results_df["true_magnitude"]), max(results_df["true_magnitude"])], 
             [min(results_df["true_magnitude"]), max(results_df["true_magnitude"])], 'r--')
    plt.xlabel("真实震级")
    plt.ylabel("估算震级")
    plt.title("震级估算结果")
    
    # PhaseNet拾取误差直方图
    plt.subplot(222)
    plt.hist(results_df["pnet_p_error_sec"].dropna(), bins=10, alpha=0.5, label="P波")
    plt.hist(results_df["pnet_s_error_sec"].dropna(), bins=10, alpha=0.5, label="S波")
    plt.xlabel("误差 (秒)")
    plt.ylabel("频率")
    plt.title("PhaseNet拾取误差")
    plt.legend()
    
    # EQTransformer拾取误差直方图
    plt.subplot(223)
    plt.hist(results_df["eqt_p_error_sec"].dropna(), bins=10, alpha=0.5, label="P波")
    plt.hist(results_df["eqt_s_error_sec"].dropna(), bins=10, alpha=0.5, label="S波")
    plt.xlabel("误差 (秒)")
    plt.ylabel("频率")
    plt.title("EQTransformer拾取误差")
    plt.legend()
    
    plt.tight_layout()
    plt.show()