In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from data_loader import QuakeDataset, get_dataloader
from models import QuakeModel

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
save_step = 1
n_epochs = 100
learning_rate = 0.001
weight_decay = 1e-5
step_size = 10
gamma = 0.1
n_samples_train = 100000
n_samples_valid = 10000
n_layers = 6
n_heads = 8
time_length = 500
time_size = 300
freq_size = 128
d_model = 64  # 512, 128
d_ff = 256    # 2048, 512
batch_size = 32
num_workers = 4
dropout = 0.1

In [4]:
save_step = 1
n_epochs = 100
learning_rate = 0.0001
weight_decay = 1e-5
step_size = 10
gamma = 0.1
n_samples_train = 100000
n_samples_valid = 10000
n_layers = 3
n_heads = 8
time_length = 500
time_size = 300
freq_size = 128
d_model = 32  # 512, 128
d_ff = 128    # 2048, 512
batch_size = 32
num_workers = 4
dropout = 0.1

In [5]:
input_dir = '/run/media/hoosiki/WareHouse3/mtb/datasets/LANL'
log_dir = './logs'
model_dir = './models'

csv_file_train = 'train_split.csv'
csv_file_valid = 'valid_split.csv'

os.makedirs(log_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

In [6]:
model = QuakeModel(
    n_layers=n_layers,
    n_heads=n_heads,
    freq_size=freq_size,
    d_model=d_model,
    d_ff=d_ff,
    dropout=dropout)

# This was important from their code. 
# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)

checkpoint = torch.load('./models/model-epoch-02.ckpt')
model.load_state_dict(checkpoint['state_dict'])

In [7]:
criterion = nn.SmoothL1Loss()
#criterion = nn.L1Loss()

In [8]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [9]:
for epoch in range(n_epochs):
    
    data_loaders, data_size = get_dataloader(
        input_dir=input_dir,
        csv_file_train=csv_file_train,
        csv_file_valid=csv_file_valid,
        n_samples_train=n_samples_train,
        n_samples_valid=n_samples_valid,
        time_length=time_length,
        time_size=time_size,
        batch_size=batch_size,
        num_workers=num_workers)
        
    for phase in ['train', 'valid']:
    
        running_loss = 0.0
        running_diff = 0.0
        running_size = data_size[phase] / batch_size
        
        if phase == 'train':
            scheduler.step()
            model.train()
        else:
            model.eval()
            
        for batch_idx, batch_sample in enumerate(data_loaders[phase]):

            Sxx = batch_sample['Sxx'].to(device)
            target = batch_sample['target'].to(device)
                
            optimizer.zero_grad()
                
            with torch.set_grad_enabled(phase == 'train'):
                    
                output = model(Sxx)
                diff = torch.mean(torch.abs(target-output), dim=0)
                loss = criterion(output, target)
                    
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                        
            running_loss += loss.item()
            running_diff += diff.item()
            
        # Print the average loss and accuracy in an epoch.
        epoch_loss = running_loss / running_size
        epoch_diff = running_diff / running_size

        print('| {} SET | Epoch [{:02d}/{:02d}], Loss: {:.4f}, MAE: {:.4f}'
              .format(phase.upper(), epoch+1, n_epochs, epoch_loss, epoch_diff))

        # Log the loss and accuracy in an epoch.
        with open(os.path.join(log_dir, '{}-log-epoch-{:02}.txt')
                  .format(phase, epoch+1), 'w') as f:
            f.write(str(epoch+1) + '\t' +
                    str(epoch_loss) + '\t' +
                    str(epoch_diff))

    # Save the model check points.
    if (epoch+1) % save_step == 0:
        torch.save({'epoch': epoch+1, 'state_dict': model.state_dict()},
                   os.path.join(model_dir, 'model-epoch-{:02d}.ckpt'.format(epoch+1)))
    print()

| TRAIN SET | Epoch [01/100], Loss: 1.6911, MAE: 2.1396
| VALID SET | Epoch [01/100], Loss: 1.5846, MAE: 2.0628

| TRAIN SET | Epoch [02/100], Loss: 1.6918, MAE: 2.1407
| VALID SET | Epoch [02/100], Loss: 1.3698, MAE: 1.8381



KeyboardInterrupt: 