In [None]:
import ddsp.training.models
%load_ext autoreload
%autoreload 2

import gin
import os

gin.enter_interactive_mode()

In [None]:
import seaborn as sns
sns.set()

In [None]:
base_dir = "/Users/vaclav/prog/thesis/data/models"
good_model_name = "0721-ddspae-cnn-8"
bad_model_name = "0728-ddspae-cnn"
with gin.unlock_config():
    gin.parse_config_file(os.path.join(base_dir, good_model_name, "operative_config-0.gin"))


In [None]:
import ddsp.training

data_provider = ddsp.training.data.TFRecordProvider(
    "/Users/vaclav/prog/thesis/data/datasets/transfer4/transfer4.tfrecord*",
    frame_rate=50,
    centered=True,
    with_jukebox=False,
)
dataset = data_provider.get_batch(batch_size=1, shuffle=True, repeats=1)
batch = next(iter(dataset))

def load_model(name):

    model = ddsp.training.models.get_model()
    model.restore(os.path.join(base_dir, name))

    model(batch)

    return model

In [None]:
good_model = load_model(good_model_name)
bad_model = load_model(bad_model_name)

In [None]:
import matplotlib.pyplot as plt
plt.style.use({'figure.facecolor': 'gray'})

assert len(good_model.decoder.weights) == len(bad_model.decoder.weights)

for gw, bw in zip(good_model.decoder.weights, bad_model.decoder.weights):

    print(gw.name, gw.numpy().reshape(-1)[:5])
    if gw.numpy().size == 1:
        plt.bar([gw.numpy(), bw.numpy()])
    else:
        # continue
        fig, axes = plt.subplots(2, 1, sharex=True)

        axes[0].hist(gw.numpy().reshape(-1), alpha=0.5)
        axes[1].hist(bw.numpy().reshape(-1), alpha=0.5)
        axes[0].set_title(gw.name)
        axes[1].set_title(bw.name)
        fig.tight_layout()
    # print(gw.name, bw.name)
    plt.show()

In [None]:
import tensorflow as tf

reader = tf.train.load_checkpoint(os.path.join(base_dir, good_model_name))
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

for k in shape_from_key.keys():
    if k.startswith("model/decoder") and "optimizer/" not in k:
        print(k, reader.get_tensor(k).reshape(-1)[:5])