In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')
import gc
import os
import time
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import scipy as sp
import cv2
from tqdm.notebook import tqdm
from skimage.metrics import structural_similarity as ssim
import deepwave

import sys
import os
sys.path.append(os.path.abspath(".."))
from deepinvhessian import fwi
from deepinvhessian.utilities import *
from deepinvhessian.filters import *
from deepinvhessian.train import *
from deepinvhessian.masks import *
from unet import *

In [None]:
def get_dir(directory):
    """
    Creates the given directory if it does not exist.
    """
    if not os.path.exists(directory):
        os.makedirs(directory)
    return directory

def clear_dir(directory):
    """
    Removes all files in the given directory.
    """
    if not os.path.isdir(directory): raise Exception("%s is not a directory"%(directory))
    if type(directory) != str: raise Exception("string type required for directory: %s"%(directory))
    if directory in ["..",".", "","/","./","../","*"]: raise Exception("trying to delete current directory, probably bad idea?!")
    
    for f in os.listdir(directory):
        path = os.path.join(directory, f)
        try:
            if os.path.isfile(path):
                os.remove(path)
            elif os.path.isdir(path):
                shutil.rmtree(path)
        except Exception as e:
            print(e)

In [None]:
set_seed(14)
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

In [None]:
# Define the model and acquisition parameters
par = {'nx':601,   'dx':0.01, 'ox':0,
       'nz':221,   'dz':0.01, 'oz':0,
        'num_shots':30,    'ds':0.2,   'os':0,  'sz':0,
       'num_receivers_per_shot':300,   'dr':0.02,  'orec':0, 'rz':0,
       'nt':4000,  'dt':0.001,  'ot':0,
       'freq': 20, 'num_sources_per_shot':1, 'num_dims':2,
       'num_batches':30,
        'FWI_itr': 100
      }


In [None]:
exp_name = './Exp_marmousi_attention_noise/'
velocity_model="Marmousi"
seismic_path = exp_name + "seismic_data"
get_dir(seismic_path)
obs_dir = exp_name + "seismic_data/" + f"shot_{velocity_model}_born_big" 
obs_dir_noise = exp_name + "seismic_data/" + f"shot_{velocity_model}_born_noise_14_big" 

vel_path ="../data/"
vel_dir  = vel_path + "/Marm.bin"

print("obs_dir:", obs_dir)
print("vel_dir:", vel_dir)

In [None]:
# Load the true model
model_true = (np.fromfile(vel_dir, np.float32)
              .reshape(par['nz'], par['nx']))

# function to get water layer mask
# def mask(model,value):
#     """
#     Return a mask for the model (m) using the (value)
#     """
#     mask = model > value
#     mask = mask.astype(int)
#     mask[:10] = 0
#     return mask

# mask = mask(model_true, 1.8)

In [None]:
model_init = sp.ndimage.gaussian_filter(model_true, sigma=[2,2])

model_sclar= (model_true - model_init)/model_true



In [None]:
m_vmin, m_vmax = np.percentile(model_true, [2,98]) 
s_vmin, s_vmax = np.percentile(model_sclar, [2,98]) 
fig, ax = plt.subplots(1, 2, figsize=(12, 5))

# 第一个子图：model_true
im1 = ax[0].imshow(model_true, cmap='jet', vmin=m_vmin, vmax=m_vmax)
ax[0].set_title('Marmousi true model')
fig.colorbar(im1, ax=ax[0], fraction=0.046, pad=0.04)

# 第二个子图：model_pred （如果你有另一个模型）
im2 = ax[1].imshow(model_sclar, cmap='jet', vmin=s_vmin, vmax=s_vmax)
ax[1].set_title('Predicted model')
fig.colorbar(im2, ax=ax[1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()


In [None]:
# Create the source the wavelet
source_wavelet = deepwave.wavelets.ricker(par['freq'], par['nt'], par['dt'], 1/par['freq'])
# Initialize the FWI class
params = fwi.FWIParams(par, torch.tensor(source_wavelet), 1)
# Get the source receiver coordinates
x_s1, x_r1 = params.get_coordinate(1)
# Create a wavelet for every source
source_amplitudes = params.create_wavelet(torch.tensor(source_wavelet))

In [None]:
# Visualize the source wavelet
plt.plot(np.arange(0,par['nt'])*par['dt'], source_amplitudes[0,0,:])
plt.xlabel('Time (s)')
plt.title('Source wavelet')
plt.show()

In [None]:
print(model_true.shape)

In [None]:
# Simulate the true data
data_true = fwi.forward_modelling_born(params, torch.tensor(model_true).float(), torch.tensor(model_sclar).float(), device)

In [None]:
# Visualize the true data
show_3_shots(data_true.cpu(), [0, 14, 29], clip=0.1, extent=(0,int(par['nx']*par['dx']), int(par['nt']*par['dt']), 0), 
        ylim=(int(par['nt']*par['dt']), 0),)

In [None]:
data_true.cpu().numpy().tofile(obs_dir)

In [None]:
print(data_true.shape)

In [None]:
print("写入后的文件大小（字节）:", os.path.getsize(obs_dir))

In [None]:
print(obs_dir)

In [None]:
def add_noise(data, noise_level, kind="gauss", device="cpu"):
    noise = noise_level * torch.randn_like(data)

    print(f"[DEBUG] data min={data.min().item():.4e}, max={data.max().item():.4e}")
    print(f"[DEBUG] noise min={noise.min().item():.4e}, max={noise.max().item():.4e}")

    if kind == "gauss":
        SNR = 20 * torch.log10(torch.norm(data)/(torch.norm(noise)+1e-12))
        noise_info = {'SNR': SNR.cpu().detach()}
        noisy_data = data + noise
        return noisy_data, noise_info

    elif kind == "laplace":
        # laplace 生成的部分要修正
        pass
    else:
        raise NotImplementedError(f"No such kind of noise")
    

In [None]:
NOISE_LEVEL = 0.0000007
noisy_data, noise_info = add_noise(data_true, NOISE_LEVEL, "gauss", device)

In [None]:
print(noise_info)

In [None]:
show_3_shots(noisy_data.cpu(), [0, 14, 29], clip=0.1, extent=(0,int(par['nx']*par['dx']), int(par['nt']*par['dt']), 0), 
        ylim=(int(par['nt']*par['dt']), 0),)

In [None]:
noisy_data.cpu().numpy().tofile(obs_dir_noise)

In [None]:
print(obs_dir_noise)