In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import joblib
import glob

from data_utils.mg_sg_generator import get_dataset_ids, get_dataset_for
from data_utils.mg_sg_generator import MotiongramSpectrogramGenerator

import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense

In [2]:
fout = get_dataset_for(nfft=256)
for k, v in fout.items():
    print(f'[{k}] Wrote {v[0]} files to {v[1]}')

[train] Wrote 1056 files to /home/sbol13/sbol_data/datasets/mg_sg_pair_129x128_train_1056
[validation] Wrote 227 files to /home/sbol13/sbol_data/datasets/mg_sg_pair_129x128_validation_227
[test] Wrote 227 files to /home/sbol13/sbol_data/datasets/mg_sg_pair_129x128_test_227


In [44]:
def parse_tfr_example(example):
    data = {
        'motiongram'  : tf.io.FixedLenFeature([], tf.string),
        'spectrogram' : tf.io.FixedLenFeature([], tf.string)
    }
    content = tf.io.parse_single_example(example, data)
    mg, sg = content['motiongram'], content['spectrogram']
    mg_feature = tf.reshape(tf.io.parse_tensor(mg, out_type=tf.float32), shape=[129*128])
    sg_feature = tf.reshape(tf.io.parse_tensor(sg, out_type=tf.float32), shape=[129*128])
    return mg_feature, sg_feature

def get_dataset_small(filename):
    # create the dataset
    dataset = tf.data.TFRecordDataset(filename)

    # pass every single feature through our mapping function
    dataset = dataset.map(
      parse_tfr_example
    )
    
    return dataset

def input_fn(filename, batch_size):
    ds = get_dataset_small(filename)
    ds = ds.shuffle(10000)
    ds = ds.batch(batch_size)
    ds = ds.repeat()
    #ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

### Setup a simple test model 

In [45]:
inp = Input(shape=(129*128, ))
x = Dense(100)(inp)
out = Dense(129*128)(x)

model = Model(inputs=[inp], outputs=[out])
model.compile(optimizer="adam", loss="mse", metrics=["accuracy"])

In [46]:
# Set batch size
batch_size = 100

ds_train      = input_fn(fout["train"][-1], batch_size=batch_size)
ds_validation = input_fn(fout["validation"][-1], batch_size=batch_size)

In [49]:
model.fit(
    ds_train,
    steps_per_epoch=1056//100,
    validation_data=ds_validation,
    validation_steps=227//100,
    epochs=20
)

Train for 10 steps, validate for 2 steps
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7f71ec6b3d90>

In [None]:
import numpy as np
g = gen[0]
mg, sg = g
mg_0 = np.reshape(mg[0, :], newshape=(129, 128))
sg_0 = np.reshape(sg[0, :], newshape=(129, 128))
mg_1 = np.reshape(mg[-1, :], newshape=(129, 128))
sg_1 = np.reshape(sg[-1, :], newshape=(129, 128))

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
axs[0, 0].set_xticks([])
axs[0, 0].set_yticks([])
axs[0, 0].set_xlabel("(a) Input Motiongram")
axs[0, 0].imshow(mg_0, aspect="auto", cmap="Spectral_r", interpolation="bicubic")

axs[0, 1].set_xticks([])
axs[0, 1].set_yticks([])
axs[0, 1].set_xlabel("(a) Input Motiongram")
axs[0, 1].imshow(mg_1, aspect="auto", cmap="Spectral_r", interpolation="bicubic")

axs[1, 0].set_xticks([])
axs[1, 0].set_yticks([])
axs[1, 0].set_xlabel("(a) Input Motiongram")
axs[1, 0].imshow(np.flipud(sg_0), aspect="auto", cmap="Spectral_r", interpolation="bicubic")

axs[1, 1].set_xticks([])
axs[1, 1].set_yticks([])
axs[1, 1].set_xlabel("(a) Input Motiongram")
axs[1, 1].imshow(np.flipud(sg_1), aspect="auto", cmap="Spectral_r", interpolation="bicubic")

plt.show()