# Iterative research

In this advanced notebook, we apply the [carcass interpolation model](./01_Demo_E.ipynb), as well as [horizon extension](./../Horizon_extension/Demo_E.ipynb) and enhancement ones in a quick succesion with the help of [research](./../Research_template.ipynb). It is adviced to check out our notebooks on these techniques prior to looking at this one. 

In [None]:
# Necessary imports
import os
import sys
import shutil
import random
import warnings

from copy import copy
from glob import glob
from tqdm.auto import tqdm
from datetime import date

import torch
import numpy as np

sys.path.append('../..')
from seismiqb.batchflow import Pipeline, Dataset, C, V, R, P, B, D
from seismiqb.batchflow.models.torch import EncoderDecoder, ResBlock
from seismiqb.batchflow.research import Research, Option, Domain, Results, FileLogger
from seismiqb.batchflow.research import RP, RC, KV

from seismiqb import Interpolator, Enhancer, Extender

warnings.filterwarnings("ignore")

In [None]:
# Global parameters
DEVICES = [0, 1, 2, 3, 4, 5, 6, 7]         # physical device numbers
WORKERS = len(DEVICES)

RESEARCH_NAME = f'Research_iterative'
DUMP_NAME = date.today().strftime("%Y-%d-%m") + RESEARCH_NAME[8:]
N_REPS = 2
SUPPORTS = 100
OVERLAP_FACTOR = 2.
ITERATIONS = 1
FREQUENCIES = [200, 200]


# Detection parameters
DETECTION_CROP_SHAPE = (1, 256, 256)       # shape of sampled 3D crops
DETECTION_ITERS = 500                      # number of train iterations
DETECTION_BATCH_SIZE = 64                  # number of crops inside one batch


# Extension parameters
EXTENSION_CROP_SHAPE = (1, 64, 64)         # shape of sampled 3D crops
EXTENSION_ITERS = 500                      # number of train iterations
EXTENSION_BATCH_SIZE = 64                  # number of crops inside one batch
EXTENSION_STRIDE = 32                      # step size for extension
EXTENSION_STEPS = 50                       # number of boundary extensions


# Enhancing parameters
ENHANCE_CROP_SHAPE = (1, 256, 256)         # shape of sampled 3D crops
ENHANCE_ITERS = 500                        # number of train iterations
ENHANCE_BATCH_SIZE = 64                    # number of crops inside one batch

In [None]:
# Model architecture
DETECTION_MODEL_CONFIG = {
    # Model layout
    'initial_block': {
        'base_block': ResBlock,
        'filters': 16,
        'kernel_size': 5,
        'downsample': False,
        'attention': 'scse'
    },

    'body/encoder': {
        'num_stages': 4,
        'order': 'sbd',
        'blocks': {
            'base': ResBlock,
            'n_reps': 1,
            'filters': [32, 64, 128, 256],
            'attention': 'scse',
        },
    },
    'body/embedding': {
        'base': ResBlock,
        'n_reps': 1,
        'filters': 256,
        'attention': 'scse',
    },
    'body/decoder': {
        'num_stages': 4,
        'upsample': {
            'layout': 'tna',
            'kernel_size': 2,
        },
        'blocks': {
            'base': ResBlock,
            'filters': [128, 64, 32, 16],
            'attention': 'scse',
        },
    },
    'head': {
        'base_block': ResBlock,
        'filters': [16, 8],
        'attention': 'scse'
    },
    'output': 'sigmoid',
    # Train configuration
    'loss': 'bdice',
    'optimizer': {'name': 'Adam', 'lr': 0.01,},
    'decay': {'name': 'exp', 'gamma': 0.1, 'frequency': 150},
    'microbatch': 4,
}

In [None]:
EXTENSION_MODEL_CONFIG = {
    # Model layout
    'body/encoder': {
        'num_stages': 4,
        'order': 'sbd',
        'blocks': {
            'base': ResBlock,
            'n_reps': 1,
            'filters': [32, 64, 128, 256],
            'attention': 'scse',
        },
    },
    'body/embedding': {
        'base': ResBlock,
        'n_reps': 1,
        'filters': 256,
        'attention': 'scse',
    },
    'body/decoder': {
        'num_stages': 4,
        'upsample': {
            'layout': 'tna',
            'kernel_size': 2,
        },
        'blocks': {
            'base': ResBlock,
            'filters': [128, 64, 32, 16],
            'attention': 'scse',
        },
    },
    'head': {
        'base_block': ResBlock,
        'filters': [16, 8],
        'attention': 'scse'
    },
    'output': 'sigmoid',
    # Train configuration
    'loss': 'bdice',
    'optimizer': {'name': 'Adam', 'lr': 0.005,},
    'decay': {'name': 'exp', 'gamma': 0.1, 'frequency': 150},
    'microbatch': 4,
}

In [None]:
ENHANCE_MODEL_CONFIG = {
    # Model layout
    'body/encoder': {
        'num_stages': 4,
        'order': 'sbd',
        'blocks': {
            'base': ResBlock,
            'n_reps': 1,
            'filters': [32, 64, 128, 256],
            'attention': 'scse',
        },
    },
    'body/embedding': {
        'base': ResBlock,
        'n_reps': 1,
        'filters': 256,
        'attention': 'scse',
    },
    'body/decoder': {
        'num_stages': 4,
        'upsample': {
            'layout': 'tna',
            'kernel_size': 2,
        },
        'blocks': {
            'base': ResBlock,
            'filters': [128, 64, 32, 16],
            'attention': 'scse',
        },
    },
    'head': {
        'base_block': ResBlock,
        'filters': [16, 8],
        'attention': 'scse'
    },
    'output': 'sigmoid',
    # Train configuration
    'loss': 'bdice',
    'optimizer': {'name': 'Adam', 'lr': 0.01,},
    'decay': {'name': 'exp', 'gamma': 0.1, 'frequency': 150},
    'microbatch': 4,
}

In [None]:
paths = [
    ('/data/seismic/CUBE_2/M_cube.hdf5', '/data/seismic/CUBE_2/RAW/*'),
]

In [None]:
unrolled = [
    (cube_path, horizon_path)
    for cube_path, horizon_dir in paths
    for horizon_path in glob(horizon_dir)
]

options = [
    KV((cube_path, horizon_path),
       '+'.join((cube_path.split('/')[-1].split('.')[0], horizon_path.split('/')[-1].split('.')[0])))
    for cube_path, horizon_path in unrolled
]
random.shuffle(options)

domain = Option('cube_and_horizon', options)

In [None]:
def perform_one_experiment(config, ppl):
    
    ###################################################################################
    ################################   PARSE CONFIGS   ################################
    ###################################################################################
    # Get all the params from configs
    config = config.config()
    train_cube, horizon = config['cube_and_horizon']
    n_rep = config['repetition']
    
    
    # Directory to save results to
    results_dir = os.path.join(RESEARCH_NAME, 'custom_results')
    
    short_name_cube = train_cube.split('/')[-1].split('.')[0]
    short_name_horizon = horizon.split('/')[-1].split('.')[0]
    alias = os.path.join(short_name_cube, short_name_horizon, f'{n_rep}')
    save_dir = os.path.join(results_dir, alias)
    
    return_value = [[], [], [], []] # coverages, window ratios, support corrs, local corrs
    

    ###################################################################################
    ##################################   DETECTION   ##################################
    ###################################################################################
    # Create Detector instance
    detector = Interpolator(
        batch_size=DETECTION_BATCH_SIZE,
        crop_shape=DETECTION_CROP_SHAPE,
        model_config=DETECTION_MODEL_CONFIG,
        save_dir=save_dir, bar=False
    )
    
    train_dataset = detector.make_dataset(train_cube,
                                          {short_name_cube : [horizon]})

    # Train model
    last_loss = detector.train(dataset=train_dataset,
                               frequencies=FREQUENCIES,
                               n_iters=DETECTION_ITERS,
                               width=5, batch_size_multiplier=1,
                               rebatch_threshold=0.9)
    

    # Inference on the same cube to interpolate horizon on whole spatial range
    detector.inference(dataset=train_dataset,
                       batch_size_multiplier=0.1,
                       version=1, orientation='ix',
                       overlap_factor=OVERLAP_FACTOR)

    infos = detector.evaluate(n=1, add_prefix=False, dump=True, supports=SUPPORTS)
    info = infos[0]
    horizon = detector.predictions[0]
    
    return_value[0].append(horizon.coverage)
    return_value[1].append(info['window_rate'])
    return_value[2].append(info['corrs'])
    return_value[3].append(info['local_corrs'])

    
    for i in range(ITERATIONS):
        ###################################################################################
        ###################################   EXTEND   ####################################
        ###################################################################################
        torch.cuda.empty_cache()

        # Create instance of Enhancer
        extender = Extender(
            batch_size=EXTENSION_BATCH_SIZE,
            crop_shape=EXTENSION_CROP_SHAPE,
            model_config=EXTENSION_MODEL_CONFIG,
            save_dir=os.path.join(save_dir, f'extended_{i}'), bar=False
        )

        # Train model
        extender.train(horizon, n_iters=EXTENSION_ITERS, width=5)

        # Inference: fill the holes and exterior
        horizon = extender.inference(horizon,
                                     n_steps=EXTENSION_STEPS,
                                     stride=EXTENSION_STRIDE)

        # Evaluate results
        horizon = extender.predictions[0]
        extender.targets = detector.targets
        infos = extender.evaluate(n=1, add_prefix=False, dump=True, supports=SUPPORTS)
        info = infos[0]
        
        return_value[0].append(horizon.coverage)
        return_value[1].append(info['window_rate'])
        return_value[2].append(info['corrs'])
        return_value[3].append(info['local_corrs'])


        ###################################################################################
        ###################################   ENHANCE   ###################################
        ###################################################################################
        torch.cuda.empty_cache()

        # Create instance of Enhancer
        enhancer = Enhancer(
            batch_size=ENHANCE_BATCH_SIZE,
            crop_shape=ENHANCE_CROP_SHAPE,
            model_config=ENHANCE_MODEL_CONFIG,
            save_dir=os.path.join(save_dir, f'enhanced_{i}'), bar=False
        )

        # Train model
        enhancer.train(horizon, n_iters=ENHANCE_ITERS, width=5)

        # Inference: try to make every crop a touch better
        enhancer.inference(horizon,
                           batch_size_multiplier=0.1,
                           version=1, orientation='ix',
                           overlap_factor=OVERLAP_FACTOR)

        # Evaluate results
        enhancer.targets = detector.targets
        infos = enhancer.evaluate(n=1, add_prefix=False, dump=True, supports=SUPPORTS)
        info = infos[0]
        horizon = enhancer.predictions[0]
    
        return_value[0].append(horizon.coverage)
        return_value[1].append(info['window_rate'])
        return_value[2].append(info['corrs'])
        return_value[3].append(info['local_corrs'])

    ###################################################################################
    ##############################   SAVE NEXT TO CUBE   ##############################
    ###################################################################################
    cube_dir = os.path.dirname(horizon.geometry.path)
    savepath = os.path.join(cube_dir, 'HORIZONS_DUMP', DUMP_NAME)
    os.makedirs(savepath, exist_ok=True)
    horizon.name = '+' + horizon.name.replace('enhanced_', '').replace('extended_', '')
    savepath = os.path.join(savepath, horizon.name)
    horizon.dump(savepath, add_height=False)
    detector.log(f'Dumped horizon to {savepath}')

    ###################################################################################
    ###################################   RETURNS   ###################################
    ###################################################################################
    
    msg = ''
    for name, value in zip(returned_values, return_value):
        msg += f'        {name} -> {value}\n'
    detector.log(msg)
    return return_value



def clear_previous_results(res_name):
    if os.path.exists(res_name):
        shutil.rmtree(res_name)

In [None]:
# Name of the directory to save logs and results in

clear_previous_results(RESEARCH_NAME)

returned_values = [
    'coverages', 'window_rates', 'corrs', 'local_corrs',
]

# Fake pipeline is needed to pass parameters around
fake_ppl = Pipeline().set_dataset(Dataset(10)).run_later(1, n_iters=1)

research = (
    Research()
    .add_logger(FileLogger)
    .init_domain(domain, n_reps=N_REPS)
    .add_pipeline(fake_ppl, run=True, name='fake')
    .add_callable(
        perform_one_experiment,                         # Callable to run
        returns=returned_values,                        # Names of returned results
        execute='#0',                                   # Execute immediately
        config=RC('fake'),                              # Pass config to the callable
        ppl=RP('fake'),                                 # Pass pipeline to the callable
        name='perform_one_experiment'                   # Name to be shown in the dataframe
    )
)

research.run(
    n_iters=1,
    name=RESEARCH_NAME,
    bar=True,
    workers=WORKERS,
    devices=DEVICES,
    timeout=10000,
)