In [1]:
import tensorflow as tf

%load_ext tensorboard

In [5]:
class SimpleModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.a_variable = tf.Variable(5.0, name='trainable')
        self.b_variable = tf.Variable(5.0, trainable=False, name='untrainable')

    def __call__(self, inputs):
        return self.a_variable * inputs + self.b_variable

simple_module = SimpleModule()
print(f'5.: {simple_module(5.)}')
print(f'tf.constant(5.): {simple_module(tf.constant(5.))}')
print(f'tf.constant([5., 5.]): {simple_module(tf.constant([5., 5.]))}')

5.: 30.0
tf.constant(5.): 30.0
tf.constant([5., 5.]): [30. 30.]


In [11]:
class Dense(tf.Module):
    def __init__(self, in_features, out_features, name=None):
        super().__init__(name=name)
        self.w = tf.Variable(tf.random.normal([in_features, out_features]), name='w')
        self.b = tf.Variable(tf.zeros([out_features]), name='b')

    def __call__(self, inputs):
        x = tf.matmul(inputs, self.w) + self.b
        return tf.nn.relu(x)

class SequentialModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.layer_1 = Dense(in_features=3, out_features=3)
        self.layer_2 = Dense(in_features=3, out_features=2)

    def __call__(self, inputs):
        x = self.layer_1(inputs)
        return self.layer_2(x)

simple_model = SequentialModule()
simple_model(tf.constant([[1., 2., 3.]]))

<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.      , 4.229596]], dtype=float32)>

In [12]:
print(simple_model.submodules)
print(simple_model.variables)

(<__main__.Dense object at 0x7f99b41195b0>, <__main__.Dense object at 0x7f99b4119790>)
(<tf.Variable 'b:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Variable 'w:0' shape=(3, 3) dtype=float32, numpy=
array([[-0.7205934 , -1.3891398 , -0.22253664],
       [-0.6661983 , -0.4855898 , -0.5871777 ],
       [ 2.0263572 ,  1.1751839 , -0.06157124]], dtype=float32)>, <tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>, <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=
array([[-0.13722058,  0.64954734],
       [ 0.14585947,  1.3873916 ],
       [ 1.2594806 , -0.02584443]], dtype=float32)>)
