# Assignment 2  

The goal for this notebook is to understand the important of low-wavenumber in the inversion. Thus we will start with a linearly increasing model with depth (i.e., V(z)) model. 


### Tasks: 

     - Aplpy FWI using the full data spectrum starting from 3 Hz, then apply a multi-scale approach to improve the inversion of v(z). Explain your steps and your observations. 
          
     - What is the minimum and maximum wavenumber expected for the first frequency range at depths 1,2,3 km?

     

In [None]:
import time
import torch
import numpy as np
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
import deepwave
from scipy import signal
from torchvision.transforms import GaussianBlur 
from fwi import FWI
%matplotlib inline

####  Functions

In [None]:
def Plot_model(m,par,name=None,**kwargs):
    """
    plot a 2D model 
    
    Arguments
    ----------
    m: 2D numpy array
         array containing the model
    par : dictionary 
        dictionary containing the axis points,increments, and origin points. 
        (i.e,par['ox'],par['dx'],par['nx'],par['nz'],par['dz'],par['oz'])
    ----------
    Optional 
    ----------
    vmax: float
          Maximum value for the plot 
    vmin: float
          Minimum value for plot
    cmap: str
          Matplotlib-colormap
    name: str 
          to save the figure with the corresponding 'name' in a 'Fig' directory
    """

    vmax = kwargs.pop('vmax', None)
    vmin = kwargs.pop('vmin', None)
    name = kwargs.pop('name', None)
    cmap = kwargs.pop('cmap', 'jet')
    if 'vmin'==None: vmin, _ = np.percentile(m.T,[2,98])
    if 'vmax'==None: _, vmax = np.percentile(m.T,[2,98])
    plt.figure(figsize=(10,3))
    plt.imshow(m,cmap=cmap,vmin=vmin,vmax=vmax,extent=[par['ox'],par['dx']*par['nx'],par['nz']*par['dz'],par['oz']])
    plt.axis('tight')
    plt.xlabel('Distance (km)',fontsize=18,weight='heavy')
    plt.ylabel('Depth (km)',fontsize=18,weight='heavy')
    plt.colorbar(label='km/s')
    if name!=None:
        if not os.path.isdir('./Fig'): os.mkdir('./Fig')
        plt.savefig('./Fig/'+name,bbox_inches='tight')
    
    
def plot_shot(data,idx,par):
    """
    plot a shot gather 
    
    Arguments
    ----------
    data: 3D numpy array 
         array containing the shot gathers
    idx : int 
        the index of the shot gather to be plotted
    par: dictionary 
        A dictionary containing the parameter for the shot (nt,dt,nt,dt)        
    """
    vmin, vmax = np.percentile(data[:,idx].cpu().numpy(), [2,98])
    plt.figure()
    plt.imshow(data[:,idx].cpu().numpy(), aspect='auto',
           vmin=vmin, vmax=vmax,cmap='gray',extent=[par['orec']+idx*par['ds'],par['orec']+idx*par['ds']+par['dr']*par['nr'],
                                                    par['nt']*par['dt'],par['ot']])
    plt.ylabel('Time (s)')
    plt.xlabel('Distance (km)')
    
    
def mask(m,value):
    """
    Return a mask for the model (m) using the (value)
    """
    msk = m > value
    msk = msk.astype(int)
    return msk



#### Define the parameters and I/O files 

In [None]:
device = torch.device('cuda:0')


# Load the true model for forward modelling 
path = '../Assignment1/'
velocity_file= path + 'Marm.bin' # true model 

# Define the model and achuisition parameters
par = {'nx':601,   'dx':0.015, 'ox':0,
       'nz':221,   'dz':0.015, 'oz':0,
       'ns':30,    'ds':0.3,   'osou':0,  'sz':0.03,
       'nr':300,   'dr':0.03,  'orec':0,  'rz':0.03,
       'nt':3000,  'dt':0.0013,  'ot':0,
       'freq':15,
       'num_batches':10, # increase thus number if you have a CUDA out of memory error 
       'FWI_itr': 100,
       'num_dims': 2 
      }

# Mapping the par dictionary to variables 
for k in par:
    locals()[k] = par[k]
    
fs = 1/dt # sampling frequency




#### Loading the input files

In [None]:

#  Load the velocity model 
vel =(np.fromfile(velocity_file, np.float32)
              .reshape(nz, nx))


vel_init = np.load('./input_files/vz.npy')


Plot_model(vel,par)

Plot_model(vel_init,par)

print(f'vel shape {vel.shape} (nx,nz)  || init shape {vel_init.shape} (nx,nz)')


#### Convert arrays to tensor 

In [None]:
# Get a mask for the water layer (P.S water veocity = 1.5 km/s)
msk_water = np.ones_like(vel)
msk_water[:20,] =  0



# convert to tensor
vel = torch.tensor(vel,dtype=torch.float32)
vel_init = torch.tensor(vel_init,dtype=torch.float32)







### Initialize FWI class 

In [None]:

# initialize the fwi class
inversion = FWI(par)


#### Forward modelling 
    
You don't need to do this step, it is already done to create the observed data 

In [None]:
def highpass_filter(freq, wavelet, dt):
    """
    Filter out low frequency

    Parameters
    ----------
    freq : :obj:`int`
    Cut-off frequency
    wavelet : :obj:`torch.Tensor`
    Tensor of wavelet
    dt : :obj:`float32`
    Time sampling
    Returns
    -------
    : :obj:`torch.Tensor`
    Tensor of highpass frequency wavelet
    """

    sos = signal.butter(6,  freq / (0.5 * (1 / dt)), 'hp', output='sos') 
    return torch.tensor( signal.sosfiltfilt(sos, wavelet,axis=0).copy(),dtype=torch.float32)



# #Ricker wavelet
wavl = inversion.Ricker(freq)
# Foeward modelling 
data = torch.zeros((nt,ns,nr),dtype=torch.float32)
data = inversion.forward_modelling(vel,wavl.repeat(1,ns,1),device)


# Remove low frequencies 
wavl = highpass_filter(4,wavl,dt)
data = highpass_filter(4,data,dt) 

In [None]:
plot_shot(data,par['ns']//2,par)
print(data.shape)


#### Apply the multiscaling

In [None]:
# ## To do 




# plot_shot(data_filtered,5,par)

# plt.figure()
# plt.magnitude_spectrum(wavl_filtered[:,0,0],fs)
# plt.xlim([0,10])

### Run inversion 


In [None]:
v_inv,loss  = inversion.run_inversion(vel_init,data,wavl,msk_water,
                                      FWI_itr,device) 




### Plotting the objective function and the inversion results and saving the inversion 


In [None]:

# To Do 

