import torch import matplotlib.pyplot as plt import deepwave from deepwave import elastic import numpy as np import torch.fft import os os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,3" # 设置只使用第0号和第1号GPU device = torch.device('cuda' if torch.cuda.is_available()else 'cpu') class Prop(torch.nn.Module): def __init__(self, dx, dt,freq): super().__init__() self.dx = dx self.dt = dt self.freq = freq def forward(self, vp_true,vs_true,rho_true,source_amplitudes1, source_locations1, receiver_locations1): out = elastic( *deepwave.common.vpvsrho_to_lambmubuoyancy(vp_true, vs_true, rho_true), self.dx, self.dt, source_amplitudes_y=source_amplitudes1.to(device), source_locations_y=source_locations1, receiver_locations_y=receiver_locations1, pml_freq=self.freq, ) return out[-2] ny = 192 nx = 128 dx = 10 n_shots = 19 nt = 2000 dt = 0.001 epoches = 10 for m in range(epoches): freq = 7+m for i in range(10): vp_true = torch.from_file('vptrue.bin', size=ny * nx).reshape(ny, nx).to(device) vs_true = torch.from_file('vstrue.bin', size=ny * nx).reshape(ny, nx).to(device) rho_true = torch.from_file('rotrue.bin', size=ny * nx).reshape(ny, nx).to(device) vp_true = vp_true.T vs_true = vs_true.T rho_true = rho_true.T vp_true.to(device) vs_true.to(device) rho_true.to(device) if m == 0 and i == 0: vp_background = torch.from_file('vptrue.bin', size=ny * nx).reshape(ny, nx).to(device) vs_background = torch.from_file('vstrue.bin', size=ny * nx).reshape(ny, nx).to(device) rho_background = torch.from_file('rotrue.bin', size=ny * nx).reshape(ny, nx).to(device) vp_background[:,:] =1500 vs_background[:,:] =900 rho_background[:,:] =2200 elif m != 0 and i == 0: vp_background = torch.from_file('vp{}.{}.bin'.format(10, freq - 1), size=ny * nx).reshape(ny, nx).to(device) vs_background = torch.from_file('vs{}.{}.bin'.format(10, freq - 1), size=ny * nx).reshape(ny, nx).to(device) rho_background = torch.from_file('rho{}.{}.bin'.format(10, freq - 1), size=ny * nx).reshape(ny, nx).to(device) else: vp_background = torch.from_file('vp{}.{}.bin'.format(i, freq), size=ny * nx).reshape(ny, nx).to(device) vs_background = torch.from_file('vs{}.{}.bin'.format(i, freq), size=ny * nx).reshape(ny, nx).to(device) rho_background = torch.from_file('rho{}.{}.bin'.format(i, freq), size=ny * nx).reshape(ny, nx).to(device) n_sources_per_shot = 1 d_source = 10 # 20 * 4m = 80m first_source = 0 + i # 10 * 4m = 40m source_depth = 1 # 2 * 4m = 8m print(vp_true.shape) n_receivers_per_shot = 191 d_receiver = 1 # 6 * 4m = 24m first_receiver = 0 # 0 * 4m = 0m receiver_depth = 0 # 2 * 4m = 8m peak_time = 1.5 / freq # source_locations source_locations = torch.zeros(n_shots, n_sources_per_shot, 2, dtype=torch.long, device=device) source_locations[..., 0] = source_depth source_locations[:, 0, 1] = (torch.arange(n_shots) * d_source + first_source) # receiver_locations receiver_locations = torch.zeros(n_shots, n_receivers_per_shot, 2, dtype=torch.long, device=device) receiver_locations[..., 0] = receiver_depth receiver_locations[:, :, 1] = ( (torch.arange(n_receivers_per_shot) * d_receiver + first_receiver) .repeat(n_shots, 1) ) # source_amplitudes source_amplitudes = ( (deepwave.wavelets.ricker(freq, nt, dt, peak_time)) .repeat(n_shots, n_sources_per_shot, 1).to(device) ) # Create observed data using true models prop = Prop( dx, dt, freq) prop = torch.nn.DataParallel(prop).to(device) observed_data = prop(vp_true,vs_true,rho_true,source_amplitudes,source_locations,receiver_locations,) # Setup optimiser to perform inversion prop2 = Prop( dx, dt, freq) prop2 = torch.nn.DataParallel(prop2).to(device) vp_background.to(device) vs_background.to(device) vp = vp_background.clone().requires_grad_() vs = vs_background.clone().requires_grad_() rho = rho_background.clone().requires_grad_() optimiser = torch.optim.LBFGS([vp, vs, rho],max_iter=20) loss_fn = torch.nn.MSELoss() vp = vp.T vs = vs.T rho = rho.T # Run optimisation/inversion n_epochs = 2 vp.to(device) vs.to(device) rho.to(device) for epoch in range(n_epochs): def closure(): optimiser.zero_grad() out = prop2(vp,vs,rho, source_amplitudes, source_locations, receiver_locations, ) loss = 1e22*loss_fn(out, observed_data) print(loss/1e22) loss.backward() return loss optimiser.step(closure) vp = vp.T vs = vs.T rho = rho.T vp = vp.detach().cpu().numpy() vs = vs.detach().cpu().numpy() rho = rho.detach().cpu().numpy() vp.tofile("vp{}.{}.bin".format(i+1,freq)) vs.tofile("vs{}.{}.bin".format(i+1,freq)) rho.tofile("rho{}.{}.bin".format(i+1,freq)) print("=======",freq)