In [None]:
%load_ext autoreload
%autoreload 2

import os

import tqdm.notebook as tqdm
import tensorflow as tf
import gin
from ddsp.training.data import TFRecordProvider
import ddsp.training
from thesis.notebook_util import play_audio, specplot

gin.enter_interactive_mode()

In [None]:
model_dir = "../data/train_newt5"

gin_file = os.path.join(model_dir, "operative_config-0.gin")

# Parse gin config,
with gin.unlock_config():
    gin.parse_config_file(gin_file, skip_unknown=True)

In [None]:
model = ddsp.training.models.get_model()
# Loads the latest checkpoint
model.restore(model_dir)

In [None]:
data_provider = TFRecordProvider(file_pattern="/Users/vaclav/prog/thesis/data/violin/violin.tfrecord*")

In [None]:
batch_size = 1
repeats = True

dataset = data_provider.get_batch(batch_size=batch_size,
                                  shuffle=False,
                                  repeats=-1) # repeat infinitely

In [None]:
dataset_iter = iter(dataset)

with tf.profiler.experimental.Profile('logdir'):
    for i in tqdm.trange(10):
        with tf.profiler.experimental.Trace('predict', step_num=i):  # , _r=1
            batch = next(dataset_iter)
            outputs, losses = model(batch, return_losses=True, training=False)

In [None]:
class FakeModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.l1 = tf.keras.layers.Dense(16, activation=tf.nn.relu, name="bar")
        self.l2 = tf.keras.layers.Dense(8, name="baz")

    def call(self, x, *args, training=False, **kwargs):
        y = self.l1(x)
        y = self.l2(y)

        return y

In [None]:
fake_model = FakeModel()

In [None]:
with tf.profiler.experimental.Profile('logdir'):
    for i in tqdm.trange(10):
        fake_batch = tf.random.normal((8, 32))
        with tf.profiler.experimental.Trace("TraceContext", step_num=i, graph_type="train"):  # , _r=1
            fake_model(fake_batch)

In [None]:
outputs.keys()

In [None]:
play_audio(outputs["audio_synth"])
play_audio(outputs["filtered_noise"]["signal"])

In [None]:
import matplotlib.pyplot as plt
plt.matshow(outputs["noise_magnitudes"][0].T)


In [None]:
specplot(outputs["filtered_noise"]["signal"])

In [None]:
import datetime

stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = f'logdir/func/{stamp}'
writer = tf.summary.create_file_writer(logdir)

tf.summary.trace_on(graph=True, profiler=True)

model(batch, return_losses=True, training=False)

with writer.as_default():
  tf.summary.trace_export(
      name="my_func_trace",
      step=0,
      profiler_outdir=logdir)

In [None]:
tf.keras.utils.plot_model(model)