In [1]:
import numpy as np
import tensorflow as tf
import tensornetwork as tn

In [10]:
class SimpleTNLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(SimpleTNLayer, self).__init__()

        self.mps_tensors = [
            tf.Variable(tf.random.uniform(shape=(2, 10)), 
                                          name='node_0', 
                                          trainable=True),
            tf.Variable(tf.random.uniform(shape=(2, 10, 10, 2)), 
                                          name='node_1', 
                                          trainable=True),
            tf.Variable(tf.random.uniform(shape=(2, 10)), 
                                          name='node_2', 
                                          trainable=True)
        ]

    def infer_single(self, input):
        nodes = [
            tn.Node(self.mps_tensors[i], backend='tensorflow') 
            for i in range(len(self.mps_tensors))
        ]

        input_nodes = [
            tn.Node(input[i, :], backend='tensorflow')
            for i in range(len(self.mps_tensors))
        ]

        for i in range(len(self.mps_tensors)):
            nodes[i][0] ^ input_nodes[i][0]
        nodes[0][1] ^ nodes[1][1]
        nodes[1][2] ^ nodes[2][1]
        
        final_node = tn.contractors.auto(nodes + input_nodes,
                                         output_edge_order=[nodes[1][3]])
        return final_node.tensor

    def call(self, inputs):
        return tf.vectorized_map(self.infer_single, inputs)


In [11]:
tn_model = tf.keras.Sequential([
    tf.keras.Input((3, 2)),
    SimpleTNLayer(),
    tf.keras.layers.Softmax()
])

In [12]:
tn_model.trainable_variables

[<tf.Variable 'node_0:0' shape=(2, 10) dtype=float32, numpy=
 array([[0.38165748, 0.8039684 , 0.7374232 , 0.40349054, 0.23843908,
         0.28980172, 0.7949959 , 0.08051372, 0.68883395, 0.2583654 ],
        [0.6471882 , 0.04205132, 0.5704813 , 0.7546179 , 0.2662834 ,
         0.22476602, 0.20511842, 0.58492196, 0.75513875, 0.48147857]],
       dtype=float32)>,
 <tf.Variable 'node_1:0' shape=(2, 10, 10, 2) dtype=float32, numpy=
 array([[[[0.6897962 , 0.99940157],
          [0.23827744, 0.3852216 ],
          [0.23650181, 0.31728435],
          [0.20951784, 0.5390173 ],
          [0.0698669 , 0.7405951 ],
          [0.24169123, 0.5721173 ],
          [0.29905832, 0.732903  ],
          [0.94367623, 0.46130157],
          [0.74908876, 0.27011156],
          [0.26887453, 0.01655221]],
 
         [[0.24014008, 0.8287698 ],
          [0.29903197, 0.60097957],
          [0.91802394, 0.8654877 ],
          [0.26715314, 0.66168165],
          [0.38975966, 0.58461916],
          [0.8166213 , 0.

In [5]:
tn_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
simple_tn_layer (SimpleTNLay (None, 2)                 440       
_________________________________________________________________
softmax (Softmax)            (None, 2)                 0         
Total params: 440
Trainable params: 440
Non-trainable params: 0
_________________________________________________________________
