In [None]:
# some imports
import sys
import warnings
warnings.filterwarnings("ignore")

from copy import copy

import numpy as np
import torch.nn as nn
from tqdm.notebook import tqdm_notebook

sys.path.append('../../..')

from seismiqb import *
from seismiqb.src.controllers.torch_models import ExtensionModel

from seismiqb.batchflow import FilesIndex, Pipeline
from seismiqb.batchflow import D, B, V, P, R, L

In [None]:
rm -r '/data/seismic_data/seismic_interpretation/CUBE_16_PSDM//INPUTS/FAULTS/HDF5/faults.hdf5'

In [None]:
%%time

cube_path = '/data/seismic_data/seismic_interpretation/CUBE_16_PSDM/amplitudes_16_PSDM.hdf5'

dataset = SeismicCubeset(FilesIndex(path=cube_path, no_ext=True))

dataset.load(label_dir='/INPUTS/FAULTS/NPY/*', labels_class=Fault, width=1)
dataset.modify_sampler(dst='train_sampler', finish=True)

#CPU times: user 16min 42s, sys: 26.9 s, total: 17min 9s
#Wall time: 17min 9s

In [None]:
# ! rm -r /data/seismic_data/seismic_interpretation/CUBE_16_PSDM/INPUTS/FAULTS/HDF5/faults.hdf5

In [None]:
# dataset.dump_labels('/INPUTS/FAULTS/HDF5', fmt='hdf5')

# Map of faults

In [None]:
_ = dataset.show_slices(src_sampler='train_sampler',
                        normalize=True, shape=(1, 128, 128),
                        adaptive_slices=False,
                        cmap='Reds', interpolation='bilinear',
                        figsize=(8, 6))

In [None]:
for i in range(len(dataset)):
    bounds = min([fault.points[:, 2].min() for fault in dataset.labels[i]]), max([fault.points[:, 2].max() for fault in dataset.labels[i]])
    points = np.random.choice(len(dataset.labels[i]), 3, replace=False)
    for p in points:
        dataset.show_slide(dataset.labels[i][p].points[0, 0], idx=i,
                           figsize=(20,20), zoom_slice = (slice(None, None), slice(*bounds)))

In [None]:
BATCH_SIZE = 64
CROP_SHAPE = (1, 256, 256)
NUM_ITERS = 300
N_STEPS = 8
STRIDE = 20

In [None]:
load = (Pipeline()
        .crop(points=D('train_sampler')(BATCH_SIZE),
              shape=CROP_SHAPE, side_view=False)
        .create_masks(dst='masks', width=5)
        .load_cubes(dst='images'))

In [None]:
show_pipeline = (load ) << dataset
batch = show_pipeline.next_batch(1)

In [None]:
batch.plot_components('images', 'masks', idx=0, mode='separate')

In [None]:
batch.plot_components('images', 'masks', idx=0, mode='overlap')