In [1]:
import numpy as np
import tensorflow as tf
import tensornetwork as tn

In [10]:
class SimpleTNLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(SimpleTNLayer, self).__init__()

        self.mps_tensors = [
            tf.Variable(tf.random.uniform(shape=(2, 10)), 
                                          name='node_0', 
                                          trainable=True),
            tf.Variable(tf.random.uniform(shape=(2, 10, 10)), 
                                          name='node_1', 
                                          trainable=True),
            tf.Variable(tf.random.uniform(shape=(2, 10, 10, 2)), 
                                          name='node_2', 
                                          trainable=True),
            tf.Variable(tf.random.uniform(shape=(2, 10, 10)), 
                                          name='node_3', 
                                          trainable=True),
            tf.Variable(tf.random.uniform(shape=(2, 10)), 
                                          name='node_4', 
                                          trainable=True)
        ]

    def infer_single(self, input):
        nodes = [
            tn.Node(self.mps_tensors[i], backend='tensorflow') 
            for i in range(len(self.mps_tensors))
        ]

        input_nodes = [
            tn.Node(input[i, :], backend='tensorflow')
            for i in range(len(self.mps_tensors))
        ]


        for i in range(len(self.mps_tensors)):
            nodes[i][0] ^ input_nodes[i][0]
        nodes[0][1] ^ nodes[1][1]
        for i in range(1, len(self.mps_tensors)-1):
            nodes[i][2] ^ nodes[i+1][1]


        final_node = tn.contractors.auto(nodes + input_nodes,
                                         output_edge_order=[nodes[2][3]])
        return final_node.tensor

    def call(self, inputs):
        return tf.vectorized_map(self.infer_single, inputs)


In [11]:
tn_model = tf.keras.Sequential([
    tf.keras.Input((5, 2)),
    SimpleTNLayer(),
    tf.keras.layers.Softmax()
])

In [12]:
tn_model.trainable_variables

[<tf.Variable 'node_0:0' shape=(2, 10) dtype=float32, numpy=
 array([[0.54234827, 0.3999573 , 0.06133616, 0.5923351 , 0.49804282,
         0.7855842 , 0.16452289, 0.08325756, 0.18443346, 0.507283  ],
        [0.80793905, 0.3609389 , 0.12488151, 0.0431869 , 0.92400813,
         0.23356175, 0.37550926, 0.69794166, 0.37650466, 0.6860131 ]],
       dtype=float32)>,
 <tf.Variable 'node_1:0' shape=(2, 10, 10) dtype=float32, numpy=
 array([[[0.69739485, 0.01937354, 0.3424921 , 0.7881689 , 0.3193506 ,
          0.3351705 , 0.48932981, 0.90750694, 0.3534441 , 0.10870409],
         [0.37587988, 0.6129323 , 0.25032854, 0.74853194, 0.58265173,
          0.28360653, 0.562902  , 0.06100035, 0.03496718, 0.02816164],
         [0.78186476, 0.44200313, 0.88129926, 0.31243563, 0.7608769 ,
          0.3218944 , 0.76513445, 0.5259353 , 0.6599312 , 0.26157248],
         [0.02150238, 0.05553484, 0.39944792, 0.25459933, 0.8007219 ,
          0.6304586 , 0.2837217 , 0.4519105 , 0.07094193, 0.18903255],
       

In [13]:
tn_model.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
simple_tn_layer_4 (SimpleTNL (None, 2)                 840       
_________________________________________________________________
softmax_4 (Softmax)          (None, 2)                 0         
Total params: 840
Trainable params: 840
Non-trainable params: 0
_________________________________________________________________
