In [None]:
import sys
sys.path.append("../")
print(sys.path)
import os
import time
import datetime
import json

import torch
from torch import nn
from torch.utils.data import RandomSampler, DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision.transforms import Compose

from iunets import iUNet
from tqdm.notebook import tqdm
from tqdm import tqdm

from networks import iunet_network
from dataset import FWIDataset

from utils.scheduler import WarmupMultiStepLR
import utils.transforms as T
from utils.pytorch_ssim import *
import utils.utilities as utils

In [None]:
step = 0

device = torch.device("cuda")

dataset = "curvevel-a"
file_size = 500
vis_suffix = False
train_anno = "../train_test_splits/curvevel_a_train_48.txt"
val_anno = "../train_test_splits/curvevel_a_val_48.txt"
sample_temporal = 1
distributed = False
batch_size = 50
workers = 4
lambda_g1v = 1
lambda_g2v = 1
k = 1
output_path = './0_test/evaluation/'
mask_factor = 0.6


model_type = "IUNet"
model_path = "/projects/ml4science/openfwi/openfwi_results/RESULTS_Invertible_XNet/CurveVel-A/Joint/Cycle_Loss_0_Mask_Factor_60/lambda_amp_10_lambda_vel_1/fcn_l1loss_ffb/latest_checkpoint.pth"

# model_type = "Decouple_IUnet"
# model_path = "/projects/ml4science/openfwi/openfwi_results/Decoupled_Invertible_X_Net/CurveVel-A/Debug/Optimizer_Adamax/Cycle_Loss_1_Mask_Factor_80/lambda_amp_10_lambda_vel_1/fcn_l1loss_ffb/latest_checkpoint.pth"
# Comments

In [37]:
with open('../dataset_config.json') as f:
    try:
        ctx = json.load(f)[dataset]
    except KeyError:
        print('Unsupported dataset.')
        sys.exit()

if file_size is not None:
    ctx['file_size'] = file_size

# transform_data, transform_label = utils.get_transforms(args, ctx)
transform_label = T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])
transform_data = T.Normalize(ctx['data_mean'], ctx['data_std']) 

# Create dataset and dataloader
print('Loading data')
print('Loading training data')

dataset_train = FWIDataset(
    train_anno,
    preload=True,
    sample_ratio=sample_temporal,
    file_size=ctx['file_size'],
    transform_data=transform_data,
    transform_label=transform_label,
    mask_factor=mask_factor
)


print('Loading validation data')
dataset_valid = FWIDataset(
    val_anno,
    preload=True,
    sample_ratio=sample_temporal,
    file_size=ctx['file_size'],
    transform_data=transform_data,
    transform_label=transform_label
)

train_sampler = RandomSampler(dataset_train)
valid_sampler = RandomSampler(dataset_valid)
    
dataloader_train = DataLoader(
    dataset_train, batch_size=batch_size,
    sampler=train_sampler, num_workers=workers,
    pin_memory=True, drop_last=True, collate_fn=default_collate)

dataloader_valid = DataLoader(
    dataset_valid, batch_size=batch_size,
    sampler=valid_sampler, num_workers=workers,
    pin_memory=True, collate_fn=default_collate)

Loading data
Loading training data
Data concatenation complete.
Loading validation data
Data concatenation complete.


In [38]:
amp_input_channel = 5
amp_encoder_channel = [8, 16, 32, 64, 128]
amp_decoder_channel = [128, 64, 32, 16, 5]
amp_model = iunet_network.AmpAutoEncoder(amp_input_channel, amp_encoder_channel, amp_decoder_channel).to(device)

# creating velocity cnn
vel_input_channel = 1
vel_encoder_channel = [8, 16, 32, 64, 128]
vel_decoder_channel = [128, 64, 32, 16, 1]
vel_model = iunet_network.VelAutoEncoder(vel_input_channel, vel_encoder_channel, vel_decoder_channel).to(device)

if model_type == "IUNet":
    iunet_model = iUNet(in_channels=128, dim=2, architecture=(4,4,4,4))
    model = iunet_network.IUnetModel(amp_model, vel_model, iunet_model)
    print("IUnet model initialized.")
elif model_type == "Decouple_IUnet":
    amp_iunet_model = iUNet(in_channels=128, dim=2, architecture=(4,4,4,4))
    vel_iunet_model = iUNet(in_channels=128, dim=2, architecture=(4,4,4,4))
    model = iunet_network.Decouple_IUnetModel(amp_model, vel_model, amp_iunet_model, vel_iunet_model)
    print("Decoupled IUnetModel model initialized.")
else:
    print(f"Invalid Model: {model_type}")

checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model.eval()

IUnet model initialized.


IUnetModel(
  (amp_model): AmpAutoEncoder(
    (encoder_layers): Sequential(
      (0): Sequential(
        (0): Conv2d(5, 8, kernel_size=(7, 1), stride=(3, 1), padding=(3, 0))
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (1): Sequential(
        (0): Conv2d(8, 16, kernel_size=(7, 1), stride=(2, 1), padding=(1, 0))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (2): Sequential(
        (0): Conv2d(16, 32, kernel_size=(5, 1), stride=(2, 1), padding=(1, 0))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (3): Sequential(
        (0): Conv2d(32, 64, kernel_size=(5, 1), stride=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [39]:
# # Define loss function
# l1loss = nn.L1Loss()
# l2loss = nn.MSELoss()
    
# def criterion(pred, gt):
#     loss_g1v = l1loss(pred, gt)
#     loss_g2v = l2loss(pred, gt)
#     loss = lambda_g1v * loss_g1v + lambda_g2v * loss_g2v
#     return loss, loss_g1v, loss_g2v

In [40]:
# def evaluate(model, criterion, dataloader, device, ctx, transform_data=None, transform_label=None):
#     l1_loss = nn.L1Loss()
#     l2_loss = nn.MSELoss()

#     model.eval()
#     with torch.no_grad():
#         total_samples = 0
        
#         eval_metrics = ["vel_sum_abs_error", "vel_sum_squared_error", "amp_sum_abs_error", "amp_sum_squared_error"]
#         eval_dict = {}
#         for metric in eval_metrics:
#             eval_dict[metric] = 0

#         val_loss = 0
#         for _, amp, vel in dataloader:
            
#             amp = amp.to(device, non_blocking=True)
#             vel = vel.to(device, non_blocking=True)
#             vel = transform_label.inverse_transform(vel)
#             amp = transform_data.inverse_transform(amp)
            
#             batch_size = amp.shape[0]
#             total_samples += batch_size
            
#             vel_pred = model.inverse(amp)
#             vel_pred = transform_label.inverse_transform(vel_pred)
            
#             amp_pred = model.forward(vel)
#             amp_pred = transform_data.inverse_transform(amp_pred)
            
#             vel_loss, vel_loss_g1v, vel_loss_g2v = criterion(vel_pred, vel)
#             amp_loss, amp_loss_g1v, amp_loss_g2v = criterion(amp_pred, amp)

#             loss = vel_loss + amp_loss
#             val_loss += loss.item()

#             eval_dict["vel_sum_abs_error"] += (l1_loss(vel, vel_pred) * batch_size).item()
#             eval_dict["vel_sum_squared_error"] += (l2_loss(vel, vel_pred) * batch_size).item()
            
#             eval_dict["amp_sum_abs_error"] += (l1_loss(amp, amp_pred) * batch_size).item()
#             eval_dict["amp_sum_squared_error"] += (l2_loss(amp, amp_pred) * batch_size).item()

#         for metric in eval_metrics:
#             eval_dict[metric] /= total_samples 

#     val_loss /= len(dataloader)
#     return val_loss, eval_dict


In [41]:
# test_loss, test_eval_dict = evaluate(model, criterion, dataloader_valid, device, ctx, transform_data, transform_label)

In [42]:
def evaluate(model, criterions, dataloader, device, k, ctx):   
    
    vel_list, vel_pred_list= [], [] # store denormalized velocity predcition & gt in numpy 
    vel_norm_list, vel_pred_norm_list = [], [] # store normalized velocity prediction & gt in tensor
    
    amp_list, amp_pred_list = [], []     # store denormalized waveform predcition & gt in numpy
    amp_norm_list, amp_pred_norm_list = [], []  # store normalized waveform predcition & gt in numpy

    with torch.no_grad():
        batch_idx = 0
        for _, amp, vel in dataloader:
            amp = amp.to(device)
            vel = vel.to(device)
            
            vel_pred = model.inverse(amp)
            amp_pred = model.forward(vel)
            
            vel_np = transform_label.inverse_transform(vel.detach().cpu().numpy())
            vel_list.append(torch.from_numpy(vel_np))
            vel_norm_list.append(vel.detach().cpu())
            
            vel_pred_np = transform_label.inverse_transform(vel_pred.detach().cpu().numpy())
            vel_pred_list.append(torch.from_numpy(vel_pred_np))
            vel_pred_norm_list.append(vel_pred.detach().cpu())
            
            
            amp_norm_list.append(amp.detach().cpu())
            amp_pred_norm_list.append(amp_pred.detach().cpu())
            
            amp_np = transform_data.inverse_transform(amp.detach().cpu().numpy())
            amp_pred_np = transform_data.inverse_transform(amp_pred.detach().cpu().numpy())
            
            amp_list.append(torch.from_numpy(amp_np))
            amp_pred_list.append(torch.from_numpy(amp_pred_np))
            
            batch_idx += 1

    vel, vel_pred = torch.cat(vel_list), torch.cat(vel_pred_list)
    vel_norm, vel_pred_norm = torch.cat(vel_norm_list), torch.cat(vel_pred_norm_list)
    
    amp, amp_pred = torch.cat(amp_list), torch.cat(amp_pred_list)
    amp_norm, amp_pred_norm = torch.cat(amp_norm_list), torch.cat(amp_pred_norm_list)

    for name, criterion in criterions.items():
        print(f'Velocity Normalized {name}: {criterion(vel_norm, vel_pred_norm)}')
        print(f'Waveform Normalized {name}: {criterion(amp_norm, amp_pred_norm)}')
        print(f' * Velocity {name}: {criterion(vel, vel_pred)}')
        print(f' * Waveform {name}: {criterion(amp, amp_pred)}')
        
    ssim_loss = SSIM(window_size=11)
    print(f'Velocity SSIM: {ssim_loss(vel_norm / 2 + 0.5, vel_pred_norm / 2 + 0.5)}') 
    print(f'Waveform SSIM: {ssim_loss(amp_norm, amp_pred_norm)}')


In [43]:
criterions = {
    'MAE': lambda x, y: torch.mean(torch.abs(x - y)),
    'MSE': lambda x, y: torch.mean((x - y) ** 2)
}

val_eval_dict = evaluate(model, criterions, dataloader_valid, device, k, ctx)
print(val_eval_dict)

Velocity Normalized MAE: 0.0537225678563118
Waveform Normalized MAE: 0.04079228267073631
 * Velocity MAE: 80.58385467529297
 * Waveform MAE: 0.06065812334418297
Velocity Normalized MSE: 0.010535511188209057
Waveform Normalized MSE: 0.013184669427573681
 * Velocity MSE: 23704.900390625
 * Waveform MSE: 0.029153531417250633
Velocity SSIM: 0.8821964859962463
Waveform SSIM: 0.7015249729156494
None
