In [1]:
import numpy as np
import tensornetwork as tn
import tensorflow as tf
from functools import partial

# Layers

In [2]:
class QuantumDMRGLayer(tf.keras.layers.Layer):
    def __init__(self, dimvec, pos_label, nblabels, bond_len, unihigh):
        super(QuantumDMRGLayer, self).__init__()
        self.dimvec = dimvec
        self.pos_label = pos_label
        self.nblabels = nblabels
        self.m = bond_len
        self.unihigh = unihigh
        self.construct_tensornetwork()

    def construct_tensornetwork(self):
        end_node = lambda i: tf.Variable(tf.random.uniform(shape=(2, self.m),
                                                           minval=0,
                                                           maxval=self.unihigh),
                                         name='mps_node_{}'.format(i),
                                         trainable=True)

        label_node = lambda i: tf.Variable(tf.random.uniform(shape=(2, self.m, self.m, self.nblabels),
                                                             minval=0,
                                                             maxval=self.unihigh),
                                           name='mps_node_{}'.format(i),
                                           trainable=True)
        normal_node = lambda i: tf.Variable(tf.random.uniform(shape=(2, self.m, self.m),
                                                              minval=0,
                                                              maxval=self.unihigh),
                                            name='mps_node_{}'.format(i),
                                            trainable=True)
        
        self.mps_tf_vars = [None] * self.dimvec
        for i in range(self.dimvec):
            self.mps_tf_vars[i] = tf.case([(tf.math.logical_or(tf.math.equal(i, 0), tf.math.equal(i, self.dimvec - 1)), partial(end_node, i=i)),
                                           (tf.math.equal(i, self.pos_label), partial(label_node, i=i))],
                                           default=partial(normal_node, i=i)
                                  )

        # model nodes
        self.nodes = [
            tn.Node(self.mps_tf_vars[i], name='node{}'.format(i), backend='tensorflow')
            for i in range(self.dimvec)
        ]

        # input nodes
        cosx = np.random.uniform(size=self.dimvec)
        self.input_nodes = [tn.Node(np.array([cosx[i], np.sqrt(1-cosx[i]*cosx[i])]), name='input{}'.format(i), backend='tensorflow') 
                            for i in range(self.dimvec)]

    @tf.function
    def infer_single_datum(self, input):
        for i in range(self.dimvec):
            self.input_nodes[i].tensor = input[i, :]
        edges = [self.nodes[0][1] ^ self.nodes[1][1]]
        for i in range(1, self.dimvec - 1):
            edges.append(self.nodes[i][2] ^ self.nodes[i + 1][1])

        input_edges = [self.nodes[i][0] ^ self.input_nodes[i][0] for i in range(self.dimvec)]

        final_node = tn.contractors.greedy(self.nodes + self.input_nodes,
                                           output_edge_order=[self.nodes[self.pos_label][3]])
        # final_node = self.nodes[0] @ self.nodes[1]
        # for node in self.nodes[2:]+self.input_nodes:
        #     final_node = final_node @ node
        
        return final_node.tensor

    def call(self, inputs):
        return tf.vectorized_map(self.infer_single_datum, inputs)

In [3]:
dmrg_layer = QuantumDMRGLayer(dimvec=784, pos_label=392, nblabels=10, bond_len=10, unihigh=0.05)

In [4]:
cosx = np.random.uniform(size=784)

inputs = np.array([np.array([cosx, np.sqrt(1-cosx*cosx)]).T])

In [5]:
output = dmrg_layer(inputs)



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



In [6]:
output

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

In [7]:
quantum_dmrg_model = tf.keras.Sequential([
    tf.keras.Input(shape=(20, 2)),
    QuantumDMRGLayer(dimvec=20, pos_label=10, nblabels=3, bond_len=5, unihigh=0.05),
    tf.keras.layers.Softmax()
])

In [8]:
isinstance(dmrg_layer, tf.keras.layers.Layer)

True

In [12]:
quantum_dmrg_model = tf.keras.Sequential([
    tf.keras.Input(shape=(784, 2)),
    QuantumDMRGLayer(dimvec=784, pos_label=392, nblabels=10, bond_len=20, unihigh=0.05),
    tf.keras.layers.Softmax()
])

# Load MNIST data

In [16]:
import json
import numba

In [15]:
def generate_data(mnist_file):
    for line in mnist_file:
        data = json.loads(line)
        pixels = np.array(data['pixels'])
        digit = data['digit']
        yield pixels, digit

In [30]:
@numba.jit
def convert_pixels_to_tnvector(pixels):
    tnvector = np.array([pixels/256., np.sqrt(1-(pixels/256.)*(pixels/256.))]).T
    return tnvector

In [18]:
alldata = [(pixels, digit) 
           for pixels, digit in generate_data(open('/data/hok/testdata/mnist/mnist_784/mnist_784.json', 'r'))]

In [31]:
convert_pixels_to_tnvector(alldata[0][0])[150:180, :]

Compilation is falling back to object mode WITH looplifting enabled because Function "convert_pixels_to_tnvector" failed type inference due to: Invalid use of Function(<built-in function array>) with argument(s) of type(s): (list(array(float64, 1d, C)))
 * parameterized
In definition 0:
    TypingError: array(float64, 1d, C) not allowed in a homogeneous sequence
    raised from /data/hok/conda/envs/hokdev/lib/python3.7/site-packages/numba/core/typing/npydecl.py:462
In definition 1:
    TypingError: array(float64, 1d, C) not allowed in a homogeneous sequence
    raised from /data/hok/conda/envs/hokdev/lib/python3.7/site-packages/numba/core/typing/npydecl.py:462
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function array>)
[2] During: typing of call at <ipython-input-30-1094ff1dffba> (3)


File "<ipython-input-30-1094ff1dffba>", line 3:
def convert_pixels_to_tnvector(pixels):

array([[0.        , 1.        ],
       [0.        , 1.        ],
       [0.01171875, 0.99993133],
       [0.0703125 , 0.99752501],
       [0.0703125 , 0.99752501],
       [0.0703125 , 0.99752501],
       [0.4921875 , 0.87048921],
       [0.53125   , 0.84721511],
       [0.68359375, 0.72986272],
       [0.1015625 , 0.99482916],
       [0.6484375 , 0.7612679 ],
       [0.99609375, 0.08830199],
       [0.96484375, 0.26282416],
       [0.49609375, 0.86826896],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.1171875 , 0.99310981],
       [0.140625  , 0.99006293],
       [0.3671875 , 0.93014695],
       [0.6015625 , 0.79882574]])

In [32]:
alldata[0][0][150:180]

array([  0.,   0.,   3.,  18.,  18.,  18., 126., 136., 175.,  26., 166.,
       255., 247., 127.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,  30.,  36.,  94., 154.])