In [41]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchonn as onn
from torchonn.models import ONNBaseModel
import torch.optim as optim

In [118]:
class ONNModel(ONNBaseModel):
    def __init__(self, device=torch.device("cpu")):
        super().__init__()
        self.conv1 = onn.layers.MZIBlockConv2d(
            in_channels=1,
            out_channels=6,
            kernel_size=3,
            stride=1,
            padding=1,
            dilation=1,
            bias=True,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.conv2 = onn.layers.MZIBlockConv2d(
            in_channels=6,
            out_channels=10,
            kernel_size=3,
            stride=1,
            padding=1,
            dilation=1,
            bias=True,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.pool = nn.AdaptiveAvgPool2d(5)
        self.linear = onn.layers.MZIBlockLinear(
            in_features=10*5*5,
            out_features=10,
            bias=True,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, x):
        x = x[:, None, :, :]
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

In [110]:
from scipy import stats, fft

class MNISTDataset(torch.utils.data.Dataset):
    def __init__(self, train, dataset_path='/mnist'):
        self.dataset = torchvision.datasets.MNIST('./data', train=train, download=True)
        self.train = train

    def __getitem__(self, index):
        (img, label) = self.dataset[index]
        transformed_img = np.array(img).astype(np.float32)
        mean, std = np.mean(transformed_img), np.std(transformed_img)
        transformed_img -= mean
        transformed_img /= std
        return torch.from_numpy(transformed_img), label

    def __len__(self):
        return len(self.dataset)

mnist_trainset = MNISTDataset(train = True)
mnist_trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=512,
                                          shuffle=True)

mnist_testset = MNISTDataset(train = False)
mnist_testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=512,
                                         shuffle=False)

In [111]:
dtype = torch.float32

def mnist_check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on train set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

In [136]:
test_accuracies = []
train_accuracies = []

device = torch.device("cpu")

def mnist_train(model, optimizer, epochs=10):
    """
    Train a model on MNIST using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(mnist_trainloader):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # to avoid explosions that create NaNs
            nn.utils.clip_grad_norm_(model.parameters(), 1)

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

        print(f'Epoch {e}, loss = {loss.item()}')
        test_accuracies.append(mnist_check_accuracy(mnist_testloader, model))
        train_accuracies.append(mnist_check_accuracy(mnist_trainloader, model))
        print()

In [137]:
mnist_model = ONNModel()
learning_rate = 0.001

optimizer = optim.SGD(mnist_model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
mnist_train(mnist_model, optimizer, epochs = 50)

Epoch 0, loss = 130.7095489501953
Checking accuracy on test set
Got 3002 / 10000 correct (30.02)
Checking accuracy on train set
Got 17460 / 60000 correct (29.10)

Epoch 1, loss = 1.9106556177139282
Checking accuracy on test set
Got 4495 / 10000 correct (44.95)
Checking accuracy on train set
Got 26470 / 60000 correct (44.12)

Epoch 2, loss = 1.3894652128219604
Checking accuracy on test set
Got 6497 / 10000 correct (64.97)
Checking accuracy on train set
Got 38626 / 60000 correct (64.38)

Epoch 3, loss = 0.953540027141571
Checking accuracy on test set
Got 7807 / 10000 correct (78.07)
Checking accuracy on train set
Got 46361 / 60000 correct (77.27)

Epoch 4, loss = 0.9261117577552795
Checking accuracy on test set
Got 8485 / 10000 correct (84.85)
Checking accuracy on train set
Got 50979 / 60000 correct (84.97)

Epoch 5, loss = 0.4505365788936615
Checking accuracy on test set
Got 8952 / 10000 correct (89.52)
Checking accuracy on train set
Got 53583 / 60000 correct (89.31)

Epoch 6, loss = 0.

[E thread_pool.cpp:109] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:109] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:109] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:109] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 