In [1]:
import numpy as np
import tensornetwork as tn
import tensorflow as tf
import numba
import json

# Layers

In [34]:
class QuantumDMRGLayer(tf.keras.layers.Layer):
    def __init__(self, dimvec, pos_label, nblabels, bond_len, nearzero_std=1e-9, isolated_labelnode=True):
        super(QuantumDMRGLayer, self).__init__()
        self.dimvec = dimvec
        self.pos_label = pos_label
        self.nblabels = nblabels
        self.m = bond_len
        self.isolated_label = isolated_labelnode

        assert self.pos_label >= 0 and self.pos_label < self.dimvec

        self.mps_tensors = [tf.Variable(self.mps_tensor_initial_values(i, nearzero_std=nearzero_std),
                                        trainable=True,
                                        name='mps_tensors_{}'.format(i))
                            for i in range(self.dimvec)]
        if self.isolated_label:
            self.output_tensor = tf.Variable(tf.random.normal((self.m, self.m, self.nblabels),
                                                              mean=0.0,
                                                              stddev=nearzero_std),
                                             trainable=True,
                                             name='mps_output_node')

    def mps_tensor_initial_values(self, idx, nearzero_std=1e-9):
        if idx == 0 or idx == self.dimvec - 1:
            tempmat = tf.eye(max(2, self.m))
            mat = tempmat[0:2, :] if 2 < self.m else tempmat[:, 0:self.m]
            return mat + tf.random.normal(mat.shape, mean=0.0, stddev=nearzero_std)
        elif not self.isolated_label and idx == self.pos_label:
            return tf.random.normal((2, self.m, self.m, self.nblabels),
                                    mean=0.0,
                                    stddev=nearzero_std)
        else:
            return tf.random.normal((2, self.m, self.m),
                                    mean=0.0,
                                    stddev=nearzero_std)

    def infer_single(self, input):
        assert input.shape[0] == self.dimvec
        assert input.shape[1] == 2

        nodes = [
            tn.Node(self.mps_tensors[i], backend='tensorflow')
            for i in range(self.dimvec)
        ]
        if self.isolated_label:
            output_node = tn.Node(self.output_tensor, backend='tensorflow')
        input_nodes = [
            tn.Node(input[i, :], backend='tensorflow')
            for i in range(self.dimvec)
        ]

        for i in range(self.dimvec):
            nodes[i][0] ^ input_nodes[i][0]
        if self.isolated_label:
            nodes[0][1] ^ nodes[1][1]
            for i in range(1, self.pos_label):
                nodes[i][2] ^ nodes[i + 1][1]
            nodes[self.pos_label][2] ^ output_node[0]
            output_node[1] ^ nodes[self.pos_label + 1][1]
            for i in range(self.pos_label + 1, self.dimvec - 1):
                nodes[i][2] ^ nodes[i + 1][1]
        else:
            nodes[0][1] ^ nodes[1][1]
            for i in range(1, self.dimvec-1):
                nodes[i][2] ^ nodes[i + 1][1]

        if self.isolated_label:
            final_node = tn.contractors.auto(nodes + input_nodes + [output_node],
                                             output_edge_order=[output_node[2]])
        else:
            final_node = tn.contractors.auto(nodes + input_nodes,
                                             output_edge_order=[nodes[self.pos_label][3]])
        return final_node.tensor

    def call(self, inputs):
        return tf.vectorized_map(self.infer_single, inputs)

# Generating Data

In [10]:
def generate_random_data(size, std=0.2):
    centroids = {0: np.array([0.5, 0.5, 0, 1]), 
                 1: np.array([-0.5, 0.5, 0, -1]),
                 2: np.array([-0.5, -0.5, 1, 0]),
                 3: np.array([0.5, -0.5, -1, 0])}
    for _ in range(size):
        quadrant = np.random.choice(range(4))
        yield np.random.normal(loc=centroids[quadrant], scale=std), quadrant
        

@numba.njit(numba.float64[:, :](numba.float64[:]))
def convert_pixels_to_tnvector(vector):
    tnvector = np.concatenate((
        np.array([[vector[0], np.sign(vector[0])*(1-np.abs(vector[0]))]]),
        np.array([[vector[1], np.sign(vector[1])*(1-np.abs(vector[1]))]]),
        np.array([[vector[2], np.sign(vector[2])*(1-np.abs(vector[0]))]]),
        np.array([[vector[3], np.sign(vector[1])*(1-np.abs(vector[1]))]])

    ), axis=0)
    return tnvector


In [11]:
size = 1000

X = np.zeros((size, 4, 2))
Y = np.zeros((size, 4))

for i, (thisX, thisY) in enumerate(generate_random_data(size)):
    X[i, :, :] = convert_pixels_to_tnvector(thisX)
    Y[i, thisY] = 1.

# Model

In [12]:
def QuantumKerasModel(dimvec, pos_label, nblabels, bond_len, nearzero_std=1e-9, optimizer='adam'):
    quantum_dmrg_model = tf.keras.Sequential([
        tf.keras.Input(shape=(dimvec, 2)),
        QuantumDMRGLayer(dimvec=dimvec,
                         pos_label=pos_label,
                         nblabels=nblabels,
                         bond_len=bond_len,
                         nearzero_std=nearzero_std),
#         tf.keras.layers.LayerNormalization(beta_initializer='RandomUniform', gamma_initializer='RandomUniform', beta_constraint='non_neg')
        tf.keras.layers.LayerNormalization(beta_initializer='RandomUniform', gamma_initializer='RandomUniform'),
        tf.keras.layers.Softmax()
    ])
    quantum_dmrg_model.compile(optimizer=optimizer, loss=tf.keras.losses.CategoricalCrossentropy())
    return quantum_dmrg_model

In [58]:
 # model parameters
dimvec = 4
pos_label = 2
nblabels = 4
bond_len = 2
nbdata = 1000

# training and CV parameters
nb_epochs = 1000
cv_fold = 5
batch_size = 10
std = 1e-9
learning_rate = 1e-2

In [18]:
# Prepare for cross-validation
cv_labels = np.random.choice(range(cv_fold), size=nbdata)

In [19]:
cv_idx = 0

trainX = X[cv_labels==cv_idx, :, :]
trainY = Y[cv_labels==cv_idx, :]

In [59]:
quantum_dmrg_model = QuantumKerasModel(dimvec=dimvec, pos_label=pos_label, nblabels=nblabels, bond_len=bond_len,
                                       nearzero_std=std, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate))

In [60]:
quantum_dmrg_model.summary()

Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantum_dmrg_layer_7 (Quantu (None, 4)                 40        
_________________________________________________________________
layer_normalization_7 (Layer (None, 4)                 8         
_________________________________________________________________
softmax_7 (Softmax)          (None, 4)                 0         
Total params: 48
Trainable params: 48
Non-trainable params: 0
_________________________________________________________________


In [61]:
tn_layer = quantum_dmrg_model.layers[0]
tn_layer.trainable_variables[0:3]

[<tf.Variable 'mps_tensors_0:0' shape=(2, 2) dtype=float32, numpy=
 array([[ 1.0000000e+00, -2.4472582e-10],
        [-3.1551026e-10,  1.0000000e+00]], dtype=float32)>,
 <tf.Variable 'mps_tensors_1:0' shape=(2, 2, 2) dtype=float32, numpy=
 array([[[-1.7868844e-09,  1.5684126e-10],
         [-4.5113749e-10,  8.0341755e-10]],
 
        [[ 2.9506839e-10, -1.2818342e-09],
         [ 1.5821333e-09, -5.1799121e-10]]], dtype=float32)>,
 <tf.Variable 'mps_tensors_2:0' shape=(2, 2, 2) dtype=float32, numpy=
 array([[[-1.2069160e-09, -7.2580462e-11],
         [ 1.1222002e-09, -2.5530027e-09]],
 
        [[ 8.7495028e-10,  9.6011843e-10],
         [ 8.5000867e-10,  2.0877959e-09]]], dtype=float32)>]

In [62]:
X.shape

(1000, 4, 2)

In [63]:
output = quantum_dmrg_model.predict(X[0:, :, :])
output

array([[0.25478882, 0.24265279, 0.25557625, 0.24698208],
       [0.25478882, 0.24265279, 0.25557625, 0.24698208],
       [0.25478882, 0.24265279, 0.25557625, 0.24698208],
       ...,
       [0.25478882, 0.24265279, 0.25557625, 0.24698208],
       [0.25478882, 0.24265279, 0.25557625, 0.24698208],
       [0.25478882, 0.24265279, 0.25557625, 0.24698208]], dtype=float32)

In [64]:
quantum_dmrg_model.fit(trainX, trainY, epochs=nb_epochs, batch_size=batch_size)

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3645

W0628 20:05:07.647438 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.101582). Check your callbacks.


Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3726

W0628 20:05:16.056790 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.170837). Check your callbacks.


Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
Epoch 73/1000
Epoch 74/1000
Epoch 75/1000
Epoch 76/1000
Epoch 77/1000
Epoch 78/1000
Epoch 79/1000
Epoch 80/1000
Epoch 81/1000
Epoch 82/1000
Epoch 83/1000
Epoch 84/1000
Epoch 85/1000
Epoch 86/1000
Epoch 87/1000
Epoch 88/1000
Epoch 89/1000
Epoch 90/1000
Epoch 91/1000
Epoch 92/1000
Epoch 93/1000
Epoch 94/1000
Epoch 95/1000
Epoch 96/1000
Epoch 97/1000
Epoch 98/1000
Epoch 99/1000
Epoch 100/1000
Epoch 101/1000
Epoch 102/1000
Epoch 103/1000
Epoch 104/1000
Epoch 105/1000
Epoch 106/1000
Epoch 107/1000
Epoch 108/1000
Epoch 109/1000
Epoch 110/1000
Epoch 111/1000
Epoch 112/1000
Epoch 113/1000
Epoch 114/1000
Epoch 115/1000
Epoch 116/1000
Epo

W0628 20:05:43.107963 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.112210). Check your callbacks.


Epoch 131/1000
Epoch 132/1000
Epoch 133/1000
Epoch 134/1000
Epoch 135/1000
Epoch 136/1000
Epoch 137/1000
Epoch 138/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3784

W0628 20:05:45.825661 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.121482). Check your callbacks.


Epoch 139/1000
Epoch 140/1000
Epoch 141/1000
Epoch 142/1000
Epoch 143/1000
Epoch 144/1000
Epoch 145/1000
Epoch 146/1000
Epoch 147/1000
Epoch 148/1000
Epoch 149/1000
Epoch 150/1000
Epoch 151/1000
Epoch 152/1000
Epoch 153/1000
Epoch 154/1000
Epoch 155/1000
Epoch 156/1000
Epoch 157/1000
Epoch 158/1000
Epoch 159/1000
Epoch 160/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3973

W0628 20:05:53.148906 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.189221). Check your callbacks.


Epoch 161/1000
Epoch 162/1000
Epoch 163/1000
Epoch 164/1000
Epoch 165/1000
Epoch 166/1000
Epoch 167/1000
Epoch 168/1000
Epoch 169/1000
Epoch 170/1000
Epoch 171/1000
Epoch 172/1000
Epoch 173/1000
Epoch 174/1000
Epoch 175/1000
Epoch 176/1000
Epoch 177/1000
Epoch 178/1000
Epoch 179/1000
Epoch 180/1000
Epoch 181/1000
Epoch 182/1000
Epoch 183/1000
Epoch 184/1000
Epoch 185/1000
Epoch 186/1000
Epoch 187/1000
Epoch 188/1000
Epoch 189/1000
Epoch 190/1000
Epoch 191/1000
Epoch 192/1000
Epoch 193/1000
Epoch 194/1000
Epoch 195/1000
Epoch 196/1000
Epoch 197/1000
Epoch 198/1000
Epoch 199/1000
Epoch 200/1000
Epoch 201/1000
Epoch 202/1000
Epoch 203/1000
Epoch 204/1000
Epoch 205/1000
Epoch 206/1000
Epoch 207/1000
Epoch 208/1000
Epoch 209/1000
Epoch 210/1000
Epoch 211/1000
Epoch 212/1000
Epoch 213/1000
Epoch 214/1000
Epoch 215/1000
Epoch 216/1000
Epoch 217/1000
Epoch 218/1000
Epoch 219/1000
Epoch 220/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.4118

W0628 20:06:13.350083 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.126372). Check your callbacks.


Epoch 221/1000
Epoch 222/1000
Epoch 223/1000
Epoch 224/1000
Epoch 225/1000
Epoch 226/1000
Epoch 227/1000
Epoch 228/1000
Epoch 229/1000
Epoch 230/1000
Epoch 231/1000
Epoch 232/1000
Epoch 233/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.4127

W0628 20:06:18.036993 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.117678). Check your callbacks.


Epoch 234/1000
Epoch 235/1000
Epoch 236/1000
Epoch 237/1000
Epoch 238/1000
Epoch 239/1000
Epoch 240/1000
Epoch 241/1000
Epoch 242/1000
Epoch 243/1000
Epoch 244/1000
Epoch 245/1000
Epoch 246/1000
Epoch 247/1000
Epoch 248/1000
Epoch 249/1000
Epoch 250/1000
Epoch 251/1000
Epoch 252/1000
Epoch 253/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3790

W0628 20:06:25.258013 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.153579). Check your callbacks.


Epoch 254/1000
Epoch 255/1000
Epoch 256/1000
Epoch 257/1000
Epoch 258/1000
Epoch 259/1000
Epoch 260/1000
Epoch 261/1000
Epoch 262/1000
Epoch 263/1000
Epoch 264/1000
Epoch 265/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3178

W0628 20:06:28.787636 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.182256). Check your callbacks.


Epoch 266/1000
Epoch 267/1000
Epoch 268/1000
Epoch 269/1000
Epoch 270/1000
Epoch 271/1000
Epoch 272/1000
Epoch 273/1000
Epoch 274/1000
Epoch 275/1000
Epoch 276/1000
Epoch 277/1000
Epoch 278/1000
Epoch 279/1000
Epoch 280/1000
Epoch 281/1000
Epoch 282/1000
Epoch 283/1000
Epoch 284/1000
Epoch 285/1000
Epoch 286/1000
Epoch 287/1000
Epoch 288/1000
Epoch 289/1000
Epoch 290/1000
Epoch 291/1000
Epoch 292/1000
Epoch 293/1000
Epoch 294/1000
Epoch 295/1000
Epoch 296/1000
Epoch 297/1000
Epoch 298/1000
Epoch 299/1000
Epoch 300/1000
Epoch 301/1000
Epoch 302/1000
Epoch 303/1000
Epoch 304/1000
Epoch 305/1000
Epoch 306/1000
Epoch 307/1000
Epoch 308/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3768

W0628 20:06:42.268386 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.150484). Check your callbacks.


Epoch 309/1000
Epoch 310/1000
Epoch 311/1000
Epoch 312/1000
Epoch 313/1000
Epoch 314/1000
Epoch 315/1000
Epoch 316/1000
Epoch 317/1000
Epoch 318/1000
Epoch 319/1000
Epoch 320/1000
Epoch 321/1000
Epoch 322/1000
Epoch 323/1000
Epoch 324/1000
Epoch 325/1000
Epoch 326/1000
Epoch 327/1000
Epoch 328/1000
Epoch 329/1000
Epoch 330/1000
Epoch 331/1000
Epoch 332/1000
Epoch 333/1000
Epoch 334/1000
Epoch 335/1000
Epoch 336/1000
Epoch 337/1000
Epoch 338/1000
Epoch 339/1000
Epoch 340/1000
Epoch 341/1000
Epoch 342/1000
Epoch 343/1000
Epoch 344/1000
Epoch 345/1000
Epoch 346/1000
Epoch 347/1000
Epoch 348/1000
Epoch 349/1000
Epoch 350/1000
Epoch 351/1000
Epoch 352/1000
Epoch 353/1000
Epoch 354/1000
Epoch 355/1000
Epoch 356/1000
Epoch 357/1000
Epoch 358/1000
Epoch 359/1000
Epoch 360/1000
Epoch 361/1000
Epoch 362/1000
Epoch 363/1000
Epoch 364/1000
Epoch 365/1000
Epoch 366/1000
Epoch 367/1000
Epoch 368/1000
Epoch 369/1000
Epoch 370/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.4237

W0628 20:07:00.505989 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.168152). Check your callbacks.


Epoch 371/1000
Epoch 372/1000
Epoch 373/1000
Epoch 374/1000
Epoch 375/1000
Epoch 376/1000
Epoch 377/1000
Epoch 378/1000
Epoch 379/1000
Epoch 380/1000
Epoch 381/1000
Epoch 382/1000
Epoch 383/1000
Epoch 384/1000
Epoch 385/1000
Epoch 386/1000
Epoch 387/1000
Epoch 388/1000
Epoch 389/1000
Epoch 390/1000
Epoch 391/1000
Epoch 392/1000
Epoch 393/1000
Epoch 394/1000
Epoch 395/1000
Epoch 396/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3287

W0628 20:07:10.743009 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.356756). Check your callbacks.


 2/21 [=>............................] - ETA: 3s - loss: 1.3827

W0628 20:07:10.751387 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.180182). Check your callbacks.


Epoch 397/1000
Epoch 398/1000
Epoch 399/1000
Epoch 400/1000
Epoch 401/1000
Epoch 402/1000
Epoch 403/1000
Epoch 404/1000
Epoch 405/1000
Epoch 406/1000
Epoch 407/1000
Epoch 408/1000
Epoch 409/1000
Epoch 410/1000
Epoch 411/1000
Epoch 412/1000
Epoch 413/1000
Epoch 414/1000
Epoch 415/1000
Epoch 416/1000
Epoch 417/1000
Epoch 418/1000
Epoch 419/1000
Epoch 420/1000
Epoch 421/1000
Epoch 422/1000
Epoch 423/1000
Epoch 424/1000
Epoch 425/1000
Epoch 426/1000
Epoch 427/1000
Epoch 428/1000
Epoch 429/1000
Epoch 430/1000
Epoch 431/1000
Epoch 432/1000
Epoch 433/1000
Epoch 434/1000
Epoch 435/1000
Epoch 436/1000
Epoch 437/1000
Epoch 438/1000
Epoch 439/1000
Epoch 440/1000
Epoch 441/1000
Epoch 442/1000
Epoch 443/1000
Epoch 444/1000
Epoch 445/1000
Epoch 446/1000
Epoch 447/1000
Epoch 448/1000
Epoch 449/1000
Epoch 450/1000
Epoch 451/1000
Epoch 452/1000
Epoch 453/1000
Epoch 454/1000
Epoch 455/1000
Epoch 456/1000
Epoch 457/1000
Epoch 458/1000
Epoch 459/1000
Epoch 460/1000
Epoch 461/1000
Epoch 462/1000
Epoch 463/

W0628 20:08:10.837904 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.165564). Check your callbacks.


Epoch 569/1000
Epoch 570/1000
Epoch 571/1000
Epoch 572/1000
Epoch 573/1000
Epoch 574/1000
Epoch 575/1000
Epoch 576/1000
Epoch 577/1000
Epoch 578/1000
Epoch 579/1000
Epoch 580/1000
Epoch 581/1000
Epoch 582/1000
Epoch 583/1000
Epoch 584/1000
Epoch 585/1000
Epoch 586/1000
Epoch 587/1000
Epoch 588/1000
Epoch 589/1000
Epoch 590/1000
Epoch 591/1000
Epoch 592/1000
Epoch 593/1000
Epoch 594/1000
Epoch 595/1000
Epoch 596/1000
Epoch 597/1000
Epoch 598/1000
Epoch 599/1000
Epoch 600/1000
Epoch 601/1000
Epoch 602/1000
Epoch 603/1000
Epoch 604/1000
Epoch 605/1000
Epoch 606/1000
Epoch 607/1000
Epoch 608/1000
Epoch 609/1000
Epoch 610/1000
Epoch 611/1000
Epoch 612/1000
Epoch 613/1000
Epoch 614/1000
Epoch 615/1000
Epoch 616/1000
Epoch 617/1000
Epoch 618/1000
Epoch 619/1000
Epoch 620/1000
Epoch 621/1000
Epoch 622/1000
Epoch 623/1000
Epoch 624/1000
Epoch 625/1000
Epoch 626/1000
Epoch 627/1000
Epoch 628/1000
Epoch 629/1000
Epoch 630/1000
Epoch 631/1000
Epoch 632/1000
Epoch 633/1000
Epoch 634/1000
Epoch 635/

W0628 20:09:14.993860 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.160539). Check your callbacks.


Epoch 819/1000
Epoch 820/1000
Epoch 821/1000
Epoch 822/1000
Epoch 823/1000
Epoch 824/1000
Epoch 825/1000
Epoch 826/1000
Epoch 827/1000
Epoch 828/1000
Epoch 829/1000
Epoch 830/1000
Epoch 831/1000
Epoch 832/1000
Epoch 833/1000
Epoch 834/1000
Epoch 835/1000
Epoch 836/1000
Epoch 837/1000
Epoch 838/1000
Epoch 839/1000
Epoch 840/1000
Epoch 841/1000
Epoch 842/1000
Epoch 843/1000
Epoch 844/1000
Epoch 845/1000
Epoch 846/1000
Epoch 847/1000
Epoch 848/1000
Epoch 849/1000
Epoch 850/1000
Epoch 851/1000
Epoch 852/1000
Epoch 853/1000
Epoch 854/1000
Epoch 855/1000
Epoch 856/1000
Epoch 857/1000
Epoch 858/1000
Epoch 859/1000
Epoch 860/1000
Epoch 861/1000
Epoch 862/1000
Epoch 863/1000
Epoch 864/1000
Epoch 865/1000
Epoch 866/1000
Epoch 867/1000
Epoch 868/1000
Epoch 869/1000
Epoch 870/1000
Epoch 871/1000
Epoch 872/1000
Epoch 873/1000
Epoch 874/1000
Epoch 875/1000
Epoch 876/1000
Epoch 877/1000
Epoch 878/1000
Epoch 879/1000
Epoch 880/1000
Epoch 881/1000
Epoch 882/1000
Epoch 883/1000
Epoch 884/1000
Epoch 885/

W0628 20:09:41.403147 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.178981). Check your callbacks.


Epoch 917/1000
Epoch 918/1000
Epoch 919/1000
Epoch 920/1000
Epoch 921/1000
Epoch 922/1000
Epoch 923/1000
Epoch 924/1000
Epoch 925/1000
Epoch 926/1000
Epoch 927/1000
Epoch 928/1000
Epoch 929/1000
Epoch 930/1000
Epoch 931/1000
Epoch 932/1000
Epoch 933/1000
Epoch 934/1000
Epoch 935/1000
Epoch 936/1000
Epoch 937/1000
Epoch 938/1000
Epoch 939/1000
Epoch 940/1000
Epoch 941/1000
Epoch 942/1000
Epoch 943/1000
Epoch 944/1000
Epoch 945/1000
Epoch 946/1000
Epoch 947/1000
Epoch 948/1000
Epoch 949/1000
Epoch 950/1000
Epoch 951/1000
Epoch 952/1000
Epoch 953/1000
Epoch 954/1000
Epoch 955/1000
Epoch 956/1000
Epoch 957/1000
Epoch 958/1000
Epoch 959/1000
Epoch 960/1000
Epoch 961/1000
Epoch 962/1000
Epoch 963/1000
Epoch 964/1000
Epoch 965/1000
Epoch 966/1000
Epoch 967/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3686

W0628 20:09:56.115052 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.128340). Check your callbacks.


Epoch 968/1000
Epoch 969/1000
Epoch 970/1000
Epoch 971/1000
 1/21 [>.............................] - ETA: 0s - loss: 1.3224

W0628 20:09:57.103764 4521258432 callbacks.py:307] Method (on_train_batch_end) is slow compared to the batch update (0.122230). Check your callbacks.


Epoch 972/1000
Epoch 973/1000
Epoch 974/1000
Epoch 975/1000
Epoch 976/1000
Epoch 977/1000
Epoch 978/1000
Epoch 979/1000
Epoch 980/1000
Epoch 981/1000
Epoch 982/1000
Epoch 983/1000
Epoch 984/1000
Epoch 985/1000
Epoch 986/1000
Epoch 987/1000
Epoch 988/1000
Epoch 989/1000
Epoch 990/1000
Epoch 991/1000
Epoch 992/1000
Epoch 993/1000
Epoch 994/1000
Epoch 995/1000
Epoch 996/1000
Epoch 997/1000
Epoch 998/1000
Epoch 999/1000
Epoch 1000/1000


<tensorflow.python.keras.callbacks.History at 0x7fbd0cd18d68>

In [65]:
output = quantum_dmrg_model.predict(X)
output

array([[0.24470682, 0.28424042, 0.24196967, 0.22908309],
       [0.24470682, 0.28424042, 0.24196967, 0.22908309],
       [0.24470682, 0.28424042, 0.24196967, 0.22908309],
       ...,
       [0.24470682, 0.28424042, 0.24196967, 0.22908309],
       [0.24470682, 0.28424042, 0.24196967, 0.22908309],
       [0.24470682, 0.28424042, 0.24196967, 0.22908309]], dtype=float32)

In [42]:
tn_layer.trainable_variables[0:3]

[<tf.Variable 'mps_tensors_0:0' shape=(2, 2) dtype=float32, numpy=
 array([[ 1.0000000e+00, -4.8281468e-10],
        [ 5.7440247e-10,  1.0000000e+00]], dtype=float32)>,
 <tf.Variable 'mps_tensors_1:0' shape=(2, 2, 2) dtype=float32, numpy=
 array([[[ 1.01076369e-09, -9.79604731e-10],
         [ 1.82769536e-10,  1.14186494e-10]],
 
        [[-4.66370109e-10, -1.33745948e-09],
         [-1.02995341e-10, -5.84964133e-10]]], dtype=float32)>,
 <tf.Variable 'mps_tensors_2:0' shape=(2, 2, 2) dtype=float32, numpy=
 array([[[-1.3092647e-09,  1.6047685e-09],
         [-5.3281740e-10, -3.0815253e-09]],
 
        [[ 4.4603912e-10,  1.9199056e-09],
         [ 1.7281526e-09,  5.0067967e-11]]], dtype=float32)>]