In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')
import gc
import os
import time
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import scipy as sp
import cv2
from tqdm.notebook import tqdm
from skimage.metrics import structural_similarity as ssim
import deepwave

import sys
import os
sys.path.append(os.path.abspath(".."))
from deepinvhessian import fwi
from deepinvhessian.utilities import *
from deepinvhessian.filters import *
from deepinvhessian.train import *
from deepinvhessian.masks import *
from unet import *

In [None]:
set_seed(14)
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

In [None]:
def get_dir(directory):
    """
    Creates the given directory if it does not exist.
    """
    if not os.path.exists(directory):
        os.makedirs(directory)
    return directory

def clear_dir(directory):
    """
    Removes all files in the given directory.
    """
    if not os.path.isdir(directory): raise Exception("%s is not a directory"%(directory))
    if type(directory) != str: raise Exception("string type required for directory: %s"%(directory))
    if directory in ["..",".", "","/","./","../","*"]: raise Exception("trying to delete current directory, probably bad idea?!")
    
    for f in os.listdir(directory):
        path = os.path.join(directory, f)
        try:
            if os.path.isfile(path):
                os.remove(path)
            elif os.path.isdir(path):
                shutil.rmtree(path)
        except Exception as e:
            print(e)

## Prepare data

In [None]:
# Define the model and acquisition parameters
par = {'nx':601,   'dx':0.01, 'ox':0,
       'nz':221,   'dz':0.01, 'oz':0,
        'num_shots':30,    'ds':0.2,   'os':0,  'sz':0,
       'num_receivers_per_shot':300,   'dr':0.02,  'orec':0, 'rz':0,
       'nt':4000,  'dt':0.001,  'ot':0,
       'freq': 20, 'num_sources_per_shot':1, 'num_dims':2,
       'num_batches':30,
        'FWI_itr': 100
      }


In [None]:
exp_name = './Exp_Marmousi_Unet_dm/'
type='UNet'
velocity_model="Marmousi"
seismic_path = exp_name + "seismic_data/"
get_dir(seismic_path)
obs_dir = seismic_path + f"shot_{velocity_model}_born_big" 

output_path = exp_name + f"velocity_{type}/"
get_dir(output_path)
model_dir = output_path + f"velocity_{velocity_model}"

vel_path ="../data/"
vel_dir  = vel_path + "Marm.bin"

print("obs_dir:", obs_dir)
print("vel_dir:", vel_dir)
print("model_dir:", model_dir)

In [None]:
# Load the true model
model_true = (np.fromfile(vel_dir, np.float32)
              .reshape(par['nz'], par['nx']))


data_true = (
    torch.from_file(obs_dir,
                    size=par['num_shots']*par['num_receivers_per_shot']*par['nt'])
    .reshape(par['num_shots'], par['num_receivers_per_shot'], par['nt'])
).to(device)

# function to get water layer mask

# Load the true model
model_true = (np.fromfile(vel_dir, np.float32)
              .reshape(par['nz'], par['nx']))
model_init = sp.ndimage.gaussian_filter(model_true, sigma=[2,2])

model_sclar= (model_true - model_init)/model_true

data_true = (
    torch.from_file(obs_dir,
                    size=par['num_shots']*par['num_receivers_per_shot']*par['nt'])
    .reshape(par['num_shots'], par['num_receivers_per_shot'], par['nt'])
).to(device)

def mask(model,value):
    """
    Return a mask for the model (m) using the (value)
    """
    mask = model > value
    mask = mask.astype(int)
    mask[:21] = 0
    return mask

mask = mask(model_true, 0)

In [None]:
v_vmin, v_vmax = np.percentile(model_true, [2,98]) 
m_vmin, m_vmax = np.percentile(model_sclar, [2,98]) 

import matplotlib.pyplot as plt

# 假设 model_true 和 model_pred 都是 numpy 数组
fig, ax = plt.subplots(1, 2, figsize=(12, 5))

# 第一个子图：model_true
im1 = ax[0].imshow(model_true, cmap='jet', vmin=v_vmin, vmax=v_vmax,
                   extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0))
ax[0].set_title('Marmousi true model')
fig.colorbar(im1, ax=ax[0], fraction=0.046, pad=0.04)

# 第二个子图：model_pred （如果你有另一个模型）
im2 = ax[1].imshow(model_sclar, cmap='jet', vmin=m_vmin, vmax=m_vmax,
                   extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0))
ax[1].set_title('Predicted model')
fig.colorbar(im2, ax=ax[1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()


vpmin, vpmax = torch.quantile(data_true[par['num_shots']//2], torch.tensor([0.01, 0.99]).to(device))

plt.figure(figsize=(6, 5))
plt.imshow(data_true[par['num_shots']//2].cpu().detach().numpy().T, aspect='auto', cmap='gray', vmin=vpmin, vmax=vpmax)
plt.xlabel("Receiver")
plt.ylabel("Time sample")
plt.title("Observed VX")
plt.colorbar()
plt.show()


In [None]:
# Create the source the wavelet
source_wavelet = deepwave.wavelets.ricker(par['freq'], par['nt'], par['dt'], 1/par['freq'])
# Initialize the FWI class
params = fwi.FWIParams(par, torch.tensor(source_wavelet), 1)
# Get the source receiver coordinates
x_s1, x_r1 = params.get_coordinate(1)
# Create a wavelet for every source
source_amplitudes = params.create_wavelet(torch.tensor(source_wavelet))

## Run FWI with the proposed approach $\delta m$

In [None]:
# # Compute source illumination
# device = 'cpu'
# SI = fwi.source_illumination(torch.tensor(model_init), source_amplitudes, par['dx'], par['dt'], x_s1, device=device)
# # clear memory
# torch.cuda.empty_cache()
# gc.collect()
# # Visualize the source illumination
# simin, simax = np.percentile(SI.cpu(), [2,98])
# show_model(SI.cpu(), cmap='bwr', vmin=simin, vmax=simax, figsize=(12, 5), extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0),
#            title='Source Illumination')
# np.savez(f'{output_path}/source_illumination', si=SI.detach().clone().cpu())

In [None]:
data = np.load(f'{output_path}/source_illumination.npz')

# # 提取其中的 'si' 数组
si_array = data['si']

# # 转换为 PyTorch 张量，并移动到 CPU
SI = torch.tensor(si_array, device=device)
simin, simax = np.percentile(SI.cpu(), [2,98])
show_model(SI.cpu(), cmap='bwr', vmin=simin, vmax=simax, figsize=(8, 5), 
           title='Source Illumination')
SI = torch.tensor(SI).to(device)

In [None]:
# Run FWI with the proposed method
scatter = torch.zeros_like(torch.tensor(model_true).float())
model_velocity = torch.tensor(model_true)
# Move data to GPU if using GPU
model = torch.tensor(scatter).clone().to(device)
model.requires_grad = True
data_true = torch.tensor(data_true).float()
mask = torch.tensor(mask).to(device)
# Create lists to save results
gradients, dm1s, gradients_pred, dms, updates, fwi_loss, ssim_list, network_loss, alphas = [], [], [], [], [], [], [], [], []

data_range = model_sclar.max() - model_sclar.min()
loss_fn = torch.nn.MSELoss() # Misfit function for FWI and Born modelling
optimizer = torch.optim.SGD([{'params': [model], 'lr': 1e-2}]) # Optimizer to run FWI with step size: lr
# Create the network, its optimizer and the loss function to train it
network = UNet(n_channels=1, n_classes=1, hidden_channels=256).to(device)
optimizer_unet = torch.optim.Adam(network.parameters(), lr=1e-4)
l2_norm = torch.nn.MSELoss()
network_iter_init =1000 # Number of epochs to train the network in the first FWI iteration
network_iter_fin = 300 # Number of epochs to train the network in every FWI iteration except the first one
tsamples = 0 # Number of time samples starting from zero to exclude from computing the misfit
FWI_iter = 100 # Number of FWI iterations
t_start = time.time()
for iteration in tqdm(range(FWI_iter)):
    # Compute the structural similarity index measure (ssim) between the current and the true models
    ssim_metric = ssim(model.detach().cpu().numpy(), model_sclar, data_range=data_range)
    ssim_list.append(ssim_metric)
    # Compute FWI gradient
    optimizer.zero_grad()
    grad, iter_loss = fwi.compute_gradient_born(params, model_velocity, model, data_true, loss_fn, tsamples, device)
    fwi_loss.append(iter_loss)
    print(f'FWI iteration: {iteration} loss = {fwi_loss[-1]}, ssim = {ssim_list[-1]}')
    # Clip the gradient values
    torch.nn.utils.clip_grad_value_(model, torch.quantile(grad.detach().abs(), 0.98))
    # Apply source illumination to the gradient
    grad = grad / SI
    # grad = grad / SI
    if iteration == 0: gmax0 =  torch.abs(grad.detach()).max()
    # Normalize the gradient, mask it around the sources and apply taperinn to the shallower and deeper parts
    grad = (grad /gmax0) * mask
    gradients.append(grad.cpu().detach().numpy())
    # Compute dm1 with the gradient as the perturbation
    dm1 = grad.detach().clone().to(device)
    dm1.requires_grad = True
    dm1 = fwi.compute_dm1(params, model_velocity.detach().clone(), dm1 , loss_fn, tsamples, device)
    # Apply source illumination to dm1
    dm1 = dm1 / SI
    # dm1 = dm1 / SI
    if iteration == 0: dm1max0 =  1e1 * torch.abs(dm1.detach()).max()
    # Normalize dm1 and mask it around the sources
    dm1 = (dm1 / dm1max0)  * mask
    dm1s.append(dm1.cpu().detach().numpy())
    # Train the network
    training_pair = {'x': dm1.clone().unsqueeze(0).unsqueeze(0),
                    'y': grad.clone().unsqueeze(0).unsqueeze(0)}
    network_iter = network_iter_init if iteration == 0 else network_iter_fin
    lossn = train(network, training_pair, optimizer_unet, l2_norm, network_iter, use_scheduler=True, device=device)
    network_loss.extend(lossn)
    # Get the gradient from the network
    with torch.no_grad():
        g = network(training_pair['x']).squeeze() * mask
    gradients_pred.append(g.cpu().detach().numpy())
    # Get dm from the network
    with torch.no_grad():
        dm = network(training_pair['y']).squeeze() * mask
    dms.append(dm.cpu().detach().numpy())
    # Update the model
    # if iteration > 0:
    #     delta_model = model.detach().clone() - previous_model
    #     delta_grad = grad.detach().clone() - previous_grad
    #     alpha = fwi.bb_step(delta_model, delta_grad, 'short')
    #     alphas.append(alpha)
    #     print(f"[iter {iteration}] lr_old = {optimizer.param_groups[-1]['lr']}")
    #     optimizer.param_groups[-1]['lr'] = alpha
    #     print(f"[iter {iteration}] lr_new = {optimizer.param_groups[-1]['lr']}")
    # # Save the current solution and gradient for calculating the step size in the next iteration
    # previous_model = model.detach().clone()
    # previous_grad = grad.detach().clone()
    model.grad.data[:] = dm.detach().clone()
    optimizer.step()
    updates.append(model.detach().clone().cpu().numpy())
    # Plot the results
    show_one_iter_dm(grad.cpu(), dm1.cpu(), g.cpu(), dm.cpu(), model.detach().cpu(), lossn, iteration=iteration,
                cmap='bwr', vmin=m_vmin, vmax=m_vmax, extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0), save_path=f'{exp_name}')
t_end = time.time()
t_delta = t_end - t_start
print(f'Runtime:{datetime.timedelta(seconds=t_delta)}')
# Save the results
np.savez(f'{output_path}/losses', fwi_loss=np.array(fwi_loss),
                               network_loss=np.array(network_loss),
                               ssim=np.array(ssim_list),
                               )
np.savez(f'{output_path}/results', updates=np.array(updates), 
                            gradients=np.array(gradients), 
                            dm1s=np.array(dm1s),
                            gradients_pred=np.array(gradients_pred),
                            dms=np.array(dms),)
# Save the network weights
torch.save(network.state_dict(), f'{output_path}/network_weights.pth')

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12,4))
axs[0].plot(fwi_loss)
axs[0].set_title('Data Loss')
axs[0].set_xlabel('Iteration')
axs[0].spines['right'].set_visible(False)
axs[0].spines['top'].set_visible(False)
axs[1].plot(ssim_list)
axs[1].set_title('SSIM')
axs[1].set_xlabel('Iteration')
axs[1].spines['right'].set_visible(False)
axs[1].spines['top'].set_visible(False)
plt.savefig(f'{exp_name}/losses.png',  bbox_inches='tight', dpi=300)

In [None]:
show_one_iter_dm(grad.cpu(), dm1.cpu(), g.cpu(), dm.cpu(), model.detach().cpu(), lossn, iteration=FWI_iter,
                cmap='bwr', vmin=m_vmin, vmax=m_vmax, extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0),)

In [None]:
show_model(updates[0], cmap='jet', vmin=m_vmin, vmax=m_vmax, figsize=(12, 5), extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0),
           title='First update')

In [None]:
show_model(updates[-1], cmap='gray', vmin=m_vmin, vmax=m_vmax, figsize=(12, 5), extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0),
           title='Last update')

In [None]:
v_vmin, v_vmax = np.percentile(model.cpu().detach().numpy(), [2,98]) 
m_vmin, m_vmax = np.percentile(model_sclar, [2,98]) 

import matplotlib.pyplot as plt

# 假设 model_true 和 model_pred 都是 numpy 数组
fig, ax = plt.subplots(1, 2, figsize=(12, 5))

# 第一个子图：model_true
im1 = ax[0].imshow(model.cpu().detach().numpy(), cmap='gray', vmin=m_vmin, vmax=m_vmax,
                   extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0))
ax[0].set_title('Marmousi true model')
fig.colorbar(im1, ax=ax[0], fraction=0.046, pad=0.04)

# 第二个子图：model_pred （如果你有另一个模型）
im2 = ax[1].imshow(model_sclar, cmap='gray', vmin=m_vmin, vmax=m_vmax,
                   extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0))
ax[1].set_title('Predicted model')
fig.colorbar(im2, ax=ax[1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()
