In [1]:
import tensorflow as tf
from tensorflow import keras

# Layer

In [3]:
class Linear(keras.layers.Layer):

  def __init__(self, units=32, input_dim=32):
      super(Linear, self).__init__()
      w_init = tf.random_normal_initializer()
      self.w = tf.Variable(
          initial_value=w_init(shape=(input_dim, units), dtype='float32'),
          trainable=True)
      b_init = tf.zeros_initializer()
      self.b = tf.Variable(
          initial_value=b_init(shape=(units,), dtype='float32'),
          trainable=True)

  def call(self, inputs):
      return tf.matmul(inputs, self.w) + self.b

In [5]:
# Instantiate our layer.
linear_layer = Linear(units=4, input_dim=2)


y = linear_layer(tf.ones((2, 2)))
assert y.shape == (2, 4)

In [6]:
assert linear_layer.weights == [linear_layer.w, linear_layer.b]

# Weights

In [7]:
class Linear(keras.layers.Layer):
  def __init__(self, units=32):
      super(Linear, self).__init__()
      self.units = units

  def build(self, input_shape):
      self.w = self.add_weight(shape=(input_shape[-1], self.units),
                               initializer='random_normal',
                               trainable=True)
      self.b = self.add_weight(shape=(self.units,),
                               initializer='random_normal',
                               trainable=True)

  def call(self, inputs):
      return tf.matmul(inputs, self.w) + self.b


# Instantiate our lazy layer.
linear_layer = Linear(4)

# This will also call `build(input_shape)` and create the weights.
y = linear_layer(tf.ones((2, 2)))

# Gradients

In [8]:
# Prepare a dataset.
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype('float32') / 255, y_train))
dataset = dataset.shuffle(buffer_size=1024).batch(64)

# Instantiate our linear layer (defined above) with 10 units.
linear_layer = Linear(10)

# Instantiate a logistic loss function that expects integer targets.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Instantiate an optimizer.
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)

# Iterate over the batches of the dataset.
for step, (x, y) in enumerate(dataset):
  
  # Open a GradientTape.
  with tf.GradientTape() as tape:

    # Forward pass.
    logits = linear_layer(x)

    # Loss value for this batch.
    loss = loss_fn(y, logits)
     
  # Get gradients of weights wrt the loss.
  gradients = tape.gradient(loss, linear_layer.trainable_weights)
  
  # Update the weights of our linear layer.
  optimizer.apply_gradients(zip(gradients, linear_layer.trainable_weights))
  
  # Logging.
  if step % 100 == 0:
    print('Step:', step, 'Loss:', float(loss))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Step: 0 Loss: 2.467670440673828
Step: 100 Loss: 2.3106894493103027
Step: 200 Loss: 2.2388761043548584
Step: 300 Loss: 2.1261556148529053
Step: 400 Loss: 2.016288995742798
Step: 500 Loss: 1.9093151092529297
Step: 600 Loss: 1.7734405994415283
Step: 700 Loss: 1.8460146188735962
Step: 800 Loss: 1.6618050336837769
Step: 900 Loss: 1.5766392946243286


In [9]:
class ComputeSum(keras.layers.Layer):
  """Returns the sum of the inputs."""

  def __init__(self, input_dim):
      super(ComputeSum, self).__init__()
      # Create a non-trainable weight.
      self.total = tf.Variable(initial_value=tf.zeros((input_dim,)),
                               trainable=False)

  def call(self, inputs):
      self.total.assign_add(tf.reduce_sum(inputs, axis=0))
      return self.total  

my_sum = ComputeSum(2)
x = tf.ones((2, 2))

y = my_sum(x)
print(y.numpy())  # [2. 2.]

y = my_sum(x)
print(y.numpy())  # [4. 4.]

assert my_sum.weights == [my_sum.total]
assert my_sum.non_trainable_weights == [my_sum.total]
assert my_sum.trainable_weights == []

[2. 2.]
[4. 4.]


In [10]:
# Let's reuse the Linear class
# with a `build` method that we defined above.

class MLP(keras.layers.Layer):
    """Simple stack of Linear layers."""

    def __init__(self):
        super(MLP, self).__init__()
        self.linear_1 = Linear(32)
        self.linear_2 = Linear(32)
        self.linear_3 = Linear(10)

    def call(self, inputs):
        x = self.linear_1(inputs)
        x = tf.nn.relu(x)
        x = self.linear_2(x)
        x = tf.nn.relu(x)
        return self.linear_3(x)

mlp = MLP()

# The first call to the `mlp` object will create the weights.
y = mlp(tf.ones(shape=(3, 64)))

# Weights are recursively tracked.
assert len(mlp.weights) == 6

In [11]:
mlp = keras.Sequential([keras.layers.Dense(32, activation=tf.nn.relu),
                        keras.layers.Dense(32, activation=tf.nn.relu),
                        keras.layers.Dense(10)])

# Loss Tracking

In [12]:
class ActivityRegularization(keras.layers.Layer):
  """Layer that creates an activity sparsity regularization loss."""
  
  def __init__(self, rate=1e-2):
    super(ActivityRegularization, self).__init__()
    self.rate = rate
  
  def call(self, inputs):
    # We use `add_loss` to create a regularization loss
    # that depends on the inputs.
    self.add_loss(self.rate * tf.reduce_sum(inputs))
    return inputs

In [13]:
# Let's use the loss layer in a MLP block.

class SparseMLP(keras.layers.Layer):
  """Stack of Linear layers with a sparsity regularization loss."""

  def __init__(self):
      super(SparseMLP, self).__init__()
      self.linear_1 = Linear(32)
      self.regularization = ActivityRegularization(1e-2)
      self.linear_3 = Linear(10)

  def call(self, inputs):
      x = self.linear_1(inputs)
      x = tf.nn.relu(x)
      x = self.regularization(x)
      return self.linear_3(x)
    

mlp = SparseMLP()
y = mlp(tf.ones((10, 10)))

print(mlp.losses)  # List containing one float32 scalar

[<tf.Tensor: shape=(), dtype=float32, numpy=0.28619078>]


These losses are cleared by the top-level layer at the start of each forward pass -- they don't accumulate. layer.losses always contains only the losses created during the last forward pass. You would typically use these losses by summing them before computing your gradients when writing a training loop.

In [14]:
# Losses correspond to the *last* forward pass.
mlp = SparseMLP()
mlp(tf.ones((10, 10)))
assert len(mlp.losses) == 1
mlp(tf.ones((10, 10)))
assert len(mlp.losses) == 1  # No accumulation.

# Let's demonstrate how to use these losses in a training loop.

# Prepare a dataset.
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype('float32') / 255, y_train))
dataset = dataset.shuffle(buffer_size=1024).batch(64)

# A new MLP.
mlp = SparseMLP()

# Loss and optimizer.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)

for step, (x, y) in enumerate(dataset):
  with tf.GradientTape() as tape:

    # Forward pass.
    logits = mlp(x)

    # External loss value for this batch.
    loss = loss_fn(y, logits)
    
    # Add the losses created during the forward pass.
    loss += sum(mlp.losses)
     
    # Get gradients of weights wrt the loss.
    gradients = tape.gradient(loss, mlp.trainable_weights)
  
  # Update the weights of our linear layer.
  optimizer.apply_gradients(zip(gradients, mlp.trainable_weights))
  
  # Logging.
  if step % 100 == 0:
    print('Step:', step, 'Loss:', float(loss))

Step: 0 Loss: 6.676252365112305
Step: 100 Loss: 2.5746729373931885
Step: 200 Loss: 2.4075963497161865
Step: 300 Loss: 2.3879988193511963
Step: 400 Loss: 2.3600149154663086
Step: 500 Loss: 2.351372718811035
Step: 600 Loss: 2.326296806335449
Step: 700 Loss: 2.322831869125366
Step: 800 Loss: 2.3203039169311523
Step: 900 Loss: 2.325289011001587
