In [9]:
import numpy as np
import h5py
import scipy.signal
import matplotlib.pyplot as plt

In [39]:
class Hdf5DataGenerator:

    def __call__(self, filename, batchsize):
        
        # When using with tensorflow datasets, string args are passed as bytes.
        # Convert them back
        if isinstance(filename, bytes):
            filename = filename.decode()

        with h5py.File(filename, "r") as fin:

            waveforms = fin.get('waveforms')
            event_types = fin.get('type')
            p_start = fin.get('p_start')
            s_start = fin.get('s_start')

            waveform_length = waveforms[0].shape[0]
            istart = 0
            istop = batchsize
            exhausted = False

            pick_width = 100
            half_pick_width = pick_width // 2
            pick = scipy.signal.windows.gaussian(pick_width, 12)

            while not exhausted:

                data = waveforms[istart:istop]
                targets = []

                # Create the target class waveforms 
                for i in range(batchsize):

                    p_true = np.zeros(shape=(waveform_length))
                    s_true = np.zeros(shape=(waveform_length))
                    n_true = np.ones(shape=(waveform_length))
    
                    # Insert pick 
                    p_pos = p_start[i]
                    p_true[p_pos - half_pick_width : p_pos + half_pick_width] = pick
    
                    s_pos = s_start[i]
                    s_true[s_pos - half_pick_width : s_pos + half_pick_width] = pick
    
                    n_true -= p_true
                    n_true -= s_true

                    targets.append(
                        np.dstack([p_true, s_true, n_true])
                    )

                yield (data, np.vstack(targets))

                istart += batchsize
                istop += batchsize

                if istop > len(waveforms):
                    exhausted = True

                

            

        

In [40]:
gen = Hdf5DataGenerator()


for d, t in gen('selected_events.h5', 1):

    print('d.shape:', d.shape)
    print('t.shape:', t.shape)

    _, ax = plt.subplots(4, 1, sharex=True)

    xvals = np.arange(d.shape[1])

    ax[0].plot(

    
    break

targets: [array([[[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        ...,
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.]]], shape=(1, 6000, 3))]
d.shape: (1, 6000, 3)
t.shape: (1, 6000, 3)
