In [None]:
# @author: wuyuping (ypwu@stu.hit.edu.cn)

import torch
import matplotlib.pyplot as plt

In [None]:
import scipy.io
import numpy as np
data_path = ''
    
ny = 340
nx = 130

v_true = torch.from_file(data_path+'mar_big_vp_130_340.bin',
                    size=ny * nx).reshape(ny, nx)
print(v_true.shape)

In [None]:
from scipy.ndimage import gaussian_filter

In [None]:
v_init = (torch.tensor(1/gaussian_filter(1/v_true.cpu().numpy(), 30)))

In [None]:
plt.figure(figsize=(10, 5))

plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
plt.title("Position (km)", fontsize=20)

plt.imshow(v_init.T, aspect='auto', cmap='jet', vmin=1100, vmax=4700)
plt.yticks(np.arange(0,2.925,1)*1000.0/22.5, [0,1,2], fontsize=15)
plt.xticks(np.arange(0,7.650,2)*1000.0/22.5, [0,2,4,6], fontsize=15)

plt.colorbar()

plt.plot(torch.ones(130)*60, np.arange(0,130), color='k', linestyle='-.')
plt.plot(torch.ones(130)*207, np.arange(0,130), color='k', linestyle='-.')
plt.plot(torch.ones(130)*290, np.arange(0,130), color='k', linestyle='-.')
plt.scatter(60, 2, marker='o', color ="#d6641e", zorder=2, s=50, label='Well1')
plt.scatter(207, 2, marker='s', color ="#d6641e", zorder=2, s=50, label='Well2')
plt.scatter(290, 2, marker='v', color ="#d6641e", zorder=2, s=50, label='Well3')
plt.legend(bbox_to_anchor=(0.75, 0.01), ncol=3,edgecolor='k', handletextpad=0.1, fontsize=13, borderpad =0.3)

# plt.xlabel("X")
plt.ylabel("Depth (km)", fontsize=20)
# plt.title("outputs")
plt.tight_layout()
# plt.savefig('M2-init.png',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('M2-init.eps',dpi=100,transparent=True,bbox_inches='tight')

In [None]:
plt.figure(figsize=(10, 5))

plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
plt.title("Position (km)", fontsize=20)

plt.imshow(v_true.T, aspect='auto', cmap='jet', vmin=1100, vmax=4700)
plt.yticks(np.arange(0,2.925,1)*1000.0/22.5, [0,1,2], fontsize=15)
plt.xticks(np.arange(0,7.650,2)*1000.0/22.5, [0,2,4,6], fontsize=15)
plt.colorbar()

plt.plot(torch.ones(130)*60, np.arange(0,130), color='k', linestyle='-.')
plt.plot(torch.ones(130)*207, np.arange(0,130), color='k', linestyle='-.')
plt.plot(torch.ones(130)*290, np.arange(0,130), color='k', linestyle='-.')
plt.scatter(60, 2, marker='o', color ="#d6641e", zorder=2, s=50, label='Well1')
plt.scatter(207, 2, marker='s', color ="#d6641e", zorder=2, s=50, label='Well2')
plt.scatter(290, 2, marker='v', color ="#d6641e", zorder=2, s=50, label='Well3')
plt.legend(bbox_to_anchor=(0.75, 0.01), ncol=3,edgecolor='k', handletextpad=0.1, fontsize=13, borderpad =0.3)
# plt.xlabel("X")
plt.ylabel("Depth (km)", fontsize=20)
# plt.title("outputs")
plt.tight_layout()
# plt.savefig('M2-true.png',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('M2-true.eps',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('denoise_data_nor.png')
# plt.close()

In [None]:
import deepwave

In [None]:
dx = 22.5
n_shots = 30

n_sources_per_shot = 1
d_source = 11
first_source = 5
source_depth = 1

n_receivers_per_shot = 339
d_receiver = 1
first_receiver = 0
receiver_depth = 1

freq = 7
nt = 2000
dt = 0.002
peak_time = 1.0 / freq

print(deepwave.common.cfl_condition(dy = dx, dx = dx, dt = dt, max_vel = 4700))

In [None]:
from scipy.signal import butter
from torchaudio.functional import biquad

In [None]:
# source_locations, [shot, source, space]
source_locations = torch.zeros(n_shots, n_sources_per_shot, 2,
                               dtype=torch.long)
source_locations[..., 1] = source_depth
source_locations[:, 0, 0] = torch.arange(n_shots) * d_source + first_source

# receiver_locations [shot, receiver, space]
receiver_locations = torch.zeros(n_shots, n_receivers_per_shot, 2,
                                 dtype=torch.long)
receiver_locations[..., 1] = receiver_depth
receiver_locations[:, :, 0] = (
    (torch.arange(n_receivers_per_shot) * d_receiver + first_receiver)
        .repeat(n_shots, 1)
)

# source_amplitudes [shot, source, time]
source_amplitudes = (
    deepwave.wavelets.ricker(freq, nt, dt, peak_time)
)

sos = butter(6, 5, 'hp', fs=1/dt, output='sos')
sos = [torch.tensor(sosi).to(source_amplitudes.dtype)
       for sosi in sos]

def filt(x):
    return biquad(biquad(biquad(x, *sos[0]), *sos[1]),
                  *sos[2])

source_amplitudes_filt = filt(source_amplitudes).repeat(n_shots, n_sources_per_shot, 1)

In [None]:
source_amplitudes_filt.shape

In [None]:
device = torch.device('cpu')

In [None]:
observed_data = deepwave.scalar(v_true.to(device), dx, dt, 
                           source_amplitudes=source_amplitudes_filt.to(device),
                           source_locations=source_locations.to(device),
                           receiver_locations=receiver_locations.to(device),
                           max_vel=4700,
                           accuracy=8,
                           pml_freq=freq,
                           pml_width=[20, 20, 0, 20])[-1]
observed_data.shape

In [None]:
def data_noise(observed_data, scale):
    data_size = observed_data.size()
    noise_d = torch.zeros_like(observed_data)
    noise_d.normal_()
    data_energy = torch.sqrt(torch.mean(observed_data**2,dim=[1,2], keepdim=False)) * scale
    data_energy = data_energy.repeat(data_size[1],data_size[2]).reshape(data_size)
    noise_d_energy = noise_d * data_energy
    
    return noise_d_energy

In [None]:
noise = data_noise(observed_data, scale=0.5)
observed_data_noise = observed_data + noise
observed_data_noise.shape

In [None]:
plt.figure(figsize=(10, 5))

plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
plt.title("Position (km)", fontsize=20)

vmin, vmax = torch.quantile(observed_data[15],
                            torch.tensor([0.05, 0.95]).to(device))
plt.imshow(observed_data[15].cpu().T, aspect='auto',
             cmap='seismic', vmin=-vmax, vmax=vmax)
# plt.xlabel("Trace", fontsize=15)
plt.ylabel("Time (s)", fontsize=20)
plt.yticks(np.arange(0,2001,500), np.arange(0,2001,500)*0.002, fontsize=15)
plt.xticks(np.arange(0,7.650,2)*1000.0/22.5, [0,2,4,6], fontsize=15)
# plt.title("Seismic Data", fontsize=15)
plt.colorbar()
plt.tight_layout()
# plt.savefig('seismic_data.png',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('seismic_data.eps',dpi=100,transparent=True,bbox_inches='tight')

In [None]:
plt.figure(figsize=(10, 5))

plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
plt.title("Position (km)", fontsize=20)

vmin, vmax = torch.quantile(observed_data_noise[15],
                            torch.tensor([0.05, 0.95]).to(device))
plt.imshow(observed_data_noise[15].cpu().T, aspect='auto',
             cmap='seismic', vmin=-vmax, vmax=vmax)
# plt.xlabel("Trace", fontsize=15)
plt.ylabel("Time (s)", fontsize=20)
plt.yticks(np.arange(0,2001,500), np.arange(0,2001,500)*0.002, fontsize=15)
plt.xticks(np.arange(0,7.650,2)*1000.0/22.5, [0,2,4,6], fontsize=15)
# plt.title("Seismic Data", fontsize=15)
plt.colorbar()
plt.tight_layout()
# plt.savefig('noisy_seismic_data.png',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('noisy_seismic_data.eps',dpi=100,transparent=True,bbox_inches='tight')

In [None]:
#observed_data_noise,observed_data,noise

In [None]:
snr = 0.0
for i in range(n_shots):
    rec_ind = observed_data[i,:,:]
    target_ind  = observed_data_noise[i,:,:]
    s      = 10*torch.log10(torch.sum(target_ind**2)/torch.sum((rec_ind-target_ind)**2))
    snr    = snr + s
snr = snr/n_shots
print(snr)

In [None]:
observed_data.cpu().numpy().tofile(data_path+'marmousi2_130_340_data_experiments3_3_filt.bin')

In [None]:
noise.cpu().numpy().tofile(data_path+'marmousi2_130_340_data_noise_experiments3_3_filt.bin')

In [None]:
# small offset acquisition setting

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 7))
fig.subplots_adjust(wspace=0.05)

im = axes[0,0].imshow(v_true.T, aspect='auto', cmap='jet', vmin=1100, vmax=4700)
axes[0,0].set_yticks(np.arange(0,2.925,1)*1000.0/22.5, [0,1,2], fontsize=12)
axes[0,0].set_xticks(np.arange(0,7.650,2)*1000.0/22.5, [0,2,4,6], fontsize=12)
fig.colorbar(im, ax=axes[0,0])

axes[0,0].plot(torch.ones(130)*60, np.arange(0,130), color='k', linestyle='-.')
axes[0,0].plot(torch.ones(130)*207, np.arange(0,130), color='k', linestyle='-.')
axes[0,0].plot(torch.ones(130)*290, np.arange(0,130), color='k', linestyle='-.')
axes[0,0].scatter(60, 2, marker='o', color ="#d6641e", zorder=2, s=50, label='Well1')
axes[0,0].scatter(207, 2, marker='s', color ="#d6641e", zorder=2, s=50, label='Well2')
axes[0,0].scatter(290, 2, marker='v', color ="#d6641e", zorder=2, s=50, label='Well3')
axes[0,0].legend(bbox_to_anchor=(0.9, 0.01), ncol=3,edgecolor='k', handletextpad=0.1, fontsize=10, borderpad =0.3)
# plt.xlabel("X")
# 在顶部坐标轴设置标签
axes[0,0].xaxis.set_label_position('top')
# 设置顶部坐标轴的刻度线，如果没有下面的代码，默认刻度标签在底部
axes[0,0].xaxis.tick_top()
axes[0,0].set_xlabel("Position (km)", fontsize=12)
axes[0,0].set_ylabel("Depth (km)", fontsize=12)
# plt.title("outputs")
axes[0,0].text(-30,-5,"a)",fontsize = 18, weight='bold')


im = axes[0,1].imshow(v_init.T, aspect='auto', cmap='jet', vmin=1100, vmax=4700)
axes[0,1].set_yticks(np.arange(0,2.925,1)*1000.0/22.5, [0,1,2], fontsize=12)
axes[0,1].set_xticks(np.arange(0,7.650,2)*1000.0/22.5, [0,2,4,6], fontsize=12)
fig.colorbar(im, ax=axes[0,1])

axes[0,1].plot(torch.ones(130)*60, np.arange(0,130), color='k', linestyle='-.')
axes[0,1].plot(torch.ones(130)*207, np.arange(0,130), color='k', linestyle='-.')
axes[0,1].plot(torch.ones(130)*290, np.arange(0,130), color='k', linestyle='-.')
axes[0,1].scatter(60, 2, marker='o', color ="#d6641e", zorder=2, s=50, label='Well1')
axes[0,1].scatter(207, 2, marker='s', color ="#d6641e", zorder=2, s=50, label='Well2')
axes[0,1].scatter(290, 2, marker='v', color ="#d6641e", zorder=2, s=50, label='Well3')
axes[0,1].legend(bbox_to_anchor=(0.9, 0.01), ncol=3,edgecolor='k', handletextpad=0.1, fontsize=10, borderpad =0.3)
# 在顶部坐标轴设置标签
axes[0,1].xaxis.set_label_position('top')
# 设置顶部坐标轴的刻度线，如果没有下面的代码，默认刻度标签在底部
axes[0,1].xaxis.tick_top()
axes[0,1].set_xlabel("Position (km)", fontsize=12)
# axes[0,1].set_ylabel("Depth (km)", fontsize=12)
axes[0,1].yaxis.set_major_formatter(plt.NullFormatter())
axes[0,1].set_yticks([])
axes[0,1].text(-30,-5,"b)",fontsize = 18, weight='bold')

# axes[1,0].set_title("Position (km)", fontsize=20)
vmin, vmax = torch.quantile(observed_data[15],
                            torch.tensor([0.05, 0.95]).to(device))
im = axes[1,0].imshow(observed_data[15].cpu().T, aspect='auto',
             cmap='seismic', vmin=-vmax, vmax=vmax)
# plt.xlabel("Trace", fontsize=15)
axes[1,0].set_ylabel("Time (s)", fontsize=12)
axes[1,0].set_yticks(np.arange(0,2001,500), np.arange(0,2001,500)*0.002, fontsize=12)
# axes[1,0].set_xticks(np.arange(0,7.650,2)*1000.0/22.5, [0,2,4,6], fontsize=15)
# plt.title("Seismic Data", fontsize=15)
axes[1,0].xaxis.set_major_formatter(plt.NullFormatter())
axes[1,0].set_xticks([])
fig.colorbar(im, ax=axes[1,0])
axes[1,0].text(-60,-10,"c)",fontsize = 18, weight='bold')


# axes[1,1].set_title("Position (km)", fontsize=20)
vmin, vmax = torch.quantile(observed_data_noise[15],
                            torch.tensor([0.05, 0.95]).to(device))
axes[1,1].imshow(observed_data_noise[15].cpu().T, aspect='auto',
             cmap='seismic', vmin=-vmax, vmax=vmax)
# plt.xlabel("Trace", fontsize=15)
# axes[1,1].set_ylabel("Time (s)", fontsize=12)
# axes[1,1].set_yticks(np.arange(0,2001,500), np.arange(0,2001,500)*0.002, fontsize=12)
# axes[1,1].set_xticks(np.arange(0,7.650,2)*1000.0/22.5, [0,2,4,6], fontsize=15)
# plt.title("Seismic Data", fontsize=15)
axes[1,1].xaxis.set_major_formatter(plt.NullFormatter())
axes[1,1].set_xticks([])
axes[1,1].yaxis.set_major_formatter(plt.NullFormatter())
axes[1,1].set_yticks([])
fig.colorbar(im, ax=axes[1,1])
axes[1,1].text(-30,-10,"d)",fontsize = 18, weight='bold')

# plt.savefig('add_m2.png',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('add_m2.eps',dpi=100,transparent=True,bbox_inches='tight')
# plt.savefig('add_m2.pdf',dpi=100,transparent=True,bbox_inches='tight')