In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')
import gc
import os
import time
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import scipy as sp
import cv2
from tqdm.notebook import tqdm
from skimage.metrics import structural_similarity as ssim
import deepwave

import sys
import os
sys.path.append(os.path.abspath(".."))
from deepinvhessian import fwi, fwi_bgp
from deepinvhessian.utilities import *
from deepinvhessian.filters import *
from deepinvhessian.train import *
from deepinvhessian.masks import *
from unet import *

In [None]:
def get_dir(directory):
    """
    Creates the given directory if it does not exist.
    """
    if not os.path.exists(directory):
        os.makedirs(directory)
    return directory

def clear_dir(directory):
    """
    Removes all files in the given directory.
    """
    if not os.path.isdir(directory): raise Exception("%s is not a directory"%(directory))
    if type(directory) != str: raise Exception("string type required for directory: %s"%(directory))
    if directory in ["..",".", "","/","./","../","*"]: raise Exception("trying to delete current directory, probably bad idea?!")
    
    for f in os.listdir(directory):
        path = os.path.join(directory, f)
        try:
            if os.path.isfile(path):
                os.remove(path)
            elif os.path.isdir(path):
                shutil.rmtree(path)
        except Exception as e:
            print(e)

In [None]:
set_seed(14)
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

In [None]:
# Define the model and acquisition parameters
par = {'nx':601,   'dx':0.015, 'ox':0,
       'nz':221,   'dz':0.015, 'oz':0,
        'num_shots':30,    'ds':0.3,   'os':0,  'sz':0,
       'num_receivers_per_shot':300,   'dr':0.03,  'orec':0, 'rz':0,
       'nt':4000,  'dt':0.001,  'ot':0,
       'freq': 10, 'num_sources_per_shot':1, 'num_dims':2,
       'num_batches':30,
        'FWI_itr': 100
      }


In [None]:
exp_name = './Exp_Marmousi_dm/'
velocity_model="Marmousi"
seismic_path = exp_name + "seismic_data"
get_dir(seismic_path)
obs_dir = exp_name + "seismic_data/" + f"shot_{velocity_model}" 

vel_path ="../data/"
vel_dir  = vel_path + "Marm.bin"

print("obs_dir:", obs_dir)
print("vel_dir:", vel_dir)

In [None]:
# Load the true model
model_true = (np.fromfile(vel_dir, np.float32)
              .reshape(par['nz'], par['nx']))

# function to get water layer mask
def mask(model,value):
    """
    Return a mask for the model (m) using the (value)
    """
    mask = model > value
    mask = mask.astype(int)
    mask[:21] = 0
    return mask

mask = mask(model_true, 1.5)

In [None]:
m_vmin, m_vmax = np.percentile(model_true, [2,98]) 
show_model(model_true, cmap='jet', vmin=m_vmin, vmax=m_vmax, figsize=(12, 5), extent=(0, par['nx']*par['dx']*1000, par['nz']*par['dx']*1000, 0),
           title='Marmousi true model')

In [None]:
# Create the source the wavelet
source_wavelet = deepwave.wavelets.ricker(par['freq'], par['nt'], par['dt'], 1/par['freq'])
# Initialize the FWI class
params = fwi.FWIParams(par, torch.tensor(source_wavelet), 1)
# Get the source receiver coordinates
x_s1, x_r1 = params.get_coordinate(1)
# Create a wavelet for every source
source_amplitudes = params.create_wavelet(torch.tensor(source_wavelet))

In [None]:
# Visualize the source wavelet
plt.plot(np.arange(0,par['nt'])*par['dt'], source_amplitudes[0,0,:])
plt.xlabel('Time (s)')
plt.title('Source wavelet')
plt.show()

In [None]:
# Simulate the true data
data_true = fwi.forward_modelling(params, torch.tensor(model_true).float(), device)

In [None]:
# Visualize the true data
show_3_shots(data_true.cpu(), [0, 14, 29], clip=0.01, extent=(0,int(par['nx']*par['dx']), int(par['nt']*par['dt']), 0), 
        ylim=(int(par['nt']*par['dt']), 0),)

In [None]:
data_true.cpu().numpy().tofile(obs_dir)

In [None]:
masks_d = params.create_masks(window_size = 1000, v_direct = 1.5, ot=600)

masks_d = torch.tensor(masks_d).float().to(device)

In [None]:
# Visualize the true data
show_3_shots(masks_d.cpu(), [0, 14, 29], clip=0.01, extent=(0,int(par['nx']*par['dx']), int(par['nt']*par['dt']), 0), 
        ylim=(int(par['nt']*par['dt']), 0),)

In [None]:
data = data_true*masks_d

In [None]:
# Visualize the true data
show_3_shots(data.cpu(), [0, 14, 29], clip=0.01, extent=(0,int(par['nx']*par['dx']), int(par['nt']*par['dt']), 0), 
        ylim=(int(par['nt']*par['dt']), 0),)

In [None]:
print(data_true.shape)

In [None]:
data.cpu().numpy().tofile(obs_dir)

In [None]:
print(obs_dir)