In [43]:
import numpy as np 
import torch
import torchmps
import numba
import json
from tqdm import tqdm

In [2]:
dimvec = 784
pos_label = 392
nblabels = 10
bond_len = 10

learning_rate = 1e-4

In [8]:
cosx = np.random.uniform(size=dimvec)

inputs = np.array([np.array([cosx, np.sqrt(1-cosx*cosx)]).T], dtype=np.float32)

In [9]:
mps(torch.tensor(inputs))

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], grad_fn=<ViewBackward>)

In [74]:
@numba.njit(numba.float64[:, :](numba.float64[:]))
def convert_pixels_to_tnvector(pixels):
    tnvector = np.concatenate(
        (np.expand_dims(np.cos(0.5*np.pi*pixels/256.), axis=0),
         np.expand_dims(np.sin(0.5*np.pi*pixels/256.), axis=0)),
        axis=0
    ).T
    return tnvector

In [86]:
class MNISTJSON_Dataset(torch.utils.data.Dataset):
    def __init__(self, filepath):
        self.filepath = filepath
        self.digit_map = {digit: int(digit) for digit in '0123456789'}
        self.X = None
        self.Y = None

        for pixels, digit in self.generate_data(open(filepath, 'r')):
            pixel_vector = convert_pixels_to_tnvector(np.array(pixels))
            ans = np.zeros((10,))
            ans[self.digit_map[digit]] = 1.
            if self.X is None:
                self.X = np.expand_dims(pixel_vector, axis=0)
                self.Y = np.expand_dims(ans, axis=0)
            else:
                self.X = np.concatenate((self.X,
                                         np.expand_dims(pixel_vector, axis=0)),
                                        axis=0)
                self.Y = np.concatenate((self.Y,
                                         np.expand_dims(ans, axis=0)),
                                        axis=0)
            if self.Y.shape[0] >= 1000:
                break

    def generate_data(self, mnist_file):
        for line in mnist_file:
            data = json.loads(line)
            pixels = np.array(data['pixels'])
            digit = data['digit']
            yield pixels, digit

    def __getitem__(self, idx):
        x = torch.Tensor(self.X[idx, :, :])
        y = torch.Tensor(self.Y[idx, :])
        return x, y

    def __len__(self):
        return self.Y.shape[0]

In [87]:
dataset = MNISTJSON_Dataset('/Users/stephenhky/PyProjects/tensornetwork-learn/experiments/mnist_784/mnist_784.json')

In [91]:
mps = torchmps.torchmps.MPS(input_dim=dimvec, output_dim=nblabels, bond_dim=bond_len,               
                            adaptive_mode=False, periodic_bc=False, label_site=pos_label)

In [92]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mps.parameters(), lr=learning_rate)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=100)

In [93]:
nb_epochs = 10

for _ in tqdm(range(nb_epochs)):
    for i, (x_train, y_train) in enumerate(dataloader):
        optimizer.zero_grad()
        target_indices = torch.max(y_train, 1)[1]
        y_pred = mps(x_train)
        loss = criterion(y_pred, target_indices)
        loss.backward()
        optimizer.step()

100%|██████████| 10/10 [00:35<00:00,  3.56s/it]


In [95]:
mps(dataset[0:1][0])

tensor([[281.4480, 279.4002, 282.6530, 284.3252, 278.3716, 284.2715, 281.6510,
         279.2263, 283.2119, 279.5974]], grad_fn=<ViewBackward>)

In [96]:
dataset[0:1][1]

tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])