In [1]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchonn as onn
from torchonn.models import ONNBaseModel
import torch.optim as optim



In [2]:
class ONNModel(ONNBaseModel):
    def __init__(self, device=torch.device("cpu")):
        super().__init__()
        self.conv1 = onn.layers.MZIBlockConv2d(
            in_channels=3,
            out_channels=6,
            kernel_size=3,
            stride=1,
            padding=1,
            dilation=1,
            bias=True,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.conv2 = onn.layers.MZIBlockConv2d(
            in_channels=6,
            out_channels=10,
            kernel_size=3,
            stride=1,
            padding=1,
            dilation=1,
            bias=True,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.pool = nn.AdaptiveAvgPool2d(5)
        self.linear = onn.layers.MZIBlockLinear(
            in_features=10*5*5,
            out_features=10,
            bias=True,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

In [4]:
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the data
])

cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
cifar_trainloader = torch.utils.data.DataLoader(cifar_trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
cifar_testloader = torch.utils.data.DataLoader(cifar_testset, batch_size=64,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
dtype = torch.float32

def cifar_check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on train set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

In [6]:
test_accuracies = []
train_accuracies = []

device = torch.device("cpu")

def cifar_train(model, optimizer, epochs=10):
    """
    Train a model on CIFAR using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(cifar_trainloader):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # to avoid explosions that create NaNs
            nn.utils.clip_grad_norm_(model.parameters(), 1)

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

        print(f'Epoch {e}, loss = {loss.item()}')
        if e % 5 == 0:
            test_accuracies.append(cifar_check_accuracy(cifar_testloader, model))
        if e % 10 == 0:
            train_accuracies.append(cifar_check_accuracy(cifar_trainloader, model))
        print()

In [7]:
# cifar_model = ONNModel()
learning_rate = 0.001

optimizer = optim.SGD(cifar_model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
cifar_train(cifar_model, optimizer, epochs = 50)

Epoch 0, loss = 1.9971799850463867
Checking accuracy on test set
Got 3361 / 10000 correct (33.61)
Checking accuracy on train set
Got 17218 / 50000 correct (34.44)

Epoch 1, loss = 1.6028977632522583

Epoch 2, loss = 1.5796785354614258

Epoch 3, loss = 1.6570179462432861

Epoch 4, loss = 1.6343716382980347

Epoch 5, loss = 0.8567443490028381
Checking accuracy on test set
Got 4562 / 10000 correct (45.62)

Epoch 6, loss = 1.6326223611831665

Epoch 7, loss = 1.0611159801483154

Epoch 8, loss = 1.29726243019104

Epoch 9, loss = 1.587180495262146

Epoch 10, loss = 1.458213448524475
Checking accuracy on test set
Got 4678 / 10000 correct (46.78)
Checking accuracy on train set
Got 24006 / 50000 correct (48.01)

Epoch 11, loss = 1.5025169849395752

Epoch 12, loss = 1.442958116531372

Epoch 13, loss = 1.1779186725616455

Epoch 14, loss = 1.364230751991272

Epoch 15, loss = 1.453565239906311
Checking accuracy on test set
Got 4791 / 10000 correct (47.91)

Epoch 16, loss = 2.1398983001708984

Epoch 

KeyboardInterrupt: 

In [10]:
learning_rate = 0.001

optimizer = optim.SGD(cifar_model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
cifar_train(cifar_model, optimizer, epochs = 50)

Epoch 0, loss = 1.4161609411239624
Checking accuracy on test set
Got 5423 / 10000 correct (54.23)
Checking accuracy on train set
Got 27707 / 50000 correct (55.41)

Epoch 1, loss = 1.8847999572753906

Epoch 2, loss = 0.7959414720535278

Epoch 3, loss = 1.2352678775787354

Epoch 4, loss = 1.5925626754760742

Epoch 5, loss = 1.2439088821411133
Checking accuracy on test set
Got 5406 / 10000 correct (54.06)

Epoch 6, loss = 0.8024638295173645

Epoch 7, loss = 1.5905225276947021

Epoch 8, loss = 1.1617804765701294

Epoch 9, loss = 1.7929065227508545

Epoch 10, loss = 0.6599476933479309
Checking accuracy on test set
Got 5446 / 10000 correct (54.46)
Checking accuracy on train set
Got 27632 / 50000 correct (55.26)

Epoch 11, loss = 1.406890869140625

Epoch 12, loss = 1.1086044311523438

Epoch 13, loss = 1.3430298566818237

Epoch 14, loss = 1.2047455310821533

Epoch 15, loss = 1.5407623052597046
Checking accuracy on test set
Got 5497 / 10000 correct (54.97)

Epoch 16, loss = 0.6278024911880493



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1222c7490>
Traceback (most recent call last):
  File "/Users/matthewho/Photonic_computing/photonics_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/Users/matthewho/Photonic_computing/photonics_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/homebrew/Cellar/python@3.10/3.10.6_2/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/opt/homebrew/Cellar/python@3.10/3.10.6_2/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/opt/homebrew/Cellar/python@3.10/3.10.6_2/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multipr

KeyboardInterrupt: 