In [1]:
import sys
import warnings
import numpy as np

warnings.filterwarnings('ignore')

sys.path.insert(0, '../../../seismiqb')
sys.path.insert(0, '..')

from seismiqb import SeismicDataset, SeismicSampler
from batchflow import B, F

In [2]:
cube_path = 'geometry_test_files/test_cube.sgy'

In [3]:
ds = SeismicDataset([cube_path])
sampler = SeismicSampler(ds.fields, crop_shape=(1, 128, 128))
for item in ds:
    item.make_normalizer()
    item.make_quantizer(clip=False, ranges=0.)

In [4]:
for stats_src in ['field', 'images', 'biased_images']:
    bias = np.random.normal(size=16)
    p = (ds.p
         .make_locations(generator=sampler, batch_size=16)
         .load_seismic(dst='images')
         .update(B('biased_images'), B('images') + bias.reshape(-1, 1, 1, 1))
         .normalize(src='biased_images', dst='biased_images_normalized')
         .normalize(src='images', dst='images_normalized', stats=stats_src if stats_src != 'images' else None)
         .denormalize(src='images_normalized', dst='denormalized', stats=stats_src)
         .update(B.diff, F(np.abs)(B.denormalized - B.images))
        )

    batch = p.next_batch(1)
    normalized = batch.images_normalized.mean(axis=(1, 2, 3))
    images = batch.images.mean(axis=(1, 2, 3))
    
    if stats_src == 'field':
        stats = (ds[0].normalization_stats['mean'], ds[0].normalization_stats['std'])
    else:
        src = getattr(batch, stats_src)
        stats = (src.mean(axis=(1, 2, 3)), src.std(axis=(1, 2, 3)))

    assert batch.diff.max() / np.abs(batch.images).max() < 1e-5, 'images reconstructed incorrectly'
    assert np.abs((images - stats[0]) / stats[1] - normalized).max() < 1e-5, 'images normalized incorrectly'

In [5]:
p = (ds.p
     .make_locations(generator=sampler, batch_size=16)
     .load_seismic(dst='images')
     .quantize(src='images', dst='images_quantized')
     .dequantize(src='images_quantized', dst='images_recovered')
     .update(B.diff, F(np.abs)(B.images_recovered - B.images))
    )
batch = p.next_batch(1)

assert batch.diff.max() < ds[0].quantizer.estimated_absolute_error, 'images reconstructed incorrectly'