In [1]:
import os

import warnings
warnings.filterwarnings('ignore')
from tensorflow import logging
logging.set_verbosity(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


import shutil

import sys
import datetime
import pprint


import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf




sys.path.append('../../..')

from seismicpro.batchflow import Pipeline, V, B, L, I, W, C
# from seismicpro.batchflow.models.tf.layers import conv_block
# from seismicpro.batchflow.models.tf import UNet

from seismicpro.src import (SeismicDataset, FieldIndex, TraceIndex,
                           # , statistics_plot,
#                            seismic_plot, spectrum_plot, 
                            merge_segy_files
                           )
from seismicpro.models import UnetAtt, UnetAttGauss1, attention_loss, attention_loss_gauss, FieldMetrics


from seismicpro.batchflow.batchflow.research import Research, Option, KV
from seismicpro.batchflow.batchflow.utils import plot_results_by_config


from metric_utils import get_windowed_spectrogram_dists

%env CUDA_VISIBLE_DEVICES=3

env: CUDA_VISIBLE_DEVICES=3


In [2]:
model_config = {
    'initial_block/inputs': 'trace_raw',
    'inputs': dict(trace_raw={'shape': (3000, 1)},
                   lift={'name': 'targets', 'shape': (3000, 1)}),

#     'loss': (C('loss'), {'balance': 0.05}),
    'optimizer': ('Adam', {'learning_rate': 0.0001}),
    'common/data_format': 'channels_last',
    'body': {
        'main': {
            'encoder/num_stages': 5,
            'encoder/blocks': dict(layout='ca ca',
                                   filters=[16, 32, 64, 128, 256],
                                   kernel_size=[7, 5, 5, 5, 5],
                                   activation=tf.nn.elu),
            'encoder/downsample': dict(layout='pd',
                                       pool_size=2,
                                       pool_strides=2,
                                       dropout_rate=0.05),

            'decoder/blocks': dict(layout='ca ca',
                                   filters=[16, 32, 64, 128, 256][::-1],
                                   kernel_size=[7, 5, 5, 5, 5][::-1],
                                   activation=tf.nn.elu),
            'decoder/upsample': dict(layout='tad',
                                     kernel_size=[7, 5, 5, 5, 5][::-1],
                                     strides=2,
                                     dropout_rate=0.05,
                                     activation=tf.nn.elu,),
          },
        'attn': {
              'encoder/num_stages': 5,
              'encoder/blocks': dict(layout='ca ca',
                                     filters=[8, 16, 32, 64, 128],
                                     kernel_size=3,
                                     activation=tf.nn.elu),
              'encoder/downsample': dict(layout='pd',
                                         pool_size=2,
                                         pool_strides=2,
                                         dropout_rate=0.05),

              'decoder/blocks': dict(layout='ca ca',
                                     filters=[8, 16, 32, 64, 128][::-1],
                                     kernel_size=3,
                                     activation=tf.nn.elu),
              'decoder/upsample': dict(layout='ta d',
                                       kernel_size=3,
                                       strides=2,
                                       dropout_rate=0.05,
                                       activation=tf.nn.elu),
        },
    },
    'head': {'scale': 1.5},
    'train_steps': {
        'step_main': {'scope': 'main_branch'},
        'step_attention': {'scope': 'attention_branch'},

    }
}


In [3]:
BATCH_SIZE=64

def exp_stack(x):
    return np.expand_dims(np.vstack(x), -1)

def make_data(batch, **kwagrs):
    return {'trace_raw': exp_stack(batch.raw), 'lift': exp_stack(batch.lift)}

def make_data_inference(batch, **kwagrs):
    return {'trace_raw': exp_stack(batch.raw)}


In [4]:
def train_n_save(model, loss, train_set, model_path='./saved_models', **kwargs):

    fi = train_set.indices
    
    tindex = TraceIndex(train_set.index)
    t_train_set = SeismicDataset(tindex)
    
    train_pipeline = (t_train_set.p
                      .load(components=('raw', 'lift'), fmt='segy', tslice=np.arange(3000))
                      .init_model('dynamic', model, name='unet', 
                                  config={**model_config, **{'loss': (loss, {'balance': 0.05})}})
                      .init_variable('loss', default=list())
                      .train_model('unet', make_data=make_data, fetches='loss_step_main', save_to=V('loss', 'a'))
                     )
    
    batch_size = kwargs.get('batch_size', 64)
    
    if 'n_epochs' not in kwargs and 'n_iters' not in kwargs:
        n_epochs = 1
        n_iters = None
    else:
        n_epochs = kwargs.get('n_epochs', None)
        n_iters = kwargs.get('n_iters', None)
    
    train_pipeline = train_pipeline.run(batch_size=batch_size, n_epochs=n_epochs, n_iters=n_iters, drop_last=True,
                                    shuffle=True, bar=True, bar_desc=W(V('loss')[-1].format('Current loss is: {:7.7}')))
    
#     tz = datetime.timezone(datetime.timedelta(hours=3))
#     path = os.path.join(model_path, str(datetime.datetime.now(tz=tz)).replace(' ', '_'))
#     print(path)
    path = model_path

    train_pipeline.save_model_now('unet', path)
    
    loss = np.array(train_pipeline.get_variable('loss'))
    
    readme = os.path.join(path, 'README.txt')
    with open(readme, 'w') as inpf:
        inpf.write("Model name: {}\n".format(model.__name__))
        inpf.write("Avg final loss (100 points): {}\n".format(np.mean(loss[-100:])))
        inpf.write("\nConfig:\n")
        inpf.write(pprint.pformat(model_config, compact=True))
        inpf.write("\n\nAdditional Info:\n")
        inpf.write(pprint.pformat(dict(fields=list(fi), **kwargs)))
        
    return loss, path, fi


def inference(model, model_path, test_set, output_path=None, tmp_dump_path='tmp', clear=False):
    if os.path.exists(tmp_dump_path):
        shutil.rmtree(tmp_dump_path)
    
    os.makedirs(tmp_dump_path)
    
    tindex = TraceIndex(test_set.index)
    t_test_set = SeismicDataset(tindex)
    
    inference_ppl = (t_test_set.p
                     .load_model("dynamic", model, 'unet', path=model_path)
                     .init_variable('res')
                     .load(components='raw', fmt='segy', tslice=np.arange(3000))
                     .predict_model('unet', make_data=make_data_inference,
                                    fetches=['out_lift'], save_to=B('raw'))
                     .dump(path=L(lambda x: os.path.join(tmp_dump_path, str(x) + '.sgy'))(I()),
                           src='raw', fmt='segy', split=False)
                 )
    inference_ppl.run(1000, n_epochs=1, drop_last=False, shuffle=False, bar=True)
    
    if output_path is None:
        clear = False
        output_path = os.path.join(tmp_dump_path, 'merged.sgy')
        
    print("merging .sgy")
    merge_segy_files(output_path=output_path, extra_headers='all', path=os.path.join(tmp_dump_path, '*.sgy'))
    
    if clear:
        if os.path.exists(tmp_dump_path):
            shutil.rmtree(tmp_dump_path)        
    
    return output_path


def eval_mt(batch, *args):
    mt = FieldMetrics(batch.lift[0], batch.ml[0])
    return mt.mae(), mt.corr_coef()

def eval_dist(batch, *args):
    n_use_traces = 200
    dist_m = get_windowed_spectrogram_dists(batch.lift[0][0:n_use_traces], batch.ml[0][0:n_use_traces])
    dist = np.mean(dist_m)
    return dist

def _test(path_lift, model_out, fi):
    m_index = (FieldIndex(name='ml', path=model_out)
               .merge(FieldIndex(name='lift', path=path_lift)))
    
    dset = SeismicDataset(m_index)   
    metr_pipeline = (Pipeline()
                 .init_variable('mt', default=[])
                 .init_variable('dist', default=[])
                 .load(components=('ml', 'lift'), fmt='segy', tslice=np.arange(3000))
                 .call(eval_mt, save_to=V('mt', mode='a'))
                 .call(eval_dist, save_to=V('dist', mode='a')))
    
    metr_pipeline = metr_pipeline << dset
    metr_pipeline = metr_pipeline.run(batch_size=1, n_epochs=1, drop_last=False,
                                      shuffle=False, bar=True)
    mt = np.vstack(metr_pipeline.get_variable('mt'))
    dist = np.asarray(metr_pipeline.get_variable('dist'))
    
    return np.mean(mt[:, 0]), np.mean(mt[:, 1]), np.mean(dist)
    

In [None]:
def process(model, loss_fn, splits, dataset_name, path_lift):
    losses = []
    save_res = []
    for i, split in enumerate(splits): # [ds.cv0, ds.cv1, ds.cv2, ds.cv3, ds.cv4]):
        print("Processing dataset {}, model {}, cv_split {}".format(dataset_name, model.__name__, i))
        # NB split.test is passed as train and vice versa
        loss, model_path, fi = train_n_save(model, loss_fn, split.test, model_path='./saved_models/{}/{}/{}'.format(model.__name__, dataset_name, i), n_epochs=3)
        losses.append((model.__name__, dataset_name, i, loss))
        print("mean final loss:", np.mean(loss[-100:]))

        output_path = inference(model, model_path, split.train, output_path=os.path.join(model_path, '{}_out.sgy'.format(dataset_name)), clear=True)
        print(output_path)

        mae, corr, dist = _test(path_lift, output_path, fi)
        print("mae, corr, dist")
        print(mae, corr, dist)

        save_res.append([model.__name__, dataset_name, i, mae, corr, dist])
        
    return losses, save_res

In [None]:
dataset_name = 'H1_WZ'
path_raw = '/notebooks/data/H1_WZ/1_NA-gr_input_DN01_norm2.sgy'
path_lift = '/notebooks/data/H1_WZ/1_NA-gr_output_DN03_norm2.sgy'

index = (FieldIndex(name='raw', extra_headers=['offset'], path=path_raw)
             .merge(FieldIndex(name='lift', path=path_lift)))
    
findex = FieldIndex(index)
ds = SeismicDataset(findex)
ds.cv_split()
splits = [ds.cv0, ds.cv1, ds.cv2, ds.cv3, ds.cv4]

In [None]:
losses = []
save_res = []

for model, loss in zip([UnetAtt, UnetAttGauss1], [attention_loss, attention_loss_gauss]):
    l, r = process(model, loss, splits, dataset_name, path_lift)
    losses.extend(l)
    save_res.extend(r)
    

  0%|          | 0/1707 [00:00<?, ?it/s]

Processing dataset H1_WZ, model UnetAtt, cv_split 0


Current loss is: 0.08191239:  20%|█▉        | 338/1707 [16:58<1:08:04,  2.98s/it]

In [None]:
for model, dataset_name, i, loss in losses:
    plt.plot(loss, label=model+"_{} {}".format(i, np.mean(loss[-100:])))
    
plt.legend()

In [None]:
import pandas as pd

res = pd.DataFrame.from_records(save_res, columns=['Model', 'DS_train', 'split_no', 'MAE', 'Corr', 'Dist'])
res

In [None]:
save_res