In [1]:
import numpy as np
import tensornetwork as tn
import tensorflow as tf
from functools import partial

# Layers

In [2]:
class QuantumDMRGLayer(tf.keras.layers.Layer):
    def __init__(self, dimvec, pos_label, nblabels, bond_len, unihigh):
        super(QuantumDMRGLayer, self).__init__()
        self.dimvec = dimvec
        self.pos_label = pos_label
        self.nblabels = nblabels
        self.m = bond_len
        self.unihigh = unihigh

        # end_node = lambda i: tf.Variable(tf.random.uniform(shape=(2, self.m),
        #                                                    minval=0,
        #                                                    maxval=self.unihigh),
        #                                  name='mps_node_{}'.format(i),
        #                                  trainable=True)

        # label_node = lambda i: tf.Variable(tf.random.uniform(shape=(2, self.m, self.m, self.nblabels),
        #                                                      minval=0,
        #                                                      maxval=self.unihigh),
        #                                    name='mps_node_{}'.format(i),
        #                                    trainable=True)
        # normal_node = lambda i: tf.Variable(tf.random.uniform(shape=(2, self.m, self.m),
        #                                                       minval=0,
        #                                                       maxval=self.unihigh),
        #                                     name='mps_node_{}'.format(i),
        #                                     trainable=True)
        
        # self.mps_tf_vars = [None] * self.dimvec
        # for i in range(self.dimvec):
        #     self.mps_tf_vars[i] = tf.case([(tf.math.logical_or(tf.math.equal(i, 0), tf.math.equal(i, self.dimvec - 1)), partial(end_node, i=i)),
        #                                    (tf.math.equal(i, self.pos_label), partial(label_node, i=i))],
        #                                    default=partial(normal_node, i=i)
        #                           )
        self.mps_tf_vars = [None] * self.dimvec
        for i in range(self.dimvec):
            if i == 0 or i == self.dimvec-1:
                self.mps_tf_vars[i] = tf.Variable(tf.random.uniform(shape=(2, self.m),
                                                                    minval=0,
                                                                    maxval=self.unihigh),
                                                  name='mps_node_{}'.format(i),
                                                  trainable=True)
            elif i == self.pos_label:
                self.mps_tf_vars[i] = tf.Variable(tf.random.uniform(shape=(2, self.m, self.m, self.nblabels),
                                                                    minval=0,
                                                                    maxval=self.unihigh),
                                                  name='mps_node_{}'.format(i),
                                                  trainable=True)
            else:
                self.mps_tf_vars[i] = tf.Variable(tf.random.uniform(shape=(2, self.m, self.m),
                                                                    minval=0,
                                                                    maxval=self.unihigh),
                                                  name='mps_node_{}'.format(i),
                                                  trainable=True)

    def infer_single_datum(self, input, mps_tf_vars):
        nodes = [
            tn.Node(mps_tf_vars[i], backend='tensorflow')
            for i in range(self.dimvec)
        ]

        input_nodes = [None]*self.dimvec

        for i in range(self.dimvec):
            input_nodes[i] = tn.Node(input[i, :], backend='tensorflow')
        nodes[0][1] ^ nodes[1][1]
        for i in range(1, self.dimvec - 1):
            nodes[i][2] ^ nodes[i + 1][1]
        for i in range(self.dimvec):
            nodes[i][0] ^ input_nodes[i][0]

        # final_node = tn.contractors.auto(self.nodes + self.input_nodes,
        #                                    output_edge_order=[self.nodes[self.pos_label][3]])
        final_node = nodes[0] @ nodes[1]
        for node in nodes[2:]+input_nodes:
            final_node = final_node @ node
        
        return final_node.tensor

    def call(self, inputs):
        return tf.vectorized_map(lambda input: self.infer_single_datum(input, self.mps_tf_vars), inputs)

In [9]:
quantum_dmrg_model = tf.keras.Sequential([
    tf.keras.Input(shape=(3, 2)),
    QuantumDMRGLayer(dimvec=3, pos_label=1, nblabels=2, bond_len=5, unihigh=0.05),
    # tf.keras.layers.Dense(3),
    tf.keras.layers.Softmax()
])

In [10]:
quantum_dmrg_model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy())

In [11]:
cosx = np.random.uniform(size=3)

inputs = np.array([np.array([cosx, np.sqrt(1-cosx*cosx)]).T])

In [12]:
output = quantum_dmrg_model.predict(inputs)

In [13]:
output

array([[0.49999058, 0.5000095 ]], dtype=float32)

In [14]:
quantum_dmrg_model.trainable_variables

[]

In [16]:
quantum_dmrg_model.layers[0].non_trainable_variables

[]

In [17]:
quantum_dmrg_model.layers[0].mps_tf_vars[0].trainable

True

In [18]:
quantum_dmrg_model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantum_dmrg_layer_1 (Quantu (None, 2)                 0         
_________________________________________________________________
softmax_1 (Softmax)          (None, 2)                 0         
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
