## Assignment 5

In this assignment we will apply the linearized inversion with FWI and compare the results with FWI alone.

In [None]:
import time
import torch
import numpy as np
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
import deepwave
from scipy import signal,interpolate
from torchvision.transforms import GaussianBlur 
%matplotlib inline



def Plot_model(m,par,name=None,**kwargs):
    """
    plot a 2D model 
    
    Arguments
    ----------
    m: 2D numpy array
         array containing the model
    par : dictionary 
        dictionary containing the axis points,increments, and origin points. 
        (i.e,par['ox'],par['dx'],par['nx'],par['nz'],par['dz'],par['oz'])
    ----------
    Optional 
    ----------
    vmax: float
          Maximum value for the plot 
    vmin: float
          Minimum value for plot
    cmap: str
          Matplotlib-colormap
    name: str 
          to save the figure with the corresponding 'name' in a 'Fig' directory
    """

    vmax = kwargs.pop('vmax', None)
    vmin = kwargs.pop('vmin', None)
    name = kwargs.pop('name', None)
    cmap = kwargs.pop('cmap', 'jet')
    title = kwargs.pop('title', None)
    if 'vmin'==None: vmin, _ = np.percentile(m.T,[2,98])
    if 'vmax'==None: _, vmax = np.percentile(m.T,[2,98])
    plt.figure(figsize=(10,3))
    plt.imshow(m,cmap=cmap,vmin=vmin,vmax=vmax,extent=[par['ox'],par['dx']*par['nx'],par['nz']*par['dz'],par['oz']])
    plt.axis('tight')
    plt.xlabel('Distance (km)',fontsize=18,weight='heavy')
    plt.ylabel('Depth (km)',fontsize=18,weight='heavy')
    plt.colorbar(label='km/s')
    if title != None: plt.title(title)
    if name!=None:
        if not os.path.isdir('./Fig'): os.mkdir('./Fig')
        plt.savefig('./Fig/'+name,bbox_inches='tight')
    
    

## Setting the parameters

In [None]:
device = torch.device('cuda:0')



# Define the model and achuisition parameters
par = {'nx':601,   'dx':0.015, 'ox':0,
       'nz':221,   'dz':0.015, 'oz':0,
       'ns':30,    'ds':0.3,   'osou':0,  'sz':0.03,
       'nr':300,   'dr':0.03,  'orec':0,  'rz':0.03,
       'nt':3000,  'dt':0.0013,  'ot':0,
       'freq':15,
       'num_batches':10, # increase thus number if you have a CUDA out of memory error 
       'FWI_itr': 100,
       'LSRTM_itr': ,  # inner iteration
       'num_dims': 2 
      }

# Mapping the par dictionary to variables 
for k in par:
    locals()[k] = par[k]
    
fs = 1/dt # sampling frequency


# Don't change the below two lines 
num_sources_per_shot=1
num_dims = 2 



## Loading input files 

In [None]:
# Load the true model for forward modelling 
path = '../Assignment1/'
velocity_file= path + 'Marm.bin' # true model 

#  Load the velocity model 
m_true =(np.fromfile(velocity_file, np.float32)
              .reshape(nz, nx))


m0 = gaussian_filter(m_true,sigma=(15,5)) # You can change this if you like but don't use a very small smoothing



# Get a mask for the water layer (P.S water veocity = 1.5 km/s)
msk = np.ones_like(m_true)
msk[:20,] =  0



# convert to tensor
m_true = torch.tensor(m_true,dtype=torch.float32)
m0 = torch.tensor(m0,dtype=torch.float32)
msk = torch.tensor(msk,dtype=torch.float32)



Plot_model(m_true,par)
Plot_model(m0,par)

print(f'vel shape {m_true.shape} (nx,nz)  || init shape {m0.shape} (nx,nz)')

## This time we also need a scattering (perturbation) model

In [None]:
dm = torch.zeros_like(m0)  # Scattering model ( perturbation)


## The acquisition set-up

 Create arrays containing the source and receiver locations
 
    x_s: Source locations [num_shots, num_sources_per_shot, num_dimensions].
    
    x_r: Receiver locations [num_shots, num_receivers_per_shot, num_dimensions]

In [None]:

x_s = torch.zeros(ns, num_sources_per_shot, num_dims)
x_s[:, 0, 1] = torch.arange(ns).float() * ds  
x_s[:, 0, 0] = sz

x_r = torch.zeros(ns, nr, num_dims)
x_r[0, :, 1] = torch.arange(nr).float() * dr
x_r[:, :, 1] = x_r[0, :, 1].repeat(ns, 1)
x_r[:, :, 0] = rz





## Create source wavelet
    [nt, num_shots, num_sources_per_shot]

I use Deepwave's Ricker wavelet function. The result is a normal Tensor - you can use any function to create the wavelet but it needs to be converted to tensor. 

In [None]:


source_wavelet = (deepwave.wavelets.ricker(freq, nt, dt, 1/freq)
                          .reshape(-1, 1, 1)
                          .repeat(1, ns, num_sources_per_shot))
print(source_wavelet.shape)


plt.plot(np.arange(0,nt)*dt,source_wavelet[:,0,0])
plt.xlabel('Time (s)')

## Forward modeling 

In [None]:
def highpass_filter(freq, wavelet, dt):
    """
    Filter out low frequency

    Parameters
    ----------
    freq : :obj:`int`
    Cut-off frequency
    wavelet : :obj:`torch.Tensor`
    Tensor of wavelet
    dt : :obj:`float32`
    Time sampling
    Returns
    -------
    : :obj:`torch.Tensor`
    Tensor of highpass frequency wavelet
    """

    sos = signal.butter(6,  freq / (0.5 * (1 / dt)), 'hp', output='sos') 
    return torch.tensor( signal.sosfiltfilt(sos, wavelet,axis=0).copy(),dtype=torch.float32)



# Create 'true' data
prop = deepwave.scalar.Propagator({'vp': m_true.to(device)}, dx) # create a propegator 

data_true = prop(source_wavelet.to(device),
                                x_s.to(device),
                                x_r.to(device), dt).cpu()


# Remove low frequency
source_wavelet = highpass_filter(4,source_wavelet,dt)
data_true = highpass_filter(4,data_true,dt)                                  

In [None]:
# Plot one shot gather
d_vmin, d_vmax = np.percentile(data_true[:,0].cpu().numpy(), [2,98])

plt.imshow(data_true[:,0,].cpu().numpy(), aspect='auto',
           vmin=-d_vmax, vmax=d_vmax,cmap='bwr')



In [None]:
# Clone just to save the initial model. 
m0 = m0.to(device)
m0.requires_grad = True

dm = dm.to(device)
dm.requires_grad = True

##  Set the optimizer and the criterion 


In [None]:
criterion = torch.nn.MSELoss()

optimizer_m0 = torch.optim.Adam([{'params': [m0], 'lr': 0.01}])   # To update the background
optimizer_dm = torch.optim.Adam([{'params': [dm], 'lr': 0.01}])   # To update the perturbation


## Main inversion loop 

In [None]:


# Iterative inversion loop
num_shots_per_batch = int(ns / num_batches)
epoch_loss = [] 
updates = []
gradients = []
msk = msk.to(device)


t_start = time.time()
for epoch in range(FWI_itr):
  running_loss=0
  optimizer_m0.zero_grad()  # zero out the gradient   


  for inner_epoch in range (LSRTM_itr):        
        # ------------- To do 
        
        
        
        
        
  # FWI       
  # ------- To do ( add the scattering to the background  )
    
  # FWI iteration as before  
  for it in range(num_batches):
    prop = deepwave.scalar.Propagator({'vp': m0}, dx)
    batch_src_wvl = source_wavelet[:,it::num_batches,].to(device)
    batch_data_true = data_true[:,it::num_batches].to(device)
    batch_x_s = x_s[it::num_batches].to(device)
    batch_x_r = x_r[it::num_batches].to(device)
    data_pred = prop(batch_src_wvl, batch_x_s, batch_x_r, dt)
    loss = criterion(data_pred, batch_data_true)
    running_loss += loss.item()
    loss.backward()

  epoch_loss.append(running_loss)     

  # Apply some operations to the gradient
  if epoch==0: gmax = (torch.abs(m0.grad)*msk).max()
  m0.grad = m0.grad /gmax *msk   # normalizing by the first gradient and mask the wter layer
    
  # update the m0 
  optimizer_m0.step()
  print('Epoch:', epoch, 'Loss: ', running_loss)

    
  # save the vel updates and gradients for each iteration
  updates.append(m0.detach().clone().cpu().numpy())
  gradients.append(m0.grad.cpu().detach().numpy())  
    
  # plotting every 10 itr   
  if epoch % 10 == 0:
    Plot_model(m0.cpu().detach().numpy(),par) 
    Plot_model(m0.grad.cpu().detach().numpy(),par,cmap='seismic') 
    plt.show()
    
t_end = time.time()
print('Runtime:', (t_end - t_start)/60 ,'minutes')

In [None]:
updates = np.array(updates)
gradients = np.array(gradients)
obj = np.array(epoch_loss)