In [14]:
import json

import numpy as np
import numba
import tensornetwork as tn

In [12]:
dimvec = 784
pos_label = 392
nblabels = 10
m = 10
nbdata = 70000

# Data Loading

In [10]:
def generate_data(mnist_file):
    for line in mnist_file:
        data = json.loads(line)
        pixels = np.array(data['pixels'])
        digit = data['digit']
        yield pixels, digit


@numba.njit(numba.float64[:, :](numba.float64[:]))
def convert_pixels_to_tnvector(pixels):
    tnvector = np.concatenate(
        (np.expand_dims(np.cos(0.5*np.pi*pixels/256.), axis=0),
         np.expand_dims(np.sin(0.5*np.pi*pixels/256.), axis=0)),
        axis=0
    ).T
    return tnvector

In [15]:
# Reading the data
label_dict = {str(i): i for i in range(10)}
X = np.zeros((nbdata, dimvec, 2))
Y = np.zeros((nbdata, nblabels))
for i, (pixels, label) in enumerate(generate_data(open('/data/hok/PyProjects/tensornetwork-learn/experiments/mnist_784/mnist_784.json', 'r'))):
    X[i, :, :] = convert_pixels_to_tnvector(pixels)
    Y[i, label_dict[label]] = 1.

# Classifier

In [46]:
class QuantumTNSweepingClassifier:
    def __init__(self, dimvec, pos_label, nblabels, bond_len, nearzero_std=1e-9):
        self.dimvec = dimvec
        self.pos_label = pos_label
        self.nblabels = nblabels
        self.m = bond_len
        self.nearzero_std = nearzero_std
        self.current_pos_label = 0
        self.construct_nodes()
        
    def construct_nodes(self):
        self.nodes = []
        for i in range(dimvec):
            if i == 0:
                node = tn.Node(np.random.normal(loc=0.0, scale=self.nearzero_std, size=(2, self.m, self.nblabels)), 
                               name='node{}'.format(i))
            elif i == self.dimvec - 1:
                node = tn.Node(np.random.normal(loc=0.0, scale=self.nearzero_std, size=(2, self.m)),
                               name='node{}'.format(i))
                for j in range(min(self.m, 2)):
                    node.tensor[j, j] += 1.
            else:
                node = tn.Node(np.random.normal(loc=0.0, scale=self.nearzero_std, size=(2, self.m, self.m)),
                               name='node{}'.format(i))
            self.nodes.append(node)
           
        cosx = np.random.uniform(size=dimvec)
        self.input_nodes = [tn.Node(np.array([cosx[i], np.sqrt(1-cosx[i]*cosx[i])]), 
                                    name='input{}'.format(i)) 
                            for i in range(dimvec)]

    def infer_single_datum(self, input):
        assert input.shape[0] == self.dimvec
        assert input.shape[1] == 2
        
        for i in range(dimvec):
            self.input_nodes[i].tensor = input[i, :]
            
        # connecting edges
        for i in range(self.dimvec):
            self.nodes[i][0] ^ self.input_nodes[i][0]
        self.nodes[0][1] ^ self.nodes[1][1]
        for i in range(1, dimvec-1):
            self.nodes[i][2] ^ self.nodes[i+1][1]
            
        # contraction
        final_node = tn.contractors.auto(self.nodes + self.input_nodes, 
                                         output_edge_order=[self.nodes[self.current_pos_label][2 
                                                                                               if self.current_pos_label==0 
                                                                                               else 3]])
        
        return final_node.tensor
    
    def predict_proba(self, X):
        nbdata = X.shape[0]
        assert X.shape[1] == self.dimvec
        assert X.shape[2] == 2
        
        return np.array([self.infer_single_datum(X[i, :, :]) for i in range(nbdata)])
    
    def label_block(self):
        if self.current_pos_label == 0:
            self.nodes[0][1] ^ self.nodes[1][1]
            block_node = tn.contractors.auto([self.nodes[0], self.nodes[1]],
                                             output_edge_order=[self.nodes[0][2], self.nodes[1][2], 
                                                                self.nodes[0][0], self.nodes[1][0]])
            return block_node
        elif self.current_pos_label == self.dimvec-1:
            self.nodes[self.dimvec-2][2] ^ self.nodes[self.dimvec-1][1]
            block_node = tn.contractors.auto([self.nodes[self.dimvec-2], self.nodes[self.dimvec-1]],
                                             output_edge_order=[self.nodes[self.dimvec-2][1],
                                                                self.nodes[self.dimvec-2][3],
                                                                self.nodes[self.dimvec-2][0],
                                                                self.nodes[self.dimvec-1][0]])
            return block_node
        else:
            self.nodes[self.current_pos_label][2] ^ self.nodes[self.current_pos_label+1][1]
            block_node = tn.contractors.auto([self.nodes[self.current_pos_label],
                                              self.nodes[self.current_pos_label+1]],
                                             output_edge_order=[self.nodes[self.current_pos_label][1],
                                                                self.nodes[self.current_pos_label][3],
                                                                self.nodes[self.current_pos_label+1][2],
                                                                self.nodes[self.current_pos_label][0],
                                                                self.nodes[self.current_pos_label+1][0]])
            return block_node
            
    
    
    def fit(self, X, Y):
        pass

# Computation

In [47]:
dmrg_classifier = QuantumTNSweepingClassifier(dimvec, pos_label, nblabels, m)

In [48]:
dmrg_classifier.predict_proba(X[0:10, :, :])

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])