In [1]:
import numpy as np
import tensorflow as tf
import tensornetwork as tn

In [2]:
class TNLayer(tf.keras.layers.Layer):
 
  def __init__(self):
    super(TNLayer, self).__init__()
    # Create the variables for the layer.
    self.a_var = tf.Variable(tf.random.normal(
            shape=(32, 32, 2), stddev=1.0/32.0),
             name="a", trainable=True)
    self.b_var = tf.Variable(tf.random.normal(shape=(32, 32, 2), stddev=1.0/32.0),
                             name="b", trainable=True)
    self.bias = tf.Variable(tf.zeros(shape=(32, 32)), name="bias", trainable=True)
 
  def call(self, inputs):
    # Define the contraction.
    # We break it out so we can parallelize a batch using
    # tf.vectorized_map (see below).
    def f(input_vec, a_var, b_var, bias_var):
      # Reshape to a matrix instead of a vector.
      input_vec = tf.reshape(input_vec, (32,32))
 
      # Now we create the network.
      a = tn.Node(a_var, backend="tensorflow")
      b = tn.Node(b_var, backend="tensorflow")
      x_node = tn.Node(input_vec, backend="tensorflow")
      a[1] ^ x_node[0]
      b[1] ^ x_node[1]
      a[2] ^ b[2]
 
      # The TN should now look like this
      #   |     |
      #   a --- b
      #    \   /
      #      x
 
      # Now we begin the contraction.
      c = a @ x_node
      result = (c @ b).tensor
 
      # To make the code shorter, we also could've used Ncon.
      # The above few lines of code is the same as this:
      # result = tn.ncon([x, a_var, b_var], [[1, 2], [-1, 1, 3], [-2, 2, 3]])
 
      # Finally, add bias.
      return result + bias_var
  
    # To deal with a batch of items, we can use the tf.vectorized_map
    # function.
    # https://www.tensorflow.org/api_docs/python/tf/vectorized_map
    result = tf.vectorized_map(
        lambda vec: f(vec, self.a_var, self.b_var, self.bias), inputs)
    return tf.nn.relu(tf.reshape(result, (-1, 1024)))

In [3]:
tn_model = tf.keras.Sequential(
  [
    tf.keras.Input(shape=(2,)),
    tf.keras.layers.Dense(1024, activation=tf.nn.relu),
    # Here use a TN layer instead of the dense layer.
    TNLayer(),
    tf.keras.layers.Dense(1, activation=None)
  ]
)

In [4]:
tn_model.trainable_variables

[<tf.Variable 'dense/kernel:0' shape=(2, 1024) dtype=float32, numpy=
 array([[-0.03344038,  0.01820505, -0.06670009, ..., -0.03890723,
         -0.06396094,  0.05283929],
        [-0.03297286, -0.04029611, -0.02663104, ...,  0.03087538,
          0.01266984, -0.07342357]], dtype=float32)>,
 <tf.Variable 'dense/bias:0' shape=(1024,) dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'a:0' shape=(32, 32, 2) dtype=float32, numpy=
 array([[[ 0.0191182 ,  0.01152568],
         [ 0.0265045 , -0.04198615],
         [ 0.03249674, -0.01269511],
         ...,
         [ 0.00286332,  0.00638864],
         [-0.0183975 ,  0.01920737],
         [ 0.02486627,  0.01960747]],
 
        [[ 0.01296654,  0.02279922],
         [-0.01134439,  0.02230362],
         [-0.01087074,  0.02630333],
         ...,
         [ 0.00394398,  0.01329732],
         [ 0.00394495,  0.00955237],
         [-0.01647504,  0.02851471]],
 
        [[-0.01736645,  0.01870213],
         [ 0.070

In [5]:
tn_model.layers[1].trainable_variables

[<tf.Variable 'a:0' shape=(32, 32, 2) dtype=float32, numpy=
 array([[[ 0.0191182 ,  0.01152568],
         [ 0.0265045 , -0.04198615],
         [ 0.03249674, -0.01269511],
         ...,
         [ 0.00286332,  0.00638864],
         [-0.0183975 ,  0.01920737],
         [ 0.02486627,  0.01960747]],
 
        [[ 0.01296654,  0.02279922],
         [-0.01134439,  0.02230362],
         [-0.01087074,  0.02630333],
         ...,
         [ 0.00394398,  0.01329732],
         [ 0.00394495,  0.00955237],
         [-0.01647504,  0.02851471]],
 
        [[-0.01736645,  0.01870213],
         [ 0.0701299 ,  0.00992214],
         [-0.01273874,  0.01527259],
         ...,
         [-0.00610562,  0.03217464],
         [-0.02004386,  0.02719508],
         [-0.0067326 , -0.01807734]],
 
        ...,
 
        [[ 0.03232703, -0.00918395],
         [-0.02499166, -0.00664612],
         [-0.02528954,  0.01147025],
         ...,
         [-0.0254052 , -0.01283271],
         [-0.00835265, -0.0770442 ],
         

In [6]:
tn_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 1024)              3072      
_________________________________________________________________
tn_layer (TNLayer)           (None, 1024)              5120      
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 1025      
Total params: 9,217
Trainable params: 9,217
Non-trainable params: 0
_________________________________________________________________
