# Rate–distortion experiments with toy sources

This notebook contains code to train VECVQ and NTC models using stochastic rate–distortion optimization.

The Laplace and Banana sources are described in:

> "Nonlinear Transform Coding"<br />
> J. Ballé, P. A. Chou, D. Minnen, S. Singh, N. Johnston, E. Agustsson, S. J. Hwang, G. Toderici<br />
> https://arxiv.org/abs/2007.03034

The Sawbridge process is described in:

> "Neural Networks Optimally Compress the Sawbridge"<br />
> A. B. Wagner, J. Ballé<br />
> https://arxiv.org/abs/2011.05065

This notebook requires TFC v2 (`pip install tensorflow-compression==2.*`)


In [None]:
#@title Dependencies for Colab

# Run this cell to install the necessary dependencies when running the notebook
# directly in a Colaboratory hosted runtime from Github.

!pip install tensorflow-compression
![[ -e /tfc ]] || git clone https://github.com/tensorflow/compression /tfc
%cd /tfc/models


In [None]:
#@title Imports

from absl import logging
import numpy as np
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_probability as tfp

tfm = tf.math
tfkl = tf.keras.layers
tfpb = tfp.bijectors
tfpd = tfp.distributions


In [None]:
#@title Matplotlib configuration

import cycler
import matplotlib as mpl
import matplotlib.pyplot as plt

_colors = [
    "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
    "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf",
]

plt.rc("axes", facecolor="white", labelsize="large",
       prop_cycle=cycler.cycler(color=_colors))
plt.rc("grid", color="black", alpha=.1)
plt.rc("legend", frameon=True, framealpha=.9, borderpad=.5, handleheight=1,
       fontsize="large")
plt.rc("image", cmap="viridis", interpolation="nearest")
plt.rc("figure", figsize=(16, 8))


In [None]:
#@title Source definitions

from toy_sources import sawbridge
from toy_sources import circle
from toy_sources import ramp
from toy_sources import sinusoid
from toy_sources import stat_sawbridge


def _rotation_2d(degrees):
  phi = tf.convert_to_tensor(degrees / 180 * np.pi, dtype=tf.float32)
  rotation = [[tfm.cos(phi), -tfm.sin(phi)], [tfm.sin(phi), tfm.cos(phi)]]
  rotation = tf.linalg.LinearOperatorFullMatrix(
      rotation, is_non_singular=True, is_square=True)
  return rotation


def get_laplace(loc=0, scale=1):
  return tfpd.Independent(
      tfpd.Laplace(loc=[loc], scale=[scale]),
      reinterpreted_batch_ndims=1,
  )


def get_banana():
  return tfpd.TransformedDistribution(
      tfpd.Independent(tfpd.Normal(loc=[0, 0], scale=[3, .5]), 1),
      tfpb.Invert(tfpb.Chain([
          tfpb.RealNVP(
              num_masked=1,
              shift_and_log_scale_fn=lambda x, _: (.1 * x ** 2, None)),
          tfpb.ScaleMatvecLinearOperator(_rotation_2d(240)),
          tfpb.Shift([1, 1]),
      ])),
  )


def get_sawbridge(order=1, stationary=False, num_points=1024):
  index_points = tf.linspace(0., 1., num_points)
  return sawbridge.Sawbridge(
      index_points, stationary=stationary, order=order)

def get_statsawbridge(drop=None,phase=None,order=1, stationary=False, num_points=1024):
  index_points = tf.linspace(0., 1., num_points)
  return stat_sawbridge.Sawbridge(index_points,phase=phase,drop=drop,stationary=stationary, order=order)

def get_circle(width=0.):
  return circle.Circle(width=width)

def get_ramp(phase=None,num_points=1024):
  index_points = tf.linspace(0.,1.,num_points)
  return ramp.Ramp(index_points,phase=phase)

def get_sinusoid(phase=None,num_points=1024):
  index_points = tf.linspace(0.,1.,num_points)
  return sinusoid.Sinusoid(index_points,phase=phase)

In [None]:
#@title Model definitions

import pywt
import scipy
from toy_sources import ntc
from toy_sources import vecvq


def _get_activation(activation, dtype):
  if not activation:
    return None
  if activation == "gdn":
    return tfc.GDN(dtype=dtype)
  elif activation == "igdn":
    return tfc.GDN(inverse=True, dtype=dtype)
  else:
    return getattr(tf.nn, activation)


def _make_nlp(units, activation, name, input_shape, dtype):
  kwargs = [dict(  # pylint:disable=g-complex-comprehension
      units=u, use_bias=True, activation=activation,
      name=f"{name}_{i}", dtype=dtype,
  ) for i, u in enumerate(units)]
  kwargs[0].update(input_shape=input_shape)
  kwargs[-1].update(activation=None)
  return tf.keras.Sequential(
      [tf.keras.layers.Dense(**k) for k in kwargs], name=name)


def get_ntc_mlp_model(analysis_filters, synthesis_filters,
                      analysis_activation, synthesis_activation,
                      latent_dims, source, dtype=tf.float32, **kwargs):
  """NTC with MLP transforms."""
  source_dims, = source.event_shape

  analysis = _make_nlp(
      analysis_filters + [latent_dims],
      _get_activation(analysis_activation, dtype),
      "analysis",
      [source_dims],
      dtype,
  )
  synthesis = _make_nlp(
      synthesis_filters + [source_dims],
      _get_activation(synthesis_activation, dtype),
      "synthesis",
      [latent_dims],
      dtype,
  )

  return ntc.NTCModel(
      analysis=analysis, synthesis=synthesis, source=source, dtype=dtype,
      **kwargs)


def get_ltc_model(latent_dims, source, dtype=tf.float32, **kwargs):
  """LTC."""
  source_dims, = source.event_shape

  analysis = tf.keras.Sequential([
      tf.keras.layers.Dense(
          latent_dims, use_bias=True, activation=None, name="analysis",
          input_shape=[source_dims], dtype=dtype),
  ], name="analysis")
  synthesis = tf.keras.Sequential([
      tf.keras.layers.Dense(
          source_dims, use_bias=True, activation=None, name="synthesis",
          input_shape=[latent_dims], dtype=dtype),
  ], name="synthesis")

  return ntc.NTCModel(
      analysis=analysis, synthesis=synthesis, source=source, dtype=dtype,
      **kwargs)


@tf.function
def estimate_klt(source, num_samples, latent_dims):
  """Estimates KLT."""

  dims = source.event_shape[0]

  def energy(samples):
    c = tf.linalg.matmul(samples, samples, transpose_a=True)
    return c / tf.cast(num_samples[1], tf.float32)

  # Estimate mean.
  mean = tf.zeros([dims])
  for _ in range(num_samples[0]):
    samples = source.sample(num_samples[1])
    mean += tf.reduce_mean(samples, axis=0)
  mean /= tf.cast(num_samples[0], tf.float32)

  # Estimate covariance.
  covariance = tf.zeros([dims, dims])
  for _ in range(num_samples[0]):
    samples = source.sample(num_samples[1])
    covariance += energy(samples - mean)
  covariance /= tf.cast(num_samples[0], tf.float32)

  variance = tf.reduce_sum(tf.linalg.diag_part(covariance))
  tf.print("SOURCE VARIANCE:", variance)

  # Compute first latent_dims eigenvalues in descending order.
  eig, eigv = tf.linalg.eigh(covariance)
  eig = eig[::-1]
  eigv = eigv[:, ::-1]
  eig = eig[:latent_dims]
  eigv = eigv[:, :latent_dims]
  tf.print("SOURCE EIGENVALUES:", eig)

  # Estimate covariance again after whitening.
  whitened = tf.zeros([latent_dims, latent_dims])
  for _ in range(num_samples[0]):
    samples = source.sample(num_samples[1])
    whitened += energy(tf.linalg.matmul(samples - mean, eigv))
  whitened /= tf.cast(num_samples[0], tf.float32)
  whitened_var = tf.linalg.diag_part(whitened)
  whitened /= tf.sqrt(
      whitened_var[:, None] * whitened_var[None, :]) + 1e-20
  error = tf.linalg.set_diag(abs(whitened), tf.zeros(latent_dims))
  error = tf.reduce_max(error)
  tf.print("MAX. CORRELATION COEFFICIENT:", error)

  return eigv, error


class ScaleAndBias(tf.keras.layers.Layer):
  """Multiplies each channel by a learned scaling factor and adds a bias."""

  def __init__(self, scale_first, init_scale=1, **kwargs):
    super().__init__(**kwargs)
    self.scale_first = bool(scale_first)
    self.init_scale = float(init_scale)

  def build(self, input_shape):
    input_shape = tf.TensorShape(input_shape)
    channels = int(input_shape[-1])
    self._log_factors = self.add_weight(
        name="log_factors", shape=[channels],
        initializer=tf.keras.initializers.Constant(
            tf.math.log(self.init_scale)))
    self.bias = self.add_weight(
        name="bias", shape=[channels],
        initializer=tf.keras.initializers.Zeros())
    super().build(input_shape)

  @property
  def factors(self):
    return tf.math.exp(self._log_factors)

  def call(self, inputs):
    if self.scale_first:
      return inputs * self.factors + self.bias
    else:
      return (inputs + self.bias) * self.factors


def get_ltc_klt_model(latent_dims, source, num_samples, tolerance, 
                      dtype=tf.float32, **kwargs):
  """LTC constrained to KLT."""
  source_dims, = source.event_shape

  # Estimate KLT from samples.
  eigv, error = estimate_klt(
      source, tf.constant(num_samples), tf.constant(latent_dims))
  assert error < tolerance, error.numpy()
  eigv = tf.cast(eigv, dtype)

  analysis = tf.keras.Sequential([
      tf.keras.layers.Dense(
          latent_dims, use_bias=False, activation=None, name="klt",
          kernel_initializer=lambda *a, **k: eigv,
          trainable=False, input_shape=[source_dims], dtype=dtype),
      ScaleAndBias(
          scale_first=True, name="klt_scaling", dtype=dtype),
  ], name="analysis")
  synthesis = tf.keras.Sequential([
      ScaleAndBias(
          scale_first=False, name="iklt_scaling",
          input_shape=[latent_dims], dtype=dtype),
      tf.keras.layers.Dense(
          source_dims, use_bias=False, activation=None, name="iklt",
          kernel_initializer=lambda *a, **k: tf.transpose(eigv),
          trainable=False, dtype=dtype),
  ], name="synthesis")

  return ntc.NTCModel(
      analysis=analysis, synthesis=synthesis, source=source, dtype=dtype,
      **kwargs)


def get_ltc_ortho_model(latent_dims, source, transform, dtype=tf.float32,
                        **kwargs):
  """LTC constrained to fixed orthonormal transforms."""
  source_dims, = source.event_shape

  if transform == "dct":
    basis = scipy.fftpack.dct(np.eye(source_dims), norm="ortho")
  else:
    num_levels = int(round(np.log2(source_dims)))
    assert 2 ** num_levels == source_dims
    basis = []
    for impulse in np.eye(source_dims):
      levels = pywt.wavedec(
          impulse, transform, mode="periodization", level=num_levels)
      basis.append(np.concatenate(levels))
    basis = np.array(basis)

  # `basis` must have IO format, so DC should be in first column.
  assert np.allclose(basis[:, 0], basis[0, 0])
  assert not np.allclose(basis[0, :], basis[0, 0])

  # `basis` should be orthonormal.
  assert np.allclose(np.dot(basis, basis.T), np.eye(source_dims))

  # Only take the first `latent_dims` basis functions.
  basis = tf.constant(basis[:, :latent_dims], dtype=dtype)

  analysis = tf.keras.Sequential([
      tf.keras.layers.Dense(
          latent_dims, use_bias=False, activation=None, name=transform,
          kernel_initializer=lambda *a, **k: basis,
          trainable=False, input_shape=[source_dims], dtype=dtype),
      ScaleAndBias(
          scale_first=True, name=f"{transform}_scaling", dtype=dtype),
  ], name="analysis")
  synthesis = tf.keras.Sequential([
      ScaleAndBias(
          scale_first=False, name=f"i{transform}_scaling",
          input_shape=[latent_dims], dtype=dtype),
      tf.keras.layers.Dense(
          source_dims, use_bias=False, activation=None, name=f"i{transform}",
          kernel_initializer=lambda *a, **k: tf.transpose(basis),
          trainable=False, dtype=dtype),
  ], name="synthesis")

  return ntc.NTCModel(
      analysis=analysis, synthesis=synthesis, source=source, dtype=dtype,
      **kwargs)


def get_vecvq_model(**kwargs):
  return vecvq.VECVQModel(**kwargs)


In [None]:
#@title Learning schedule definitions

def get_lr_scheduler(learning_rate, epochs, warmup_epochs=0):
  """Returns a learning rate scheduler function for the given configuration."""
  def scheduler(epoch, lr):
    del lr  # unused
    if epoch < warmup_epochs:
      return learning_rate * 10. ** (epoch - warmup_epochs)
    if epoch < 1/2 * epochs:
      return learning_rate
    if epoch < 3/4 * epochs:
      return learning_rate * 1e-1
    if epoch < 7/8 * epochs:
      return learning_rate * 1e-2
    return learning_rate * 1e-3
  return scheduler


class AlphaScheduler(tf.keras.callbacks.Callback):
  """Alpha parameter scheduler."""

  def __init__(self, schedule, verbose=0):
    super().__init__()
    self.schedule = schedule
    self.verbose = verbose

  def on_epoch_begin(self, epoch, logs=None):
    if not hasattr(self.model, "alpha"):
      # Silently ignore models that don't have an alpha parameter.
      return
    self.model.force_alpha = self.schedule(epoch)

  def on_epoch_end(self, epoch, logs=None):
    if not hasattr(self.model, "alpha"):
      # Silently ignore models that don't have an alpha parameter.
      return
    if not hasattr(self.model, "soft_round") or not any(self.model.soft_round):
      # Silently ignore models that don't use soft rounding.
      return
    logs["alpha"] = self.model.alpha


def get_alpha_scheduler(epochs):
  """Returns an alpha scheduler function for the given configuration."""
  def scheduler(epoch):
    if epoch < 1/4 * epochs:
      return 3. * (epoch + 1) / (epochs/4 + 1)
    return None
  return scheduler


In [None]:
#@title Tensorboard logging callback

class LogCallback(tf.keras.callbacks.Callback):
  """Logs metrics to TensorBoard."""

  def __init__(self, log_path):
    super().__init__()
    self.log_path = log_path
    self._train_graph = None
    self._test_graph = None

  def on_train_begin(self, logs=None):
    del logs  # unused
    if not hasattr(self, "train_writer"):
      self.train_writer = tf.summary.create_file_writer(
          self.log_path + "/train")
    self.log_variables()

  def on_test_begin(self, logs=None):
    del logs  # unused
    if not hasattr(self, "test_writer"):
      self.test_writer = tf.summary.create_file_writer(
          self.log_path + "/val")

  def on_test_end(self, logs=None):
    # Log test metrics.
    self.log_tensorboard(
        self.test_writer, {"metrics/" + l: v for l, v in logs.items()})
    self.test_writer.flush()

  def on_epoch_begin(self, epoch, logs=None):
    del logs  # unused
    self.model.epoch.assign(epoch)

  def on_epoch_end(self, epoch, logs=None):
    logs = dict(logs)
    lr = logs.pop("lr")
    alpha = logs.pop("alpha", None)

    # Log training metrics.
    logs = {l: v for l, v in logs.items() if not l.startswith("val_")}
    self.log_tensorboard(
        self.train_writer, {"metrics/" + l: v for l, v in logs.items()})

    # Log learning rate.
    logs = {"learning rate": lr}
    if alpha is not None:
      logs["alpha"] = alpha
    self.log_tensorboard(self.train_writer, logs)

    self.train_writer.flush()

  def on_train_batch_begin(self, batch, logs=None):
    del logs  # unused
    if batch == 0 and not self._train_graph:
      with self.train_writer.as_default():
        tf.summary.trace_on(graph=True, profiler=False)
      self._train_graph = "tracing"

  def on_train_batch_end(self, batch, logs=None):
    del batch, logs  # unused
    if self._train_graph == "tracing":
      with self.train_writer.as_default():
        tf.summary.trace_export("step", step=self.model.epoch.value())
      self._train_graph = "traced"

  def on_test_batch_begin(self, batch, logs=None):
    del logs  # unused
    if batch == 0 and not self._test_graph:
      with self.test_writer.as_default():
        tf.summary.trace_on(graph=True, profiler=False)
      self._test_graph = "tracing"

  def on_test_batch_end(self, batch, logs=None):
    del batch, logs  # unused
    if self._test_graph == "tracing":
      with self.test_writer.as_default():
        tf.summary.trace_export("step", step=self.model.epoch.value())
      self._test_graph = "traced"

  def log_tensorboard(self, writer, logs):
    """Logs the values in `logs` to the summary writer."""
    with writer.as_default():
      for label, value in logs.items():
        tf.summary.scalar(label, value, step=self.model.epoch.value())

  def log_variables(self):
    """Logs shape and dtypes of variable collections."""
    model = self.model
    var_format = lambda v: f"{v.name} {v.dtype} {v.shape}"
    logging.info(
        "TRAINABLE VARIABLES:\n%s\n",
        "\n".join(var_format(v) for v in model.trainable_variables))
 

# Laplace source

In [None]:
#@title VECVQ

work_path = "/tmp/toy_sources/laplace/vecvq"

epochs = 50
steps_per_epoch = 1000
batch_size = 1024
validation_size = 10000000
validation_batch_size = 65536
learning_rate = 1e-3

codebook_size = 128
lmbda = 1.

# tf.debugging.enable_check_numerics()

source = get_laplace()
optimizer = tf.keras.optimizers.Adam()
model = get_vecvq_model(
    codebook_size=codebook_size, initialize="uniform-40",
    source=source, lmbda=lmbda, distortion_loss="sse")
model.compile(optimizer=optimizer)

# Add an epoch counter for keeping track in checkpoints.
model.epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

lr_scheduler = get_lr_scheduler(learning_rate, epochs)
alpha_scheduler = get_alpha_scheduler(epochs)
callback_list = [
    tf.keras.callbacks.ModelCheckpoint(
        work_path + "/checkpoints/ckpt-{epoch:04d}",
        save_weights_only=True),
    tf.keras.callbacks.BackupAndRestore(
        work_path + "/backup"),
    tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
    AlphaScheduler(alpha_scheduler),
    LogCallback(work_path),
]

model.fit(
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    batch_size=batch_size,
    validation_size=validation_size,
    validation_batch_size=validation_batch_size,
    verbose=2,
    callbacks=tf.keras.callbacks.CallbackList(callback_list, model=model),
)


In [None]:
#@title NTC

work_path = "/tmp/toy_sources/laplace/ntc"

epochs = 50
steps_per_epoch = 1000
batch_size = 1024
validation_size = 10000000
validation_batch_size = 65536
learning_rate = 1e-3

latent_dims = 1
analysis_filters = [50, 50]
analysis_activation = "softplus"
synthesis_filters = [50, 50]
synthesis_activation = "softplus"
prior_type = "deep"
dither = (1, 1, 0, 0)
soft_round = (1, 0)
guess_offset = False
lmbda = 1.

# tf.debugging.enable_check_numerics()

source = get_laplace()
optimizer = tf.keras.optimizers.Adam()
model = get_ntc_mlp_model(
    latent_dims=latent_dims,
    analysis_filters=analysis_filters,
    analysis_activation=analysis_activation,
    synthesis_filters=synthesis_filters,
    synthesis_activation=synthesis_activation,
    prior_type=prior_type,
    dither=dither,
    soft_round=soft_round,
    guess_offset=guess_offset,
    source=source, lmbda=lmbda, distortion_loss="sse")
model.compile(optimizer=optimizer)

# Add an epoch counter for keeping track in checkpoints.
model.epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

lr_scheduler = get_lr_scheduler(learning_rate, epochs)
alpha_scheduler = get_alpha_scheduler(epochs)
callback_list = [
    tf.keras.callbacks.ModelCheckpoint(
        work_path + "/checkpoints/ckpt-{epoch:04d}",
        save_weights_only=True),
    tf.keras.callbacks.BackupAndRestore(
        work_path + "/backup"),
    tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
    AlphaScheduler(alpha_scheduler),
    LogCallback(work_path),
]

model.fit(
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    batch_size=batch_size,
    validation_size=validation_size,
    validation_batch_size=validation_batch_size,
    verbose=2,
    callbacks=tf.keras.callbacks.CallbackList(callback_list, model=model),
)


In [None]:
model.plot_quantization([(-5, 5, 1000)])


In [None]:
model.plot_transfer([(-5, 5, 1000)])


# Banana source

In [None]:
#@title VECVQ

work_path = "/tmp/toy_sources/banana/vecvq"

epochs = 100
steps_per_epoch = 1000
batch_size = 1024
validation_size = 10000000
validation_batch_size = 65536
learning_rate = 1e-3

codebook_size = 256
lmbda = 1.

# tf.debugging.enable_check_numerics()

source = get_banana()
optimizer = tf.keras.optimizers.Adam()
model = get_vecvq_model(
    codebook_size=codebook_size, initialize="sample",
    source=source, lmbda=lmbda, distortion_loss="sse")
model.compile(optimizer=optimizer)

# Add an epoch counter for keeping track in checkpoints.
model.epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

lr_scheduler = get_lr_scheduler(learning_rate, epochs)
alpha_scheduler = get_alpha_scheduler(epochs)
callback_list = [
    tf.keras.callbacks.ModelCheckpoint(
        work_path + "/checkpoints/ckpt-{epoch:04d}",
        save_weights_only=True),
    tf.keras.callbacks.BackupAndRestore(
        work_path + "/backup"),
    tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
    AlphaScheduler(alpha_scheduler),
    LogCallback(work_path),
]

model.fit(
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    batch_size=batch_size,
    validation_size=validation_size,
    validation_batch_size=validation_batch_size,
    verbose=2,
    callbacks=tf.keras.callbacks.CallbackList(callback_list, model=model),
)


In [None]:
#@title NTC

work_path = "/tmp/toy_sources/banana/ntc"

epochs = 100
steps_per_epoch = 1000
batch_size = 1024
validation_size = 10000000
validation_batch_size = 65536
learning_rate = 1e-3

latent_dims = 2
analysis_filters = [100, 100]
analysis_activation = "softplus"
synthesis_filters = [100, 100]
synthesis_activation = "softplus"
prior_type = "deep"
dither = (1, 1, 0, 0)
soft_round = (1, 0)
guess_offset = False
lmbda = 1.

# tf.debugging.enable_check_numerics()

source = get_banana()
optimizer = tf.keras.optimizers.Adam()
model = get_ntc_mlp_model(
    latent_dims=latent_dims,
    analysis_filters=analysis_filters,
    analysis_activation=analysis_activation,
    synthesis_filters=synthesis_filters,
    synthesis_activation=synthesis_activation,
    prior_type=prior_type,
    dither=dither,
    soft_round=soft_round,
    guess_offset=guess_offset,
    source=source, lmbda=lmbda, distortion_loss="sse")
model.compile(optimizer=optimizer)

# Add an epoch counter for keeping track in checkpoints.
model.epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

lr_scheduler = get_lr_scheduler(learning_rate, epochs)
alpha_scheduler = get_alpha_scheduler(epochs)
callback_list = [
    tf.keras.callbacks.ModelCheckpoint(
        work_path + "/checkpoints/ckpt-{epoch:04d}",
        save_weights_only=True),
    tf.keras.callbacks.BackupAndRestore(
        work_path + "/backup"),
    tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
    AlphaScheduler(alpha_scheduler),
    LogCallback(work_path),
]

model.fit(
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    batch_size=batch_size,
    validation_size=validation_size,
    validation_batch_size=validation_batch_size,
    verbose=2,
    callbacks=tf.keras.callbacks.CallbackList(callback_list, model=model),
)


In [None]:
model.plot_quantization(2 * [(-5, 5, 1000)])


# Sawbridge source

In [None]:
#@title VECVQ

work_path = "/tmp/toy_sources/sawbridge/vecvq"

epochs = 200
steps_per_epoch = 1000
batch_size = 1024
validation_size = 10000000
validation_batch_size = 4096
learning_rate = 1e-3

codebook_size = 50
lmbda = 1.

# tf.debugging.enable_check_numerics()

source = get_sawbridge()
optimizer = tf.keras.optimizers.Adam()
model = get_vecvq_model(
    codebook_size=codebook_size, initialize="sample-.1",
    source=source, lmbda=lmbda, distortion_loss="mse")
model.compile(optimizer=optimizer)

# Add an epoch counter for keeping track in checkpoints.
model.epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

lr_scheduler = get_lr_scheduler(learning_rate, epochs)
alpha_scheduler = get_alpha_scheduler(epochs)
callback_list = [
    tf.keras.callbacks.ModelCheckpoint(
        work_path + "/checkpoints/ckpt-{epoch:04d}",
        save_weights_only=True),
    tf.keras.callbacks.BackupAndRestore(
        work_path + "/backup"),
    tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
    AlphaScheduler(alpha_scheduler),
    LogCallback(work_path),
]

model.fit(
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    batch_size=batch_size,
    validation_size=validation_size,
    validation_batch_size=validation_batch_size,
    verbose=2,
    callbacks=tf.keras.callbacks.CallbackList(callback_list, model=model),
)


In [None]:
#@title NTC

work_path = "/tmp/toy_sources/sawbridge/ntc"

epochs = 200
steps_per_epoch = 1000
batch_size = 1024
validation_size = 10000000
validation_batch_size = 4096
learning_rate = 1e-3

latent_dims = 10
analysis_filters = [100, 100]
analysis_activation = "leaky_relu"
synthesis_filters = [100, 100]
synthesis_activation = "leaky_relu"
prior_type = "deep"
dither = (1, 1, 0, 0)
soft_round = (1, 0)
guess_offset = False
lmbda = 1.

# tf.debugging.enable_check_numerics()

source = get_sawbridge()
optimizer = tf.keras.optimizers.Adam()
model = get_ntc_mlp_model(
    latent_dims=latent_dims,
    analysis_filters=analysis_filters,
    analysis_activation=analysis_activation,
    synthesis_filters=synthesis_filters,
    synthesis_activation=synthesis_activation,
    prior_type=prior_type,
    dither=dither,
    soft_round=soft_round,
    guess_offset=guess_offset,
    source=source, lmbda=lmbda, distortion_loss="mse")
model.compile(optimizer=optimizer)

# Add an epoch counter for keeping track in checkpoints.
model.epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

lr_scheduler = get_lr_scheduler(learning_rate, epochs)
alpha_scheduler = get_alpha_scheduler(epochs)
callback_list = [
    tf.keras.callbacks.ModelCheckpoint(
        work_path + "/checkpoints/ckpt-{epoch:04d}",
        save_weights_only=True),
    tf.keras.callbacks.BackupAndRestore(
        work_path + "/backup"),
    tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
    AlphaScheduler(alpha_scheduler),
    LogCallback(work_path),
]

model.fit(
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    batch_size=batch_size,
    validation_size=validation_size,
    validation_batch_size=validation_batch_size,
    verbose=2,
    callbacks=tf.keras.callbacks.CallbackList(callback_list, model=model),
)


In [None]:
#@title KLT (dither)

work_path = "/tmp/toy_sources/sawbridge/klt_dither"

epochs = 200
steps_per_epoch = 1000
batch_size = 1024
validation_size = 10000000
validation_batch_size = 4096
learning_rate = 1e-3

num_samples = (1000, 10000)
tolerance = 1e-2
latent_dims = 50
prior_type = "deep"
dither = (1, 1, 1, 1)
soft_round = (0, 0)
guess_offset = False
lmbda = 1.

# tf.debugging.enable_check_numerics()

source = get_sawbridge()
optimizer = tf.keras.optimizers.Adam()
model = get_ltc_klt_model(
    num_samples=num_samples,
    tolerance=tolerance,
    latent_dims=latent_dims,
    prior_type=prior_type,
    dither=dither,
    soft_round=soft_round,
    guess_offset=guess_offset,
    source=source, lmbda=lmbda, distortion_loss="mse")
model.compile(optimizer=optimizer)

# Add an epoch counter for keeping track in checkpoints.
model.epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

lr_scheduler = get_lr_scheduler(learning_rate, epochs)
alpha_scheduler = get_alpha_scheduler(epochs)
callback_list = [
    tf.keras.callbacks.ModelCheckpoint(
        work_path + "/checkpoints/ckpt-{epoch:04d}",
        save_weights_only=True),
    tf.keras.callbacks.BackupAndRestore(
        work_path + "/backup"),
    tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
    AlphaScheduler(alpha_scheduler),
    LogCallback(work_path),
]

model.fit(
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    batch_size=batch_size,
    validation_size=validation_size,
    validation_batch_size=validation_batch_size,
    verbose=2,
    callbacks=tf.keras.callbacks.CallbackList(callback_list, model=model),
)


In [None]:
#@title Daubechies 4-tap

work_path = "/tmp/toy_sources/sawbridge/daub4"

epochs = 200
steps_per_epoch = 1000
batch_size = 1024
validation_size = 10000000
validation_batch_size = 4096
learning_rate = 1e-3

transform = "db4"
latent_dims = 50
prior_type = "deep"
dither = (1, 1, 1, 1)
soft_round = (0, 0)
guess_offset = False
lmbda = 1.

# tf.debugging.enable_check_numerics()

source = get_sawbridge()
optimizer = tf.keras.optimizers.Adam()
model = get_ltc_ortho_model(
    transform=transform,
    latent_dims=latent_dims,
    prior_type=prior_type,
    dither=dither,
    soft_round=soft_round,
    guess_offset=guess_offset,
    source=source, lmbda=lmbda, distortion_loss="mse")
model.compile(optimizer=optimizer)

# Add an epoch counter for keeping track in checkpoints.
model.epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

lr_scheduler = get_lr_scheduler(learning_rate, epochs)
alpha_scheduler = get_alpha_scheduler(epochs)
callback_list = [
    tf.keras.callbacks.ModelCheckpoint(
        work_path + "/checkpoints/ckpt-{epoch:04d}",
        save_weights_only=True),
    tf.keras.callbacks.BackupAndRestore(
        work_path + "/backup"),
    tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
    AlphaScheduler(alpha_scheduler),
    LogCallback(work_path),
]

model.fit(
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    batch_size=batch_size,
    validation_size=validation_size,
    validation_batch_size=validation_batch_size,
    verbose=2,
    callbacks=tf.keras.callbacks.CallbackList(callback_list, model=model),
)
