In [1]:
import numpy as np 
import torch
import torchmps
import numba
import json
from tqdm import tqdm

In [2]:
dimvec = 784
pos_label = 392
nblabels = 10
bond_len = 10

learning_rate = 1e-4

In [3]:
cosx = np.random.uniform(size=dimvec)

inputs = np.array([np.array([cosx, np.sqrt(1-cosx*cosx)]).T], dtype=np.float32)

In [4]:
@numba.njit(numba.float64[:, :](numba.float64[:]))
def convert_pixels_to_tnvector(pixels):
    tnvector = np.concatenate(
        (np.expand_dims(np.cos(0.5*np.pi*pixels/256.), axis=0),
         np.expand_dims(np.sin(0.5*np.pi*pixels/256.), axis=0)),
        axis=0
    ).T
    return tnvector

In [5]:
class MNISTJSON_Dataset(torch.utils.data.Dataset):
    def __init__(self, filepath):
        self.filepath = filepath
        self.digit_map = {digit: int(digit) for digit in '0123456789'}
        self.X = None
        self.Y = None

        for pixels, digit in self.generate_data(open(filepath, 'r')):
            pixel_vector = convert_pixels_to_tnvector(np.array(pixels))
            ans = np.zeros((10,))
            ans[self.digit_map[digit]] = 1.
            if self.X is None:
                self.X = np.expand_dims(pixel_vector, axis=0)
                self.Y = np.expand_dims(ans, axis=0)
            else:
                self.X = np.concatenate((self.X,
                                         np.expand_dims(pixel_vector, axis=0)),
                                        axis=0)
                self.Y = np.concatenate((self.Y,
                                         np.expand_dims(ans, axis=0)),
                                        axis=0)
            if self.Y.shape[0] >= 1000:
                break

    def generate_data(self, mnist_file):
        for line in mnist_file:
            data = json.loads(line)
            pixels = np.array(data['pixels'])
            digit = data['digit']
            yield pixels, digit

    def __getitem__(self, idx):
        x = torch.Tensor(self.X[idx, :, :])
        y = torch.Tensor(self.Y[idx, :])
        return x, y

    def __len__(self):
        return self.Y.shape[0]

In [6]:
dataset = MNISTJSON_Dataset('/data/hok/testdata/mnist/mnist_784/mnist_784.json')

In [7]:
mps = torchmps.MPS(input_dim=dimvec, output_dim=nblabels, bond_dim=bond_len,               
                            adaptive_mode=False, periodic_bc=False, label_site=pos_label)

In [8]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mps.parameters(), lr=learning_rate)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=100)

In [9]:
nb_epochs = 10

for _ in tqdm(range(nb_epochs)):
    for i, (x_train, y_train) in enumerate(dataloader):
        optimizer.zero_grad()
        target_indices = torch.max(y_train, 1)[1]
        y_pred = mps(x_train)
        loss = criterion(y_pred, target_indices)
        loss.backward()
        optimizer.step()

100%|██████████| 10/10 [00:21<00:00,  2.10s/it]


In [10]:
mps(dataset[0:1][0])

tensor([[65.9429, 64.2422, 67.4848, 69.9750, 62.0792, 71.4161, 63.0720, 64.7849,
         69.7529, 64.9899]], grad_fn=<ViewBackward>)

In [11]:
dataset[0:1][1]

tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

In [10]:
list(mps.parameters())

[Parameter containing:
 tensor([[[[ 3.5414e-10,  1.5005e-09],
           [-6.9136e-10, -2.8323e-10],
           [ 1.0447e-09,  2.1095e-09],
           ...,
           [-1.5739e-09,  1.9702e-09],
           [ 7.4821e-10,  1.7361e-09],
           [ 6.5842e-10,  1.9907e-10]],
 
          [[ 1.0456e-09,  1.6475e-09],
           [-4.9223e-10, -3.5463e-10],
           [ 9.6608e-10, -1.8243e-10],
           ...,
           [ 1.2428e-09,  2.2870e-09],
           [ 1.5631e-09,  1.6852e-09],
           [-7.5318e-10,  1.1177e-09]],
 
          [[ 9.2228e-10, -1.8183e-10],
           [-8.2833e-10,  1.0934e-09],
           [ 7.4646e-10, -9.6212e-10],
           ...,
           [-4.0425e-10, -1.0105e-10],
           [ 1.0959e-09, -7.3402e-10],
           [ 3.8989e-10,  1.7179e-09]],
 
          ...,
 
          [[ 1.7760e-09, -4.7109e-10],
           [-4.6855e-10,  4.5083e-10],
           [ 5.9861e-10,  1.1681e-09],
           ...,
           [-2.0266e-10, -6.6465e-10],
           [-6.7241e-10, -9.3