In [None]:
import sys
sys.path.append("../")

import os
import sys
import time
import datetime
import json

import torch
from torch import nn
from torch.utils.data import RandomSampler, DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision.transforms import Compose

from iunets import iUNet

from tqdm.notebook import tqdm
from tqdm import tqdm


from dataset import FWIDataset
from networks import forward_network, inverse_network, iunet_network


import utils.utilities as utils
import utils.transforms as T
from utils.pytorch_ssim import *
from utils.scheduler import WarmupMultiStepLR

In [2]:
step = 0

device = torch.device("cuda")

dataset = "curvevel-a"
file_size = 500
vis_suffix = False
train_anno = "../train_test_splits/curvevel_a_train_48.txt"
val_anno = "../train_test_splits/curvevel_a_val_48.txt"
sample_temporal = 1
distributed = False
batch_size = 50
workers = 4
lambda_g1v = 1
lambda_g2v = 1
k = 1
output_path = './0_test/evaluation/'
mask_factor = 0.0

# model_type = "IUNet"
model_type = "IUnetForwardModel"
model_path = "/projects/ml4science/openfwi/openfwi_results/RESULTS_Invertible_XNet/CurveVel-A/Forward/mask_factor_0/lambda_1/fcn_l1loss_ffb/latest_checkpoint.pth"


In [3]:
with open('../dataset_config.json') as f:
    try:
        ctx = json.load(f)[dataset]
    except KeyError:
        print('Unsupported dataset.')
        sys.exit()

if file_size is not None:
    ctx['file_size'] = file_size

# transform_data, transform_label = utils.get_transforms(args, ctx)
transform_label = T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])
transform_data = T.Normalize(ctx['data_mean'], ctx['data_std']) 

# Create dataset and dataloader
print('Loading data')
print('Loading training data')

dataset_train = FWIDataset(
    train_anno,
    preload=True,
    sample_ratio=sample_temporal,
    file_size=ctx['file_size'],
    transform_data=transform_data,
    transform_label=transform_label,
    mask_factor=mask_factor
)


print('Loading validation data')
dataset_valid = FWIDataset(
    val_anno,
    preload=True,
    sample_ratio=sample_temporal,
    file_size=ctx['file_size'],
    transform_data=transform_data,
    transform_label=transform_label
)

train_sampler = RandomSampler(dataset_train)
valid_sampler = RandomSampler(dataset_valid)
    
dataloader_train = DataLoader(
    dataset_train, batch_size=batch_size,
    sampler=train_sampler, num_workers=workers,
    pin_memory=True, drop_last=True, collate_fn=default_collate)

dataloader_valid = DataLoader(
    dataset_valid, batch_size=batch_size,
    sampler=valid_sampler, num_workers=workers,
    pin_memory=True, collate_fn=default_collate)

Loading data
Loading training data
Data concatenation complete.
Loading validation data
Data concatenation complete.


In [5]:
model = forward_network.model_dict[model_type](**forward_network.forward_params).to(device)
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model.eval()

IUnetForwardModel(
  (iunet_model): iUNet(
    (encoder_modules): ModuleList(
      (0): ModuleList(
        (0): InvertibleModuleWrapper(
          (_fn): StandardAdditiveCoupling(
            (F): StandardBlock(
              (seq): ModuleList(
                (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                (1): LeakyReLU(negative_slope=0.01, inplace=True)
                (2): GroupNorm(1, 64, eps=0.001, affine=True)
              )
              (F): Sequential(
                (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                (1): LeakyReLU(negative_slope=0.01, inplace=True)
                (2): GroupNorm(1, 64, eps=0.001, affine=True)
              )
            )
          )
        )
        (1): InvertibleModuleWrapper(
          (_fn): StandardAdditiveCoupling(
            (F): StandardBlock(
              (seq): ModuleList(
                (0): Conv2d(64, 64, kernel_size=(3, 3)

In [8]:
def evaluate(model, criterions, dataloader, device, k, ctx):   
    
#     vel_list, vel_pred_list= [], [] # store denormalized velocity predcition & gt in numpy 
#     vel_norm_list, vel_pred_norm_list = [], [] # store normalized velocity prediction & gt in tensor
    
    amp_list, amp_pred_list = [], []     # store denormalized waveform predcition & gt in numpy
    amp_norm_list, amp_pred_norm_list = [], []  # store normalized waveform predcition & gt in numpy

    with torch.no_grad():
        batch_idx = 0
        for _, amp, vel in dataloader:
            amp = amp.to(device)
            vel = vel.to(device)
            
#             vel_pred = model(amp)
            amp_pred = model(vel)
            
#             vel_np = transform_label.inverse_transform(vel.detach().cpu().numpy())
#             vel_list.append(torch.from_numpy(vel_np))
#             vel_norm_list.append(vel.detach().cpu())
            
#             vel_pred_np = transform_label.inverse_transform(vel_pred.detach().cpu().numpy())
#             vel_pred_list.append(torch.from_numpy(vel_pred_np))
#             vel_pred_norm_list.append(vel_pred.detach().cpu())
            
            
            amp_norm_list.append(amp.detach().cpu())
            amp_pred_norm_list.append(amp_pred.detach().cpu())
            
            amp_np = transform_data.inverse_transform(amp.detach().cpu().numpy())
            amp_pred_np = transform_data.inverse_transform(amp_pred.detach().cpu().numpy())
            
            amp_list.append(torch.from_numpy(amp_np))
            amp_pred_list.append(torch.from_numpy(amp_pred_np))
            
            batch_idx += 1

#     vel, vel_pred = torch.cat(vel_list), torch.cat(vel_pred_list)
#     vel_norm, vel_pred_norm = torch.cat(vel_norm_list), torch.cat(vel_pred_norm_list)
    
    amp, amp_pred = torch.cat(amp_list), torch.cat(amp_pred_list)
    amp_norm, amp_pred_norm = torch.cat(amp_norm_list), torch.cat(amp_pred_norm_list)

    for name, criterion in criterions.items():
#         print(f'Velocity Normalized {name}: {criterion(vel_norm, vel_pred_norm)}')
        print(f'Waveform Normalized {name}: {criterion(amp_norm, amp_pred_norm)}')
#         print(f' * Velocity {name}: {criterion(vel, vel_pred)}')
        print(f' * Waveform {name}: {criterion(amp, amp_pred)}')
        
    ssim_loss = SSIM(window_size=11)
#     print(f'Velocity SSIM: {ssim_loss(vel_norm / 2 + 0.5, vel_pred_norm / 2 + 0.5)}') 
    print(f'Waveform SSIM: {ssim_loss(amp_norm, amp_pred_norm)}')


In [9]:
criterions = {
    'MAE': lambda x, y: torch.mean(torch.abs(x - y)),
    'MSE': lambda x, y: torch.mean((x - y) ** 2)
}

val_eval_dict = evaluate(model, criterions, dataloader_valid, device, k, ctx)
print(val_eval_dict)

Waveform Normalized MAE: 0.04896041005849838
 * Waveform MAE: 0.07280413061380386
Waveform Normalized MSE: 0.019604051485657692
 * Waveform MSE: 0.04334786906838417
Waveform SSIM: 0.674527108669281
None
