In [1]:
import numpy as np
import tensornetwork as tn
import tensorflow as tf
from tqdm import tqdm
from functools import partial

# Layers

In [2]:
class QuantumDMRGLayer(tf.keras.layers.Layer):
    def __init__(self, dimvec, pos_label, nblabels, bond_len, unihigh):
        super(QuantumDMRGLayer, self).__init__()
        self.dimvec = dimvec
        self.pos_label = pos_label
        self.nblabels = nblabels
        self.m = bond_len
        self.unihigh = unihigh
        self.construct_tensornetwork()

    def construct_tensornetwork(self):
        end_node = lambda i: tf.Variable(tf.random.uniform(shape=(2, self.m),
                                                           minval=0,
                                                           maxval=self.unihigh),
                                         name='mps_node_{}'.format(i),
                                         trainable=True)

        label_node = lambda i: tf.Variable(tf.random.uniform(shape=(2, self.m, self.m, self.nblabels),
                                                             minval=0,
                                                             maxval=self.unihigh),
                                           name='mps_node_{}'.format(i),
                                           trainable=True)
        normal_node = lambda i: tf.Variable(tf.random.uniform(shape=(2, self.m, self.m),
                                                              minval=0,
                                                              maxval=self.unihigh),
                                            name='mps_node_{}'.format(i),
                                            trainable=True)
        
        self.mps_tf_vars = [None] * self.dimvec
        for i in range(self.dimvec):
            self.mps_tf_vars[i] = tf.case([(tf.math.logical_or(tf.math.equal(i, 0), tf.math.equal(i, self.dimvec - 1)), partial(end_node, i=i)),
                                           (tf.math.equal(i, self.pos_label), partial(label_node, i=i))],
                                           default=partial(normal_node, i=i)
                                  )

        # model nodes
        self.nodes = [
            tn.Node(self.mps_tf_vars[i], name='node{}'.format(i), backend='tensorflow')
            for i in range(self.dimvec)
        ]

        # input nodes
        cosx = np.random.uniform(size=self.dimvec)
        self.input_nodes = [tn.Node(np.array([cosx[i], np.sqrt(1-cosx[i]*cosx[i])]), name='input{}'.format(i), backend='tensorflow') 
                            for i in range(self.dimvec)]

    @tf.function
    def infer_single_datum(self, input):
        for i in range(self.dimvec):
            self.input_nodes[i].tensor = input[i, :]
        edges = [self.nodes[0][1] ^ self.nodes[1][1]]
        for i in range(1, self.dimvec - 1):
            edges.append(self.nodes[i][2] ^ self.nodes[i + 1][1])

        input_edges = [self.nodes[i][0] ^ self.input_nodes[i][0] for i in range(self.dimvec)]

        final_node = tn.contractors.greedy(self.nodes + self.input_nodes,
                                           output_edge_order=[self.nodes[self.pos_label][3]])
        # final_node = self.nodes[0] @ self.nodes[1]
        # for node in self.nodes[2:]+self.input_nodes:
        #     final_node = final_node @ node
        
        return final_node.tensor

    def call(self, inputs):
        return tf.vectorized_map(self.infer_single_datum, inputs)

In [3]:
dmrg_layer = QuantumDMRGLayer(dimvec=784, pos_label=392, nblabels=10, bond_len=10, unihigh=0.05)

In [4]:
cosx = np.random.uniform(size=784)

inputs = np.array([np.array([cosx, np.sqrt(1-cosx*cosx)]).T])

In [5]:
output = dmrg_layer(inputs)

W0612 13:51:25.667677 4559089088 base_layer.py:2081] Layer quantum_dmrg_layer is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



In [6]:
output

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

In [7]:
quantum_dmrg_model = tf.keras.Sequential([
    tf.keras.Input(shape=(20, 2)),
    QuantumDMRGLayer(dimvec=20, pos_label=10, nblabels=3, bond_len=5, unihigh=0.05),
    tf.keras.layers.Softmax()
])

In [27]:
isinstance(dmrg_layer, tf.keras.layers.Layer)

True