# Train the GANO model using kik-net data

for detail of the model implementation, please refer to [Shi et al, 2023](https://arxiv.org/abs/2309.03447) and [Rahman et al, 2022](https://arxiv.org/abs/2205.03017)

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
import pandas as pd
import pickle as pkl 
import sys

sys.path.insert(0, './Python_libs')
# load GP function
from random_fields import *

# load GANO model
from imp import reload
import GANO_model
reload(GANO_model)
from GANO_model import Generator, Discriminator

# import utils
from dataUtils_3C import SeisData
import os
import timeit

In [2]:
# adjust the layout of jupyternoteook
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

## Training parameters of GANO


In [3]:
# GANO model parameters
ndim = 6000      # dimension of 1D time history
npad = 400       # pad at the end, to guarantee the length of data is the power of 2 (efficient FFT)
width= 32        # lift the dimension of input from 3 -> width
lr = 1e-4        # learning rate

# single GPU training
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Training parameters
epochs = 30      # training epochs
λ_grad = 10.0    # penatly factor
n_critic = 10    # train D n_critic times before train G
batch_size = 64  # decrease batch_size if cuda is out of memory

grf = GaussianRF_idct(1,(ndim + npad), alpha=1.5, tau=1.0, cal1d=True, device=device) # 1D GRF (GP)

## Dataloader

In order to use the dataloader, please prepare your dataset. Following files are needed 
1. `Full dataset` : N records with three-component waveforms and the same sampling frequency, each record has a length of `ndim` with the onset of P wave. (datasets of velocity time histories and corresponding acceleration time histories are both needed)


2. `attribute file`: should contain the meta information associated with the records, corresponding to `condv_names` variables


In [4]:
config_d = {
    
    'data_file': './kik_net_data/vel_60s_final.npy',       # full dataset, shape [N, 3, ndim]
    'attr_file': './kik_net_data/attrs_60s_final.csv',     # attribute file, contains magnitude, rupture distance, vs30, etc... for each record. 
    'batch_size': batch_size,

    'frac_train': 0.9,                                              # fraction of training
    'condv_names': ['magnitude','rrup', 'vs30', 'tectonic_value'],  # name of conditional variables
    'condv_min_max' : [(4.5, 8.0), (0, 300), (100, 1100), (0,1)]    # [min, max] for each conditional variable

}

# load the train and val indexes, guarantee reproductivity
index = np.load('./kik_net_data/index_100Hz_final.npy', allow_pickle=True)
ix_train = index[0]                                                 # index of training dataset                         
ix_val = index[1]                                                   # index of validation dataset

"""
# run this part if you don't have the index file
# Shuffle the data and save the index.

Ntot = len(pd.read_csv(config_d['attr_file']))

frac = config_d['frac_train']
Nbatch = config_d['batch_size']

# get all indexes
ix_all = np.arange(Ntot)
# get training indexes
Nsel = int(Ntot*frac)

ix_train = np.random.choice(ix_all, size=Nsel, replace=False)
ix_train.sort()
# get validation indexes
ix_val = np.setdiff1d(ix_all, ix_train, assume_unique=True)
ix_val.sort()

index = []
index.append(ix_train)
index.append(ix_val)
np.save('./kik_net_data/index_100Hz_final.npy', index)
"""

# data loader
sdat_train = SeisData(config_d['data_file'], config_d['attr_file'],config_d['condv_names'], config_d['condv_min_max'], batch_size=config_d['batch_size'], isel=ix_train)
print('total Train:', sdat_train.get_Ntrain())

sdat_val = SeisData(config_d['data_file'], config_d['attr_file'],config_d['condv_names'], config_d['condv_min_max'], batch_size= config_d['batch_size'], isel=ix_val)

n_train_tot = sdat_train.get_Nbatches_tot()

print('total Validation:', sdat_val.get_Ntrain())
# get random samples [batch, 3, dimension], normalized log10_PGA [batch, 3], conditonal variables[[batch], ..., [batch]] 
(wfs, log10_PGA, cvs) = sdat_val.get_rand_batch()
print('shape wfs:', wfs.shape)
print('shape log10_PGA:', log10_PGA.shape)

Loading data ...
Loaded samples:  42481
normalizing data ...
max log pga: 0.09604963680854223 min log pga: -4.779544184431937
--------- magnitude -----------
min magnitude 4.5 scale min 4.5
max magnitude 8.0 scale max 8.0
vc shape (42481, 1)
--------- rrup -----------
min rrup 3.190180152 scale min 0
max rrup 299.9925061 scale max 300
vc shape (42481, 1)
--------- vs30 -----------
min vs30 111.11411939894413 scale min 100
max vs30 1097.5609756097565 scale max 1100
vc shape (42481, 1)
--------- tectonic_value -----------
min tectonic_value 0.0 scale min 0
max tectonic_value 1.0 scale max 1
vc shape (42481, 1)
Number selected samples:  38232
Class init done!
total Train: 38232
Loading data ...
Loaded samples:  42481
normalizing data ...
max log pga: 0.09604963680854223 min log pga: -4.779544184431937
--------- magnitude -----------
min magnitude 4.5 scale min 4.5
max magnitude 8.0 scale max 8.0
vc shape (42481, 1)
--------- rrup -----------
min rrup 3.190180152 scale min 0
max rrup 299.9

## Model initialization

In [5]:
D = Discriminator(6+4, width, ndim=ndim,pad=npad).to(device)    # 6 (3 waveforms+3 PGAs) + 4 (4 conditional variables)
G = Generator(1+4, width, ndim=ndim, pad=npad, training=True).to(device)       # 1 (GP) + 4 (4 conditional variables)

#D = DDP(D, device_ids=device_ids, output_device=0) #data parallel, multi-gpu training
#G = DDP(D, device_ids=device_ids, output_device=0)

nn_params = sum(p.numel() for p in D.parameters() if p.requires_grad)
print("Number discriminator parameters: ", nn_params)
nn_params = sum(p.numel() for p in G.parameters() if p.requires_grad)
print("Number generator parameters: ", nn_params)


Number discriminator parameters:  136627873
Number generator parameters:  136628358


## Optimizer

In [6]:
G_optim = torch.optim.Adam(G.parameters(), lr=lr , weight_decay=1e-4)               # optimizer
D_optim = torch.optim.Adam(D.parameters(), lr=lr , weight_decay=1e-4)
G_scheduler = torch.optim.lr_scheduler.StepLR(G_optim, step_size=5, gamma=0.8)      # step learnig rate
D_scheduler = torch.optim.lr_scheduler.StepLR(D_optim, step_size=5, gamma=0.8)

D.train()
G.train()

Generator(
  (fc0): Linear(in_features=5, out_features=32, bias=True)
  (conv0): SpectralConv1d()
  (conv1): SpectralConv1d()
  (conv2): SpectralConv1d()
  (conv2_1): SpectralConv1d()
  (conv2_9): SpectralConv1d()
  (conv3): SpectralConv1d()
  (conv4): SpectralConv1d()
  (conv5): SpectralConv1d()
  (w0): pointwise_op(
    (conv): Conv1d(32, 48, kernel_size=(1,), stride=(1,))
  )
  (w1): pointwise_op(
    (conv): Conv1d(48, 96, kernel_size=(1,), stride=(1,))
  )
  (w2): pointwise_op(
    (conv): Conv1d(96, 192, kernel_size=(1,), stride=(1,))
  )
  (w2_1): pointwise_op(
    (conv): Conv1d(192, 384, kernel_size=(1,), stride=(1,))
  )
  (w2_9): pointwise_op(
    (conv): Conv1d(384, 192, kernel_size=(1,), stride=(1,))
  )
  (w3): pointwise_op(
    (conv): Conv1d(384, 96, kernel_size=(1,), stride=(1,))
  )
  (w4): pointwise_op(
    (conv): Conv1d(192, 48, kernel_size=(1,), stride=(1,))
  )
  (w5): pointwise_op(
    (conv): Conv1d(96, 32, kernel_size=(1,), stride=(1,))
  )
  (fc1): Linear(in_

## Graident penalty

In [7]:
def calculate_gradient_penalty(model, real_images, fake_images, label,device):
    """Calculates the gradient penalty loss for WGAN GP"""

    alpha = torch.randn((real_images.size(0), 1, 1), device=device)
    interpolates_wfs = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True)
    
    #print(interpolates_wfs.shape, interpolates_lcn.shape)
    model_interpolates = model(interpolates_wfs, label)
    grad_outputs = torch.ones(model_interpolates.size(), device=device, requires_grad=False)

    # Get gradient w.r.t. interpolates
    grad_wf = torch.autograd.grad(
        outputs=model_interpolates,
        inputs=interpolates_wfs,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    #gradients = torch.cat([grad_wf, grad_cn,], 1)
    gradients = grad_wf
    
    gradients = gradients.reshape(gradients.size(0), -1)
    gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1.0/ndim) ** 2)    

    return gradient_penalty

## Model Training

In [8]:
def train_WGANO(D, G, epochs, D_optim, G_optim, scheduler=None):
    # record the loss information
    losses_D = np.zeros(epochs)
    losses_G = np.zeros(epochs)
    losses_G_val = np.zeros(epochs)
    losses_W = np.zeros(epochs)
    for i in range(epochs):
        loss_D = 0.0
        loss_G = 0.0
        loss_G_val = 0.0
        loss_W = 0.0
        for j in range(n_train_tot):
            for k in range(n_critic):
                D_optim.zero_grad()
                # sdat_train is the dataloader for training dataset
                (x, log10_PGA, cvs) = sdat_train.get_rand_batch()
                x = torch.Tensor(x)
                x = F.pad(x, [0, npad]).to(device)

                # label shape [batch, 1, 4]
                label = np.asarray(cvs)
                label = torch.from_numpy(label).permute(1, 2, 0).float().to(device) 
                #print("label shape:{}".format(label.shape))
                
                x_syn = G(grf.sample(x.shape[0]).to(device), label)
                #print("x_syn shape:{}".format(x_syn.shape))
                
                # wasserstein regularizaiton
                log10_PGA = torch.from_numpy(log10_PGA).unsqueeze(2).float()
                log10_PGA = log10_PGA.repeat(1, 1, (ndim+npad)).to(device)
                x = torch.cat([x, log10_PGA], dim=1)
                
                W_loss = -torch.mean(D(x, label)) + torch.mean(D(x_syn, label))
                gradient_penalty = calculate_gradient_penalty(D, x, x_syn, label, device)
                #gradient_penalty = 0

                loss = W_loss + λ_grad * gradient_penalty 
                loss.backward()
                
                loss_D += loss.item()
                loss_W += W_loss.item()

                D_optim.step()
            
            G_optim.zero_grad()
            # train discriminator every n_critic times before updating the generator
            (x, _, cvs) = sdat_train.get_rand_batch()
            x = torch.Tensor(x).to(device)
            
            label = np.asarray(cvs)
            label = torch.from_numpy(label).permute(1, 2, 0).float().to(device) 

            x_syn = G(grf.sample(x.shape[0]).to(device), label)

            loss = -torch.mean(D(x_syn, label))
            loss.backward()
            loss_G += loss.item()

            G_optim.step()
            
            # Store validation information
            with torch.no_grad():
                print("epoch:[{} / {}] batch:[{} / {}],loss_G:{:.4f}".format(i, epochs, j, n_train_tot, loss.item()))   
                
                # save training loss and validation loss 
                (x, _, cvs) = sdat_val.get_rand_batch()
                x = torch.Tensor(x).to(device)

                label = np.asarray(cvs)
                label = torch.from_numpy(label).permute(1, 2, 0).float().to(device) 

                x_syn = G(grf.sample(x.shape[0]).to(device), label)

                loss = -torch.mean(D(x_syn, label))
                loss_G_val += loss.item()
                
                if j % 200 == 0:
                    # check the label. 
                    mag = sdat_train.to_real(cvs[0][0], 'magnitude')
                    dist = sdat_train.to_real(cvs[1][0], 'rrup')
                    vs30 = sdat_train.to_real(cvs[2][0], 'vs30')
                    tectonic_value = sdat_train.to_real(cvs[3][0],'tectonic_value')
                    if tectonic_value == 0.0:
                        tectonic_type = 'Subduction'
                    else:
                        tectonic_type = 'Shallow crustal'
                    fig, ax = plt.subplots(1, 1, figsize=(16,8), tight_layout=True)
                    ax.plot(x_syn[0,0,:].squeeze().detach().cpu().numpy())
                    ax.set_title('M {} , {} km, $Vs_{{30}}$={}m/s, event= {}'.format(mag[0], dist[0], vs30[0], tectonic_type), fontsize=16)
                    plt.savefig('./plots/epoch{}_it{}_GANO'.format(i,j))
                    plt.close(fig)

                                
        losses_D[i] = loss_D / batch_size
        losses_G[i] = loss_G / batch_size
        losses_G_val[i] = loss_G_val / batch_size
        losses_W[i] = loss_W / batch_size
        
        D_scheduler.step()
        G_scheduler.step()
        if (i+1) % 10 == 0: #save the model every 10 epochs
            torch.save(G.state_dict(), "./saved_models/G_{}_GANO.pt".format(i+1))
        
    return losses_D, losses_G, losses_G_val, losses_W

In [9]:
# create folder is not exist
folder = "GANO_kik_net_training"
if not os.path.exists(f"./saved_models/{folder}"):
    os.makedirs(f"./saved_models/{folder}")
if not os.path.exists(f"./plots/{folder}"):
    os.makedirs(f"./plots/{folder}")

In [None]:
start = timeit.default_timer() # track the time for training
losses_D, losses_G, losses_G_val, losses_W = train_WGANO(D, G, epochs, D_optim, G_optim)
stop = timeit.default_timer() 
# print time loss
print(stop - start)

### Save the loss information (optional)

In [None]:
plt.figure(figsize=(24,4))
plt.subplot(1,4,1)
plt.plot(losses_D[10:])
plt.subplot(1,4,2)
plt.plot(losses_G[10:])
plt.subplot(1,4,3)
plt.plot(losses_G_val[10:])
plt.subplot(1,4,4)
plt.plot(losses_W[10:])

losses_all = pd.DataFrame()

losses_all['losses_D'] = losses_D
losses_all['losses_G'] = losses_G
losses_all['losses_G_val'] = losses_G_val
losses_all['losses_W'] = losses_W

losses_all.to_csv("./losses_GANO.npy")

In [None]:
# 30 epochs 