In [3]:
import numpy as np
import tensornetwork as tn
import tensorflow as tf

# Layers

In [19]:
class QuantumDMRGLayer(tf.keras.layers.Layer):
    def __init__(self, dimvec, pos_label, nblabels, bond_len, unihigh):
        super(QuantumDMRGLayer, self).__init__()
        self.dimvec = dimvec
        self.pos_label = pos_label
        self.nblabels = nblabels
        self.m = bond_len
        self.unihigh = unihigh

        self.mps_tensors = [tf.Variable(tf.random.uniform(shape=self.mps_tensor_shape(i),
                                                          minval=0,
                                                          maxval=self.unihigh),
                                        trainable=True,
                                        name='mps_tensors_{}'.format(i))
                            for i in range(self.dimvec)]

    def mps_tensor_shape(self, idx):
        if idx == 0 or idx == self.dimvec-1:
            return (2, self.dimvec)
        elif idx == self.pos_label:
            return (2, self.dimvec, self.dimvec, self.nblabels)
        else:
            return (2, self.dimvec, self.dimvec)

    def infer_single(self, input):
        assert input.shape[0] == self.dimvec
        assert input.shape[1] == 2
        
        nodes = [
            tn.Node(self.mps_tensors[i], backend='tensorflow')
            for i in range(self.dimvec)
        ]
        input_nodes = [
            tn.Node(input[i, :], backend='tensorflow')
            for i in range(self.dimvec)
        ]

        for i in range(self.dimvec):
            nodes[i][0] ^ input_nodes[i][0]
        nodes[0][1] ^ nodes[1][1]
        for i in range(1, self.dimvec-1):
            nodes[i][2] ^ nodes[i+1][1]

        final_node = tn.contractors.auto(nodes + input_nodes,
                                         output_edge_order=[nodes[self.pos_label][3]])
        return final_node.tensor

    def call(self, inputs):
        return tf.vectorized_map(self.infer_single, inputs)


In [20]:
quantum_dmrg_model = tf.keras.Sequential([
    tf.keras.Input(shape=(10, 2)),
    QuantumDMRGLayer(dimvec=10, pos_label=5, nblabels=3, bond_len=5, unihigh=0.05),
    # tf.keras.layers.Dense(3),
    # tf.keras.layers.Softmax()
])
quantum_dmrg_model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy())

In [21]:
quantum_dmrg_model.trainable_variables

7431630e-02, 3.07178330e-02, 2.24122293e-02],
          [5.85508347e-03, 4.28246856e-02, 2.30904873e-02],
          [2.34359987e-02, 4.45082858e-02, 1.40963318e-02],
          [4.84372340e-02, 2.82031782e-02, 4.33179736e-02],
          [4.18792330e-02, 3.07191610e-02, 2.37146802e-02],
          [4.04254459e-02, 3.65561135e-02, 2.41609048e-02],
          [3.21437381e-02, 1.56651735e-02, 2.29396168e-02],
          [1.56141100e-02, 7.39735365e-03, 1.72497686e-02],
          [1.58353690e-02, 2.00548172e-02, 4.94232178e-02],
          [1.94491446e-02, 3.14465575e-02, 5.78522682e-04]],
 
         [[2.43466794e-02, 4.52120081e-02, 1.16047328e-02],
          [7.98987783e-03, 3.29398103e-02, 2.70730909e-02],
          [4.86647002e-02, 3.62484753e-02, 2.79381704e-02],
          [3.52992788e-02, 1.08714942e-02, 1.69741996e-02],
          [2.96542943e-02, 1.13124195e-02, 1.75402761e-02],
          [2.14590435e-03, 4.38815951e-02, 3.67909670e-02],
          [1.00823641e-02, 4.04539099e-03, 3.376704

In [8]:
quantum_dmrg_model.layers[0].non_trainable_variables

[]

In [11]:
cosx = np.random.uniform(size=10)

inputs = np.array([np.array([cosx, np.sqrt(1-cosx*cosx)]).T])

In [12]:
output = quantum_dmrg_model.predict(inputs)
output

False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
True
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False


array([[4.5197291e-07, 4.2527077e-07, 4.5495318e-07]], dtype=float32)

In [13]:
[quantum_dmrg_model.layers[0].mps_tensors[i].shape for i in range(3)]

[TensorShape([2, 10]), TensorShape([2, 10, 10]), TensorShape([2, 10, 10])]

In [14]:
quantum_dmrg_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantum_dmrg_layer_1 (Quantu (None, 3)                 0         
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
