In [1]:
import json
import concurrent.futures

import numpy as np
import numba
import tensornetwork as tn
from tqdm import tqdm

In [2]:
dimvec = 784
pos_label = 392
nblabels = 10
m = 10
nbdata = 70000

# Data Loading

In [3]:
def generate_data(mnist_file):
    for line in mnist_file:
        data = json.loads(line)
        pixels = np.array(data['pixels'])
        digit = data['digit']
        yield pixels, digit


@numba.njit(numba.float64[:, :](numba.float64[:]))
def convert_pixels_to_tnvector(pixels):
    tnvector = np.concatenate(
        (np.expand_dims(np.cos(0.5*np.pi*pixels/256.), axis=0),
         np.expand_dims(np.sin(0.5*np.pi*pixels/256.), axis=0)),
        axis=0
    ).T
    return tnvector

In [4]:
# Reading the data
label_dict = {str(i): i for i in range(10)}
X = np.zeros((nbdata, dimvec, 2))
Y = np.zeros((nbdata, nblabels))
for i, (pixels, label) in enumerate(generate_data(open('/data/hok/PyProjects/tensornetwork-learn/experiments/mnist_784/mnist_784.json', 'r'))):
    X[i, :, :] = convert_pixels_to_tnvector(pixels)
    Y[i, label_dict[label]] = 1.

# Classifier

In [5]:
class QuantumTNSweepingClassifier:
    def __init__(self, dimvec, pos_label, nblabels, bond_len, nearzero_std=1e-9):
        self.dimvec = dimvec
        self.pos_label = pos_label
        self.nblabels = nblabels
        self.m = bond_len
        self.nearzero_std = nearzero_std
        self.current_pos_label = 0
        self.construct_nodes()
        
    def construct_nodes(self):
        self.nodes = []
        for i in range(dimvec):
            if i == 0:
                node = tn.Node(np.random.normal(loc=0.0, scale=self.nearzero_std, size=(2, self.m, self.nblabels)), 
                               name='node{}'.format(i))
            elif i == self.dimvec - 1:
                node = tn.Node(np.random.normal(loc=0.0, scale=self.nearzero_std, size=(2, self.m)),
                               name='node{}'.format(i))
                for j in range(min(self.m, 2)):
                    node.tensor[j, j] += 1.
            else:
                node = tn.Node(np.random.normal(loc=0.0, scale=self.nearzero_std, size=(2, self.m, self.m)),
                               name='node{}'.format(i))
            self.nodes.append(node)
           
        cosx = np.random.uniform(size=dimvec)
        self.input_nodes = [tn.Node(np.array([cosx[i], np.sqrt(1-cosx[i]*cosx[i])]), 
                                    name='input{}'.format(i)) 
                            for i in range(dimvec)]

    def infer_single_datum(self, input):
        assert input.shape[0] == self.dimvec
        assert input.shape[1] == 2
        
        for i in range(dimvec):
            self.input_nodes[i].tensor = input[i, :]
            
        # connecting edges
        for i in range(self.dimvec):
            self.nodes[i][0] ^ self.input_nodes[i][0]
        self.nodes[0][1] ^ self.nodes[1][1]
        for i in range(1, dimvec-1):
            self.nodes[i][2] ^ self.nodes[i+1][1]
            
        # contraction
        final_node = tn.contractors.auto(self.nodes + self.input_nodes, 
                                         output_edge_order=[self.nodes[self.current_pos_label][2 
                                                                                               if self.current_pos_label==0 
                                                                                               else 3]])
        
        return final_node
    
    def predict_proba(self, X):
        nbdata = X.shape[0]
        assert X.shape[1] == self.dimvec
        assert X.shape[2] == 2
        
        return np.array([self.infer_single_datum(X[i, :, :]).tensor for i in range(nbdata)])
    
    def label_block(self):
        assert self.current_pos_label != self.dimvec - 1
        
        if self.current_pos_label == 0:
            self.nodes[0][1] ^ self.nodes[1][1]
            block_node = tn.contractors.auto([self.nodes[0], self.nodes[1]],
                                             output_edge_order=[self.nodes[0][0], 
                                                                self.nodes[1][0],
                                                                self.nodes[1][2],
                                                                self.nodes[0][2]])
            return block_node
        elif self.current_pos_label == self.dimvec-2:
            self.nodes[self.dimvec-2][2] ^ self.nodes[self.dimvec-1][1]
            block_node = tn.contractors.auto([self.nodes[self.dimvec-2], self.nodes[self.dimvec-1]],
                                             output_edge_order=[self.nodes[self.dimvec-2][1],
                                                                self.nodes[self.dimvec-2][0],
                                                                self.nodes[self.dimvec-1][0],
                                                                self.nodes[self.dimvec-2][3]])
            return block_node
        else:
            self.nodes[self.current_pos_label][2] ^ self.nodes[self.current_pos_label+1][1]
            block_node = tn.contractors.auto([self.nodes[self.current_pos_label],
                                              self.nodes[self.current_pos_label+1]],
                                             output_edge_order=[self.nodes[self.current_pos_label][1],
                                                                self.nodes[self.current_pos_label][0],
                                                                self.nodes[self.current_pos_label+1][0],
                                                                self.nodes[self.current_pos_label+1][2],
                                                                self.nodes[self.current_pos_label][3]])
            return block_node
    
    def phi(self, input):
        assert input.shape[0] == self.dimvec
        assert input.shape[1] == 2
        assert self.current_pos_label != self.dimvec - 1
        
        for i in range(dimvec):
            self.input_nodes[i].tensor = input[i, :]
        
        if self.current_pos_label == 0:
            for i in range(2, self.dimvec-1):
                self.nodes[i][2] ^ self.nodes[i+1][1]
                self.nodes[i][0] ^ self.input_nodes[i][0]
            self.nodes[self.dimvec-1][0] ^ self.input_nodes[self.dimvec-1][0]
            
            phi = tn.contractors.auto(self.nodes[2:]+self.input_nodes,
                                      output_edge_order=[self.input_nodes[0][0],
                                                         self.input_nodes[1][0],
                                                         self.nodes[2][1]])
            return phi
        elif self.current_pos_label == self.dimvec-2:
            self.nodes[0][1] ^ self.nodes[1][1]
            self.nodes[0][0] ^ self.input_nodes[0][0]
            for i in range(1, self.dimvec-3):
                self.nodes[i][2] ^ self.nodes[i+1][1]
                self.nodes[i][0] ^ self.input_nodes[i][0]
            self.nodes[self.dimvec-3][0] ^ self.input_nodes[self.dimvec-3][0]
            
            phi = tn.contractors.auto(self.nodes[:-2]+self.input_nodes,
                                      output_edge_order=[self.nodes[self.dimvec-3][2],
                                                         self.input_nodes[self.dimvec-2][0],
                                                         self.input_nodes[self.dimvec-1][0]])
            return phi
        else:
            # left
            self.nodes[0][1] ^ self.nodes[1][1]
            self.nodes[0][0] ^ self.input_nodes[0][0]
            for i in range(1, self.current_pos_label-1):
                self.nodes[i][2] ^ self.nodes[i+1][1]
                self.nodes[i][0] ^ self.input_nodes[i][0]
            self.nodes[self.current_pos_label-1][0] ^ self.input_nodes[self.current_pos_label-1][0]
            
            # right
            for i in range(self.current_pos_label+2, self.dimvec-1):
                self.nodes[i][2] ^ self.nodes[i+1][1]
                self.nodes[i][0] ^ self.input_nodes[i][0]
            self.nodes[self.dimvec-1][0] ^ self.input_nodes[self.dimvec-1][0]
            
            phi = tn.contractors.auto(self.nodes[:self.current_pos_label] + \
                                      self.nodes[(self.current_pos_label+2):] + \
                                      self.input_nodes,
                                      output_edge_order=[self.nodes[self.current_pos_label-1][2],
                                                         self.input_nodes[self.current_pos_label][0],
                                                         self.input_nodes[self.current_pos_label+1][0],
                                                         self.nodes[self.current_pos_label+2][1]])
            return phi
        
    def calculate_deltalabel_phi(self, n):
        phi = self.phi(X[n, :, :])
        diff_predlabel_node = tn.Node(Y[n, :]) - self.infer_single_datum(X[n, :, :])
        gcfcn = tn.outer_product(phi, diff_predlabel_node)
        return gcfcn
    
    def fit(self, X, Y, epochs=10, lr=0.001):
        nbdata = X.shape[0]
        assert X.shape[1] == self.dimvec
        assert X.shape[2] == 2
        assert Y.shape[0] == nbdata
        assert Y.shape[1] == self.nblabels
        
        assert self.current_pos_label == 0
        assert len(self.nodes[0].shape) == 3
        
        for round in range(epochs):
            print('Round {}'.format(round))
            for i in range(self.dimvec):
                print('\tsweeping node {}'.format(i))
                self.current_pos_label = i
                label_block_node = self.label_block()
                delta_label_block_node = tn.Node(np.zeros(label_block_node.shape))
                with concurrent.futures.ProcessPoolExecutor() as executor:
                    for gcfcn in executor.map(self.calculate_deltalabel_phi, range(nbdata)):
                        delta_label_block_node += gcfcn
                
                delta_label_block_node.tensor += lr
                label_block_node -= delta_label_block_node
                if i == 0:
                    left, right, _ = tn.split_node(label_block_node, 
                                                   [label_block_node[0]], 
                                                   [label_block_node[1], label_block_node[2], label_block_node[3]],
                                                   max_singular_values=self.m)
                elif i == self.dimvec - 2:
                    left, right, _ = tn.split_node(label_block_node,
                                                   [label_block_node[0], label_block_node[1]],
                                                   [label_block_node[2], label_block_node[3]],
                                                   max_singular_values=self.m)
                else:
                    left, right, _ = tn.split_node(label_block_node,
                                                   [label_block_node[0], label_block_node[1]],
                                                   [label_block_node[2], label_block_node[3], label_block_node[4]],
                                                   max_singular_values=self.m)
                self.nodes[i] = tn.Node(left, name='node{}'.format(i))
                self.nodes[i+1] = tn.Node(right, name='node{}'.format(i+1))

# Computation

In [6]:
sampled_ids = np.random.choice(range(nbdata), size=1000, replace=False)

samX = X[sampled_ids, :, :]
samY = Y[sampled_ids, :]

In [7]:
dmrg_classifier = QuantumTNSweepingClassifier(dimvec, pos_label, nblabels, m)

In [None]:
dmrg_classifier.fit(samX, samY)

Round 0
	sweeping node 0


In [70]:
dmrg_classifier.nodes[0].shape

(2, 10, 10)