In [None]:
# @author: wuyuping (ypwu@stu.hit.edu.cn)

import torch
import numpy as np
import matplotlib.pyplot as plt
import deepwave
from deepwave import scalar
import scipy
import torch.nn.functional as F

In [None]:
def sliding_average_filter(signal, size=5, std=1, mode='gau'):
    """
    滑动平均滤波器
    :param size: 滤波器的大小
    :param std: 控制标准差，决定滤波器的分布
    :return: 滑动平均滤波器的PyTorch张量
    """
    if mode == 'gau':
        # 创建一个1D滤波器的权重，使用高斯分布
        weights = torch.exp(-(torch.arange(size).float() - (size - 1) / 2) ** 2 / (2 * std ** 2))
    else:
        weights = torch.ones(size)
    
    # 归一化权重，确保总和为1
    weights /= weights.sum()
    # 创建1D卷积层，设置卷积核为创建的滑动平均滤波器

    return F.conv1d(F.pad(signal.view(1,1,-1),((size - 1) // 2,(size - 1) // 2), mode='replicate'), weights.view(1, 1, size), 
                                      bias=None, stride=1, padding=0)[0,0]

# 高通滤波
def filter_highpass(signal, low_l, low_cut, nt, dt):

    # 傅里叶变换
    signal_fft = torch.fft.rfft(signal)
    # 滤波系数, low_l左端点, low_cut右端点
    co_e = co_filter_highpass(low_l, low_cut, nt, dt, signal_fft.shape[-1])

    signal_fft_filtered = (signal_fft * co_e)

    signal_filtered = torch.fft.irfft(signal_fft_filtered)

    return signal_filtered

# 高通滤波系数
def co_filter_highpass(low_l, low_cut, nt, dt, nt_fft):
    #     signal_fft = torch.fft.rfft(signal)

    co_signal = torch.ones(nt_fft)

    # freq = torch.arange((n + 1) // 2) / (d * n)
    # index = freq * (d * n)

    # end
    low_cut_index = int(low_cut * dt * nt)
    # begin
    low_l_index = int(low_l * dt * nt)

    co_signal[0:low_l_index + 1] = 0

    # tensor([0.0000, 0.1250, 0.2500, 0.3750, 0.5000, 0.6250, 0.7500, 0.8750])
    co_signal[low_l_index:low_cut_index] = torch.arange(0, 1, 1.0 / (low_cut_index - low_l_index))

    # 分段平滑
    index_length = 4
    co_signal[low_cut_index-index_length:low_cut_index+index_length] = torch.arange(co_signal[low_cut_index-index_length], 
                                                                                    1, (1-co_signal[low_cut_index-index_length]) / (2*index_length))
    
    low_cut_index = low_cut_index + index_length
    index_length = 3
    co_signal[low_cut_index-index_length:low_cut_index+index_length] = torch.arange(co_signal[low_cut_index-index_length], 
                                                                                    1, (1-co_signal[low_cut_index-index_length]) / (2*index_length))
    
    low_cut_index = low_cut_index + index_length
    index_length = 2
    co_signal[low_cut_index-index_length:low_cut_index+index_length] = torch.arange(co_signal[low_cut_index-index_length], 
                                                                                    1, (1-co_signal[low_cut_index-index_length]) / (2*index_length))
    
    
    
    index_length = 4
    co_signal[low_l_index-index_length:low_l_index+index_length] = torch.arange(0, co_signal[low_l_index+index_length], 
                                                                                co_signal[low_l_index+index_length] / (2*index_length))
    
    low_l_index = low_l_index - index_length
    index_length = 3
    co_signal[low_l_index-index_length:low_l_index+index_length] = torch.arange(0, co_signal[low_l_index+index_length], 
                                                                                co_signal[low_l_index+index_length] / (2*index_length))
    
    low_l_index = low_l_index - index_length
    index_length = 2
    co_signal[low_l_index-index_length:low_l_index+index_length] = torch.arange(0, co_signal[low_l_index+index_length], 
                                                                                co_signal[low_l_index+index_length] / (2*index_length))

    return sliding_average_filter(co_signal)
    # return co_signal

In [None]:
torch.set_default_dtype(torch.float)

# PyTorch random number generator
torch.manual_seed(1234)

# Random number generators in other libraries
np.random.seed(1234)

ny = 340
nx = 130

v_true = torch.from_file('mar_big_vp_130_340.bin',
                    size=ny * nx).reshape(ny, nx)
v_true = v_true[70:-70]

In [None]:
plt.imshow(v_true.T, cmap='jet', aspect='auto')

In [None]:
plt.figure(figsize=(10, 5))

plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
plt.title("Position (km)", fontsize=20)

plt.imshow(v_true.T, aspect='auto', cmap='jet', vmin=1500, vmax=4700)
plt.yticks(np.arange(0,130,40), np.arange(0,130,40)*12.5/1000.0, fontsize=15)
plt.xticks(np.arange(0,200,40), np.arange(0,200,40)*12.5/1000.0, fontsize=15)
plt.colorbar()

plt.scatter(source_locations[0,:,0], source_locations[0,:,1], c='r', s=10, label='source')
plt.scatter(receiver_locations_tuolan[0,:,0], receiver_locations_tuolan[0,:,1], c='m', s=1, label='receivers')

# plt.xlabel("X")
plt.ylabel("Depth (km)", fontsize=20)
# plt.title("outputs")
plt.tight_layout()
plt.legend(fontsize=15)
# plt.savefig('M2-true.png',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('M2-true-tuolan.eps',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('denoise_data_nor.png')
# plt.close()

In [None]:
velocity_index = torch.where(v_true == 1500, 1, 0)
index_seafloor = torch.sum(velocity_index,1)

In [None]:
print(torch.max(v_true), torch.min(v_true), v_true.shape)

In [None]:
print(deepwave.common.cfl_condition(dy = 12.5, dx = 12.5, dt = 0.001, max_vel = 4700))

In [None]:
dx = 12.5
n_shots = 20
n_sources_per_shot = 1
d_source = 10  # 10 * 8m = 80m
first_source = 4  # 5 * 8m = 40m
source_depth = 2  # 1 * 8m = 8m

n_receivers_per_shot = 200
d_receiver = 1  # 3 * 8m = 24m
first_receiver = 0  # 0 * 8m = 0m
receiver_depth = 3  # 1 * 8m = 8m

freq = 15
nt = 4000
dt = 0.001
peak_time = 6.0 / freq

print(deepwave.common.cfl_condition(dy = dx, dx = dx, dt = dt, max_vel = 4700))

# source_locations, [shot, source, space]
source_locations = torch.zeros(n_shots, n_sources_per_shot, 2,
                               dtype=torch.long)
source_locations[..., 1] = source_depth
source_locations[:, 0, 0] = torch.arange(n_shots) * d_source + first_source

# receiver_locations [shot, receiver, space]
receiver_locations = torch.zeros(n_shots, n_receivers_per_shot, 2,
                                 dtype=torch.long)
receiver_locations[..., 1] = receiver_depth
receiver_locations[:, :, 0] = (
    (torch.arange(n_receivers_per_shot) * d_receiver + first_receiver)
           .repeat(n_shots, 1))

# 拖缆设置
for index in range(n_shots):
    # 首先判断这个长度来判断拖缆的方向, 
    # 如果炮点位置(source_locations[index,0,0])后延长拖缆距离(80 + 24)不超出则可以设定
    if source_locations[index,0,0] + 80 + 24 < 200:
        begin = source_locations[index,0,0] + 4
        end = source_locations[index,0,0] + 4 + 76
        receiver_locations[index,0:begin,0] = 0
        receiver_locations[index,end:,0] = 0
    else:
        begin = source_locations[index,0,0] - 4
        end = source_locations[index,0,0] - 4 - 76
        receiver_locations[index,begin+1:,0] = 0
        receiver_locations[index,0:end+1,0] = 0

n_receivers_per_shot_tuolan = torch.sum(torch.where(receiver_locations[:,:,0]>0,1,0),1)
# receiver_locations [shot, receiver, space]
receiver_locations_tuolan = torch.zeros(n_shots, n_receivers_per_shot_tuolan[0], 2,
                                 dtype=torch.long)
receiver_locations_tuolan[..., 1] = receiver_depth
receiver_locations_tuolan[:, :, 0] = (
    (torch.arange(n_receivers_per_shot_tuolan[0]) * d_receiver + first_receiver)
           .repeat(n_shots, 1))

for index in range(n_shots):
    # 相当于找到检波器的位置索引
    index_receivers = torch.argwhere(receiver_locations[index,:,0] > 0)
    receiver_locations_tuolan[index, :, 0] = receiver_locations[index, index_receivers, 0][:,0]

In [None]:
plt.scatter(source_locations[0,:,0], source_locations[0,:,1], s=10, label='source')
plt.scatter(receiver_locations_tuolan[0,:,0], receiver_locations_tuolan[0,:,1], s=1, label='receivers')
plt.ylim(0,100)
plt.xlim(-5,200)
plt.gca().invert_yaxis()
plt.legend()
plt.xlabel('Width')
plt.ylabel('Depth')
plt.title('Acquisition System')

In [None]:
source_amplitudes = (
        deepwave.wavelets.ricker(freq, nt, dt, peak_time)
    )
print(peak_time)

In [None]:
source_amplitudes_filtered = filter_highpass(source_amplitudes, 12, 13, nt, dt)

In [None]:
freq_list = torch.fft.rfftfreq(nt,dt)
trace_observed_fft = torch.fft.rfft(source_amplitudes)
trace_observed_fft_amp = torch.abs(trace_observed_fft)

filtered_highpass_fft = torch.fft.rfft(source_amplitudes_filtered)
filtered_highpass_fft_amp = torch.abs(filtered_highpass_fft)

In [None]:
plt.figure(figsize=(10, 5))

plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = True
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = False

plt.plot(freq_list[0:200], trace_observed_fft_amp.cpu().detach().numpy()[0:200], label='Original')
plt.plot(freq_list[0:200], filtered_highpass_fft_amp[0:200], label='Filtered')
# plt.title('Trace-Spectrum')
plt.xlabel("Frequency", fontsize=20)
plt.ylabel("Amplitude", fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.legend(fontsize=15)
# plt.savefig('wavelet-freq-tuolan.eps',dpi=100,transparent=True,bbox_inches='tight')
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.xlabel("Time (s)", fontsize=20)
plt.ylabel("Amplitude", fontsize=20)
plt.plot(source_amplitudes[0:1000], label='Original')
plt.plot(source_amplitudes_filtered[0:1000], label='Filtered')
plt.yticks(fontsize=15)
plt.xticks(np.arange(0,1001,250), np.arange(0,1001,250)/1000.0,fontsize=15)
plt.legend(fontsize=15)
# plt.savefig('wavelet_tuolan.eps',dpi=100,transparent=True,bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 3.5))
fig.subplots_adjust(wspace=0.1)

axes[0].text(-35.,6,"a)",fontsize = 18, weight='bold')
im = axes[0].imshow(v_true.T, aspect='auto', cmap='jet', vmin=1500, vmax=4700)
axes[0].set_yticks(np.arange(0,130,40), np.arange(0,130,40)*12.5/1000.0, fontsize=12)
axes[0].set_xticks(np.arange(0,200,40), np.arange(0,200,40)*12.5/1000.0, fontsize=12)
# axes[0].colorbar()

axes[0].scatter(source_locations[0,:,0], source_locations[0,:,1], c='r', s=10, label='source')
axes[0].scatter(receiver_locations_tuolan[0,:,0], receiver_locations_tuolan[0,:,1], c='m', s=1, label='receivers')

# plt.xlabel("X")
# 在顶部坐标轴设置标签
axes[0].xaxis.set_label_position('top')
# 设置顶部坐标轴的刻度线，如果没有下面的代码，默认刻度标签在底部
axes[0].xaxis.tick_top()
axes[0].set_ylabel("Depth (km)", fontsize=12)
axes[0].set_xlabel("Position (km)", fontsize=12)
# plt.title("outputs")
axes[0].legend(fontsize=11)

axes[1].text(-7.,27.5,"b)",fontsize = 18, weight='bold')
axes[1].plot(freq_list[0:200], trace_observed_fft_amp.cpu().detach().numpy()[0:200], label='Original')
axes[1].plot(freq_list[0:200], filtered_highpass_fft_amp[0:200], label='Filtered')
# plt.title('Trace-Spectrum')
axes[1].set_xlabel("Frequency", fontsize=12)
axes[1].yaxis.set_major_formatter(plt.NullFormatter())
axes[1].set_yticks([])
# axes[1].set_ylabel("Amplitude", fontsize=12)
axes[1].tick_params(axis='y',labelsize=12)
axes[1].tick_params(axis='x',labelsize=12)
axes[1].legend(fontsize=12)


axes[2].text(-140,1,"c)",fontsize = 18, weight='bold')
axes[2].yaxis.set_major_formatter(plt.NullFormatter())
axes[2].set_yticks([])
axes[2].set_xlabel("Time (s)", fontsize=12)
# ax_right.set_ylabel("Amplitude", fontsize=12)
axes[2].plot(source_amplitudes[0:1000], label='Original')
axes[2].plot(source_amplitudes_filtered[0:1000], label='Filtered')
axes[2].tick_params(axis='y',labelsize=12)
axes[2].set_xticks(np.arange(0,1001,250), np.arange(0,1001,250)/1000.0,fontsize=12)
axes[2].legend(fontsize=12)

position=fig.add_axes([0.15, 0.08, 0.2, 0.02])
cbar = fig.colorbar(im, ax=[axes[0]], 
                    cax=position,orientation='horizontal', fraction=0.025)

# plt.savefig('ae_m2_tuolan.png',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('ae_m2_tuolan.eps',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('ae_m2_tuolan.pdf',dpi=300,transparent=True,bbox_inches='tight')

In [None]:
filtered_highpass_f = source_amplitudes_filtered.repeat(n_shots, n_sources_per_shot, 1)

In [None]:
device = torch.device('cuda:0')

In [None]:
observed_data = scalar(
            v_true.to(device), dx, dt,
            source_amplitudes=filtered_highpass_f.to(device),
            source_locations=source_locations.to(device),
            receiver_locations=receiver_locations_tuolan.to(device),
            max_vel=4700,
            pml_freq=freq,
            accuracy=8,
            pml_width=[20, 20, 0, 20]
        )[-1]

In [None]:
observed_data.shape

In [None]:
figsize = (12, 6)
plt.imshow(observed_data[0].cpu().T, aspect='auto', cmap='seismic',vmin=-4, vmax=4)
plt.colorbar()
plt.savefig('obs_data.png')
plt.show()

In [None]:
freq_list = torch.fft.rfftfreq(4000,dt)

filtered_highpass_fft = torch.fft.rfft(observed_data[10].cpu()[30,:])
filtered_highpass_fft_amp = torch.abs(filtered_highpass_fft)

In [None]:
figsize = (12, 6)
# plt.plot(freq_list[0:100], trace_observed_fft_amp.cpu().detach().numpy()[0:100], label='original_fft_amp')
plt.plot(freq_list[0:100], filtered_highpass_fft_amp[0:100], label='filtered_fft_amp')
plt.title('Trace-Spectrum')
plt.xlabel('Frequency Hz')
plt.ylabel('Amplitude')
plt.legend()
# plt.savefig('trace_freq.png')
plt.show()

In [None]:
def data_noise(observed_data, scale):
    data_size = observed_data.size()
    noise_d = torch.zeros_like(observed_data)
    noise_d.normal_()
    data_energy = torch.sqrt(torch.mean(observed_data**2,dim=[1,2], keepdim=False)) * scale
    data_energy = data_energy.repeat(data_size[1],data_size[2]).reshape(data_size)
    noise_d_energy = noise_d * data_energy
    
    return noise_d_energy

In [None]:
noise = data_noise(observed_data, scale=0.5)
observed_data_noise = observed_data + noise
observed_data_noise.shape

In [None]:
figsize = (12, 6)
plt.imshow(observed_data_noise[0].cpu().T, aspect='auto', cmap='seismic',vmin=-5, vmax=5)
plt.colorbar()
plt.savefig('obs_data_noise.png')
plt.show()

In [None]:
snr = 0.0
for i in range(n_shots):
    rec_ind = observed_data[i,:,:]
    target_ind  = observed_data_noise[i,:,:]
    s      = 10*torch.log10(torch.sum(target_ind**2)/torch.sum((rec_ind-target_ind)**2))
    snr    = snr + s
snr = snr/n_shots
print(snr)

In [None]:
observed_data.cpu().numpy().tofile('marmousi2_130_200_data_experiments_filter_12_13_fs.bin')
noise.cpu().numpy().tofile('marmousi2_130_200_data_noise_experiments_filter_12_13_fs.bin')