In [1]:
import numpy as np
import tensornetwork as tn
import tensorflow as tf
import numba
import json

# Layers

In [2]:
class QuantumDMRGLayer(tf.keras.layers.Layer):
    def __init__(self, dimvec, pos_label, nblabels, bond_len, nearzero_std=1e-9, isolated_labelnode=True):
        super(QuantumDMRGLayer, self).__init__()
        self.dimvec = dimvec
        self.pos_label = pos_label
        self.nblabels = nblabels
        self.m = bond_len
        self.isolated_label = isolated_labelnode

        assert self.pos_label >= 0 and self.pos_label < self.dimvec

        self.mps_tensors = [tf.Variable(self.mps_tensor_initial_values(i, nearzero_std=nearzero_std),
                                        trainable=True,
                                        name='mps_tensors_{}'.format(i))
                            for i in range(self.dimvec)]
        if self.isolated_label:
            self.output_tensor = tf.Variable(tf.random.normal((self.m, self.m, self.nblabels),
                                                              mean=0.0,
                                                              stddev=nearzero_std),
                                             trainable=True,
                                             name='mps_output_node')

    def mps_tensor_initial_values(self, idx, nearzero_std=1e-9):
        if idx == 0 or idx == self.dimvec - 1:
            tempmat = tf.eye(max(2, self.m))
            mat = tempmat[0:2, :] if 2 < self.m else tempmat[:, 0:self.m]
            return mat + tf.random.normal(mat.shape, mean=0.0, stddev=nearzero_std)
        elif not self.isolated_label and idx == self.pos_label:
            return tf.random.normal((2, self.m, self.m, self.nblabels),
                                    mean=0.0,
                                    stddev=nearzero_std)
        else:
            return tf.random.normal((2, self.m, self.m),
                                    mean=0.0,
                                    stddev=nearzero_std)

    def infer_single(self, input):
        assert input.shape[0] == self.dimvec
        assert input.shape[1] == 2

        nodes = [
            tn.Node(self.mps_tensors[i], backend='tensorflow')
            for i in range(self.dimvec)
        ]
        if self.isolated_label:
            output_node = tn.Node(self.output_tensor, backend='tensorflow')
        input_nodes = [
            tn.Node(input[i, :], backend='tensorflow')
            for i in range(self.dimvec)
        ]

        for i in range(self.dimvec):
            nodes[i][0] ^ input_nodes[i][0]
        if self.isolated_label:
            nodes[0][1] ^ nodes[1][1]
            for i in range(1, self.pos_label):
                nodes[i][2] ^ nodes[i + 1][1]
            nodes[self.pos_label][2] ^ output_node[0]
            output_node[1] ^ nodes[self.pos_label + 1][1]
            for i in range(self.pos_label + 1, self.dimvec - 1):
                nodes[i][2] ^ nodes[i + 1][1]
        else:
            nodes[0][1] ^ nodes[1][1]
            for i in range(1, self.dimvec-1):
                nodes[i][2] ^ nodes[i + 1][1]

        if self.isolated_label:
            final_node = tn.contractors.auto(nodes + input_nodes + [output_node],
                                             output_edge_order=[output_node[2]])
        else:
            final_node = tn.contractors.auto(nodes + input_nodes,
                                             output_edge_order=[nodes[self.pos_label][3]])
        return final_node.tensor

    def call(self, inputs):
        return tf.vectorized_map(self.infer_single, inputs)

In [3]:
def generate_data(mnist_file):
    for line in mnist_file:
        data = json.loads(line)
        pixels = np.array(data['pixels'])
        digit = data['digit']
        yield pixels, digit


@numba.njit(numba.float64[:, :](numba.float64[:]))
def convert_pixels_to_tnvector(pixels):
    tnvector = np.concatenate(
        (np.expand_dims(np.cos(0.5*np.pi*pixels/256.), axis=0),
         np.expand_dims(np.sin(0.5*np.pi*pixels/256.), axis=0)),
        axis=0
    ).T
    return tnvector


def convert_pixels(datum):
    # datum['pixels'] = [list(l) for l in convert_pixels_to_tnvector(np.array([datum['pixels']]))]
    for i, pixel in enumerate(convert_pixels_to_tnvector(np.array([datum['pixels']]))):
        datum['pixel{}'.format(i)] = list(pixel)
    return datum

In [22]:
def QuantumKerasModel(dimvec, pos_label, nblabels, bond_len, nearzero_std=1e-9, optimizer='adam'):
    quantum_dmrg_model = tf.keras.Sequential([
        tf.keras.Input(shape=(dimvec, 2)),
        QuantumDMRGLayer(dimvec=dimvec,
                         pos_label=pos_label,
                         nblabels=nblabels,
                         bond_len=bond_len,
                         nearzero_std=nearzero_std),
#         tf.keras.layers.LayerNormalization(beta_initializer='RandomUniform', gamma_initializer='RandomUniform', beta_constraint='non_neg')
        tf.keras.layers.LayerNormalization(beta_initializer='RandomUniform', gamma_initializer='RandomUniform'),
        tf.keras.layers.Softmax()
    ])
    quantum_dmrg_model.compile(optimizer=optimizer, loss=tf.keras.losses.CategoricalCrossentropy())
    return quantum_dmrg_model

In [5]:
def DenseTNKerasModel(dimvec, hidden_dim, nblabels, bond_len, nearzero_std=1e-9, optimizer='adam'):
    tn_model = tf.keras.Sequential([
        tf.keras.Input(shape=(dimvec, 2)),
        tf.keras.layers.Reshape((dimvec*2,)),
        tf.keras.layers.Dense(hidden_dim*2, activation=None),
        tf.keras.layers.Reshape((hidden_dim, 2)),
        QuantumDMRGLayer(dimvec=hidden_dim,
                         pos_label=hidden_dim // 2,
                         nblabels=nblabels,
                         bond_len=bond_len,
                         nearzero_std=nearzero_std),
        tf.keras.layers.Softmax()
    ])
    tn_model.compile(optimizer=optimizer, loss=tf.keras.losses.CategoricalCrossentropy())
    return tn_model

In [6]:
cosx = np.random.uniform(size=784)

inputs = np.array([np.array([np.cos(0.5*np.pi*cosx/256.), np.sin(0.5*np.pi*cosx/256.)]).T])
inputs

array([[[9.99992416e-01, 3.89465704e-03],
        [9.99997184e-01, 2.37317318e-03],
        [9.99995432e-01, 3.02272545e-03],
        ...,
        [9.99989133e-01, 4.66205259e-03],
        [9.99999146e-01, 1.30703263e-03],
        [9.99999854e-01, 5.40475390e-04]]])

In [7]:
 # model parameters
dimvec = 784
pos_label = 392
nblabels = 10
bond_len = 2
nbdata = 70000

# training and CV parameters
nb_epochs = 4
cv_fold = 5
batch_size = 100
std = 1e-9
learning_rate = 1e-4

In [8]:
# Prepare for cross-validation
cv_labels = np.random.choice(range(cv_fold), size=nbdata)

In [9]:
# Reading the data
label_dict = {str(i): i for i in range(10)}
X = np.zeros((nbdata, dimvec, 2))
Y = np.zeros((nbdata, nblabels))
for i, (pixels, label) in enumerate(generate_data(open('/data/hok/PyProjects/tensornetwork-learn/experiments/mnist_784/mnist_784.json', 'r'))):
    X[i, :, :] = convert_pixels_to_tnvector(pixels)
    Y[i, label_dict[label]] = 1.

In [10]:
cv_idx = 0

trainX = X[cv_labels==cv_idx, :, :]
trainY = Y[cv_labels==cv_idx, :]

In [11]:
# tn_model = DenseTNKerasModel(dimvec=dimvec, hidden_dim=64,  nblabels=nblabels, bond_len=bond_len,
#                              nearzero_std=std, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate))

In [23]:
quantum_dmrg_model = QuantumKerasModel(dimvec=dimvec, pos_label=pos_label, nblabels=nblabels, bond_len=bond_len,
                                       nearzero_std=std, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate))

In [24]:
quantum_dmrg_model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantum_dmrg_layer_1 (Quantu (None, 10)                6304      
_________________________________________________________________
layer_normalization_1 (Layer (None, 10)                20        
_________________________________________________________________
softmax (Softmax)            (None, 10)                0         
Total params: 6,324
Trainable params: 6,324
Non-trainable params: 0
_________________________________________________________________


In [25]:
tn_layer = quantum_dmrg_model.layers[0]

In [26]:
tn_layer.trainable_variables[0:3]

[<tf.Variable 'mps_tensors_0:0' shape=(2, 2) dtype=float32, numpy=
 array([[ 1.0000000e+00, -1.4978621e-09],
        [-1.6002278e-09,  1.0000000e+00]], dtype=float32)>,
 <tf.Variable 'mps_tensors_1:0' shape=(2, 2, 2) dtype=float32, numpy=
 array([[[ 1.2845859e-09, -5.1590737e-10],
         [-1.1658096e-09,  6.8301648e-10]],
 
        [[-1.6309616e-10,  1.5843733e-10],
         [-5.5312405e-10, -1.5131182e-09]]], dtype=float32)>,
 <tf.Variable 'mps_tensors_2:0' shape=(2, 2, 2) dtype=float32, numpy=
 array([[[-1.5412845e-09, -3.8406842e-10],
         [ 1.4566774e-09, -1.2042070e-09]],
 
        [[-3.3812675e-10, -7.3959028e-10],
         [-1.8769619e-09,  2.8743155e-10]]], dtype=float32)>]

In [27]:
output = quantum_dmrg_model.predict(inputs)
output

array([[0.10163032, 0.09602131, 0.10381068, 0.10201379, 0.10338927,
        0.10157745, 0.09881987, 0.10003183, 0.09667202, 0.0960335 ]],
      dtype=float32)

In [28]:
quantum_dmrg_model.fit(trainX, trainY, epochs=nb_epochs, batch_size=batch_size)

Train on 13812 samples
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<tensorflow.python.keras.callbacks.History at 0x2aabe6974d10>

In [31]:
output = quantum_dmrg_model.predict(inputs)
output

array([[0.10067101, 0.09875087, 0.10343433, 0.10198036, 0.10286111,
        0.10009173, 0.09848643, 0.10098644, 0.09636264, 0.09637505]],
      dtype=float32)

In [32]:
tn_layer.trainable_variables[0:3]

[<tf.Variable 'mps_tensors_0:0' shape=(2, 2) dtype=float32, numpy=
 array([[ 1.0000000e+00, -1.4978621e-09],
        [-1.6002278e-09,  1.0000000e+00]], dtype=float32)>,
 <tf.Variable 'mps_tensors_1:0' shape=(2, 2, 2) dtype=float32, numpy=
 array([[[ 1.2845859e-09, -5.1590737e-10],
         [-1.1658096e-09,  6.8301648e-10]],
 
        [[-1.6309616e-10,  1.5843733e-10],
         [-5.5312405e-10, -1.5131182e-09]]], dtype=float32)>,
 <tf.Variable 'mps_tensors_2:0' shape=(2, 2, 2) dtype=float32, numpy=
 array([[[-1.5412845e-09, -3.8406842e-10],
         [ 1.4566774e-09, -1.2042070e-09]],
 
        [[-3.3812675e-10, -7.3959028e-10],
         [-1.8769619e-09,  2.8743155e-10]]], dtype=float32)>]