In [57]:
import tensorflow as tf
from datetime import datetime

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [58]:
class SimpleModule(tf.Module):

    def __init__(self, name=None):
        super().__init__(name=name)
        self.a_variable = tf.Variable(5, dtype=tf.dtypes.float32, name='train_me')
        self.non_trainable_variable = tf.Variable(5, trainable=False, dtype=tf.dtypes.float32, name='do_not_train_me')

    def __call__(self, x):
        return self.a_variable * x + self.non_trainable_variable

In [59]:
simple_module = SimpleModule(name='Simple')
simple_module(tf.constant(5.0)).numpy()

30.0

In [60]:
print("Trainable variables:", simple_module.trainable_variables)
print("All variables:", simple_module.variables)

Trainable variables: (<tf.Variable 'train_me:0' shape=() dtype=float32, numpy=5.0>,)
All variables: (<tf.Variable 'train_me:0' shape=() dtype=float32, numpy=5.0>, <tf.Variable 'do_not_train_me:0' shape=() dtype=float32, numpy=5.0>)


In [61]:
class Dense(tf.Module):

    def __init__(self, in_features, out_features, name=None):
        super().__init__(name=name)
        self.w = tf.Variable(
            tf.random.normal([in_features, out_features], name='w')
        )
        self.b = tf.Variable(tf.ones([out_features]), name='b')

    def __call__(self, x):
        y = tf.matmul(x, self.w) + self.b
        return tf.nn.relu(y)

In [62]:
class SequentialModule(tf.Module):

    def __init__(self, name=None):
        super().__init__(name=name)
        self.dense1 = Dense(in_features=3, out_features=3)
        self.dense2 = Dense(in_features=3, out_features=2)

    def __call__(self, x):
        x = self.dense1(x)
        return self.dense2(x)

In [63]:
mymodel = SequentialModule(name='mymodel')
mymodel([[1.0, 2.0, 3.0]]).numpy()

array([[7.5454993, 0.       ]], dtype=float32)

In [64]:
print(mymodel.submodules)
print()
for vrs in mymodel.variables:
    print(vrs)

(<__main__.Dense object at 0x14913cbe0>, <__main__.Dense object at 0x149140f28>)

<tf.Variable 'b:0' shape=(3,) dtype=float32, numpy=array([1., 1., 1.], dtype=float32)>
<tf.Variable 'Variable:0' shape=(3, 3) dtype=float32, numpy=
array([[-1.0082649 , -0.30613193, -0.07526284],
       [ 0.2806238 ,  0.23274003,  0.26898396],
       [ 3.1036172 ,  0.06940007, -0.5029811 ]], dtype=float32)>
<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=array([1., 1.], dtype=float32)>
<tf.Variable 'Variable:0' shape=(3, 2) dtype=float32, numpy=
array([[ 0.7739295 , -1.4055985 ],
       [-0.7958861 ,  0.75101995],
       [-0.9784334 , -0.9110446 ]], dtype=float32)>


In [65]:
class FlexibleDense(tf.Module):

    def __init__(self, out_features, name=None):
        super().__init__(name=name)
        self.is_built = False
        self.out_features = out_features

    def __call__(self, x):
        if not self.is_built:
            self.in_features = x.shape[-1]
            self.w = tf.Variable(
                tf.random.normal([self.in_features, self.out_features]), 
                name='w'
            )
            self.b = tf.Variable(tf.ones([self.out_features]), name='b')
            self.is_built = True

        y = tf.matmul(x, self.w) + self.b
        return tf.nn.relu(y)



In [66]:
class MySequentialModule(tf.Module):

    def __init__(self, name=None):
        super().__init__(name=name)
        self.dense1 = FlexibleDense(out_features=3, name='dense1')
        self.dense2 = FlexibleDense(out_features=2, name='dense2')

    def __call__(self, x):
        x = self.dense1(x)
        return self.dense2(x)


In [68]:
mymodel2 = MySequentialModule(name='mymodel2')
print(mymodel2(tf.constant([[1.0, 2.0, 3.0, 4.0]])).numpy())

[[6.908826  1.8998265]]
