import torch from scipy.ndimage import gaussian_filter import matplotlib.pyplot as plt import deepwave from deepwave import scalar, scalar_born import numpy as np import torch.fft print("===========") device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') ny = 2301 nx = 801 dx = 4.0 nt = 750 dt = 0.004 n_shots = 1 epoches = 115 freq = 26 for i in range(epoches): v_true = torch.from_file('vel10.bin', size=ny * nx).reshape(ny, nx) smooth2 = (torch.tensor(1 / gaussian_filter(1 / v_true.numpy(), 5)) .to(device)) v_true = v_true.to(device) print(i) n_sources_per_shot = 1 d_source = 20 # 20 * 4m = 80m first_source = 10 + 20 * i # 10 * 4m = 40m source_depth = 2 # 2 * 4m = 8m n_receivers_per_shot = 384 d_receiver = 6 # 6 * 4m = 24m first_receiver = 0 # 0 * 4m = 0m receiver_depth = 2 # 2 * 4m = 8m peak_time = 1.5 / freq source_locations = torch.zeros(n_shots, n_sources_per_shot, 2, dtype=torch.long, device=device) source_locations[..., 1] = source_depth source_locations[:, 0, 0] = (torch.arange(n_shots) * d_source + first_source) receiver_locations = torch.zeros(n_shots, n_receivers_per_shot, 2, dtype=torch.long, device=device) receiver_locations[..., 1] = receiver_depth receiver_locations[:, :, 0] = ( (torch.arange(n_receivers_per_shot) * d_receiver + first_receiver) .repeat(n_shots, 1) ) source_amplitudes = ( (deepwave.wavelets.ricker(freq, nt, dt, peak_time)) .repeat(n_shots, n_sources_per_shot, 1).to(device) ) observed_scattered_data = ( torch.from_file('./quzhidabo_data/observed_scattered-{}.bin'.format(i+1), size=n_shots * n_receivers_per_shot * nt) .reshape(n_shots, n_receivers_per_shot, nt) ) observed_scattered_data = ( observed_scattered_data.to(device) ) # observed_scattered_data = scalar(v_true, dx, dt, source_amplitudes=source_amplitudes, # source_locations=source_locations, # receiver_locations=receiver_locations, # accuracy=8, # pml_freq=freq) if i == 0: scatter = torch.zeros_like(smooth2) else: scatter = torch.from_file('./lsrtm-10/scatter{}.bin'.format(i), size=ny * nx).reshape(ny, nx) scatter = scatter.to(device) scatter.requires_grad_() optimiser = torch.optim.LBFGS([scatter]) loss_fn = torch.nn.MSELoss() n_epochs = 10 for epoch in range(n_epochs): def closure(): optimiser.zero_grad() out = scalar_born( smooth2, scatter, dx, dt, source_amplitudes=source_amplitudes, source_locations=source_locations, receiver_locations=receiver_locations, accuracy=8, pml_freq=freq, ) loss = 1e20 * loss_fn(out[-1], observed_scattered_data) loss.backward() print(epoch, loss) return loss.item() optimiser.step(closure) scatter.detach().cpu().numpy().tofile('./lsrtm-10/scatter{}.bin'.format(i + 1))