In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')
import gc
import os
import time
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import scipy as sp
import cv2
from tqdm.notebook import tqdm
from skimage.metrics import structural_similarity as ssim
import deepwave

import sys
import os
sys.path.append(os.path.abspath(".."))
from deepinvhessian import fwi, fwi_lbfgs
from deepinvhessian.utilities import *
from deepinvhessian.filters import *
from deepinvhessian.train import *
from deepinvhessian.masks import *
from unet import *

In [None]:
set_seed(14)
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

In [None]:
def get_dir(directory):
    """
    Creates the given directory if it does not exist.
    """
    if not os.path.exists(directory):
        os.makedirs(directory)
    return directory

def clear_dir(directory):
    """
    Removes all files in the given directory.
    """
    if not os.path.isdir(directory): raise Exception("%s is not a directory"%(directory))
    if type(directory) != str: raise Exception("string type required for directory: %s"%(directory))
    if directory in ["..",".", "","/","./","../","*"]: raise Exception("trying to delete current directory, probably bad idea?!")
    
    for f in os.listdir(directory):
        path = os.path.join(directory, f)
        try:
            if os.path.isfile(path):
                os.remove(path)
            elif os.path.isdir(path):
                shutil.rmtree(path)
        except Exception as e:
            print(e)

In [None]:
# Define the model and acquisition parameters
# Define the model and acquisition parameters
par = {'nx':601,   'dx':0.015, 'ox':0,
       'nz':221,   'dz':0.015, 'oz':0,
        'num_shots':30,    'ds':0.3,   'os':0,  'sz':0,
       'num_receivers_per_shot':300,   'dr':0.03,  'orec':0, 'rz':0,
       'nt':4000,  'dt':0.001,  'ot':0,
       'freq': 10, 'num_sources_per_shot':1, 'num_dims':2,
       'num_batches':30,
        'FWI_itr': 100
      }


In [None]:
exp_name = './Exp_Marmousi_lbfgs/'
type='lbfgs-5'
velocity_model="Marmousi"
seismic_path = exp_name + "seismic_data/"
get_dir(seismic_path)
obs_dir = seismic_path + f"shot_{velocity_model}_born" 

output_path = exp_name + f"velocity_{type}/"
get_dir(output_path)
model_dir = output_path + f"velocity_{velocity_model}"

vel_path ="../data/"
vel_dir  = vel_path + "Marm.bin"

print("obs_dir:", obs_dir)
print("vel_dir:", vel_dir)
print("model_dir:", model_dir)

In [None]:
# Load the true model
model_true = (np.fromfile(vel_dir, np.float32)
              .reshape(par['nz'], par['nx']))
model_init = sp.ndimage.gaussian_filter(model_true, sigma=[2,2])

model_sclar= (model_true - model_init)/model_true

data_true = (
    torch.from_file(obs_dir,
                    size=par['num_shots']*par['num_receivers_per_shot']*par['nt'])
    .reshape(par['num_shots'], par['num_receivers_per_shot'], par['nt'])
).to(device)

In [None]:
# function to get water layer mask
def mask(model,value):
    """
    Return a mask for the model (m) using the (value)
    """
    mask = model > value
    mask = mask.astype(int)
    mask[:21] = 0
    return mask

mask = mask(model_true, 0)

In [None]:
v_vmin, v_vmax = np.percentile(model_true, [2,98]) 
m_vmin, m_vmax = np.percentile(model_sclar, [2,98]) 

import matplotlib.pyplot as plt

# 假设 model_true 和 model_pred 都是 numpy 数组
fig, ax = plt.subplots(1, 2, figsize=(12, 5))

# 第一个子图：model_true
im1 = ax[0].imshow(model_true, cmap='jet', vmin=v_vmin, vmax=v_vmax,
                   extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0))
ax[0].set_title('Marmousi true model')
fig.colorbar(im1, ax=ax[0], fraction=0.046, pad=0.04)

# 第二个子图：model_pred （如果你有另一个模型）
im2 = ax[1].imshow(model_sclar, cmap='jet', vmin=m_vmin, vmax=m_vmax,
                   extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0))
ax[1].set_title('Predicted model')
fig.colorbar(im2, ax=ax[1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()


vpmin, vpmax = torch.quantile(data_true[par['num_shots']//2], torch.tensor([0.01, 0.99]).to(device))

plt.figure(figsize=(6, 5))
plt.imshow(data_true[par['num_shots']//2].cpu().detach().numpy().T, aspect='auto', cmap='gray', vmin=vpmin*0.1, vmax=vpmax*0.1)
plt.xlabel("Receiver")
plt.ylabel("Time sample")
plt.title("Observed VX")
plt.colorbar()
plt.show()


In [None]:
# Create the source the wavelet
source_wavelet = deepwave.wavelets.ricker(par['freq'], par['nt'], par['dt'], 1/par['freq'])
# Initialize the FWI class
params = fwi.FWIParams(par, torch.tensor(source_wavelet), 1)
# Get the source receiver coordinates
x_s1, x_r1 = params.get_coordinate(1)
# Create a wavelet for every source
source_amplitudes = params.create_wavelet(torch.tensor(source_wavelet))

In [None]:
# Visualize the source wavelet
plt.plot(np.arange(0,par['nt'])*par['dt'], source_amplitudes[0,0,:])
plt.xlabel('Time (s)')
plt.title('Source wavelet')
plt.show()

In [None]:
scatter = torch.zeros_like(torch.tensor(model_true).float())
model_velocity = torch.tensor(model_true).to(device)
# Move data to GPU if using GPU
model = torch.tensor(scatter).clone().to(device)
model.requires_grad = True
data_true = torch.tensor(data_true).float()
mask = torch.tensor(mask).to(device)
# Create lists to save results
gradients, updates, fwi_loss, ssim_list, alphas = [], [], [], [], []

In [None]:
def print_device_info(*args, **kwargs):
    for i, arg in enumerate(args, start=1):
        if torch.is_tensor(arg):
            print(f"arg{i} device: {arg.device}")
        else:
            print(f"arg{i} is not a torch.Tensor (type={type(arg)})")

    for k, v in kwargs.items():
        if torch.is_tensor(v):
            print(f"{k} device: {v.device}")
        else:
            print(f"{k} is not a torch.Tensor (type={type(v)})")

In [None]:
# 打印 FWI_LBFGS 初始化参数的 device
print_device_info(
    data_true,
    source_amplitudes,
    params.s_cor,
    params.r_cor,
    mask,
    scatter,
    model_velocity,
    source_amplitudes
)


In [None]:
# Initialize the FWI class
fwi = fwi_lbfgs.FWI_LBFGS(data_true, source_amplitudes, params.s_cor, params.r_cor, params.dx, params.dt, 
                          params.num_batches, model_true.shape, mask.to(device), scaling=1., device=device)
fwi.forward(model_velocity.detach().cpu().numpy().ravel(), scatter.detach().cpu().numpy().ravel(), scipy=True)

In [None]:
# Compute the gradient
grad = fwi.grad(model_velocity.detach().cpu().numpy().ravel(), scatter.detach().cpu().numpy().ravel(), scipy=True)[-1]
# Compute the gradient scaling
fwi.scaling = (grad * mask.cpu().numpy().ravel()).max()
print(fwi.scaling)

In [None]:
# Create functions to run FWI with Scipy optimization
fun = lambda x: fwi.forward(model_velocity.detach().cpu().numpy().ravel(), x, scipy=True)
grad = lambda x: fwi.grad(model_velocity.detach().cpu().numpy().ravel(),x, scipy=True)[-1]
callback = lambda x: fwi.callback(x, model_sclar.ravel(), model_velocity.detach().cpu().numpy().ravel(), MSSIM, data_residual, nWE)

In [None]:
from scipy.optimize import minimize
MSSIM, data_residual, nWE = [], [], []
model0 = scatter.clone().ravel()
# Run FWI using Scipy L-BFGS
# nl = minimize(fun, model0, jac=grad,
#               method='L-BFGS-B', 
#               callback=callback,
#               options=dict(maxiter=100, 
#                            maxfun=500,     # 更少函数评估
#                            maxcor=2,       # 记忆步数减少
#                   ftol=1e-2,      # 收敛更松
#                   gtol=1e-3,      # 梯度容忍更松
#                   maxls=5, 
#                   maxstep=0.01),
          #    )
          
          
nl = minimize(fun, model0, jac=grad,
method='L-BFGS-B', 
callback=callback,
options=dict(maxiter=100, 
)
)
MSSIM = np.array(MSSIM)
nWE = np.array(nWE)
data_residual = np.array(data_residual)

In [None]:
# Save the inverted model
update = nl.x.reshape(par['nz'], par['nx'])
np.save(f'{exp_name}/update', update)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 8))
im1 = axs[0].imshow(update, cmap='gray', vmin=m_vmin, vmax=m_vmax, extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0))
axs[0].set_title(r'Inverted model', fontsize=14)
axs[0].set_xlabel(r'x [m]')
axs[0].set_ylabel(r'z [m]')
im2 = axs[1].imshow(model_sclar, cmap='gray', vmin=m_vmin, vmax=m_vmax, extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0))
axs[1].set_title(r'Inverted model', fontsize=14)
axs[1].set_xlabel(r'x [m]')
axs[1].set_ylabel(r'z [m]')
fig.colorbar(im1, ax=axs, shrink=0.46, pad=0.02)
plt.savefig(f'{exp_name}/model_updated', bbox_inches='tight', dpi=300)
plt.show()

m_vmin, m_vmax

In [None]:
print(data_residual)

In [None]:
print(MSSIM)