Imports

In [42]:
import os.path
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

Constants

In [43]:
# Define constants
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
epochs = 5
batch_size = 64
MODEL_DIR = "./model"
DATA_DIR = "./data"
MODEL_SPLIT_PREFIX = "model_2"

Define Model, Train, Test, Shard functions

In [46]:
# Define Model
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.stack_1 = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU()
        )
        self.stack_2 = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.stack_1(x)
        return self.stack_2(x)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


def shard(model):
    num_splits = len(list(model.children()))
    splits = []
    for i in range(num_splits):
        splits.append(nn.Sequential(list(model.children())[i]))
    return splits
    
    

Setup Data

In [4]:
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)

# (Down)load training data
MNIST_folder_exists = os.path.exists(f"{DATA_DIR}/MNIST")
train_data = datasets.MNIST(
    root=DATA_DIR,
    train=True,
    download=(not MNIST_folder_exists),
    transform=ToTensor()
)
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_data = datasets.MNIST(
    root=DATA_DIR,
    train=False,
    download=(not MNIST_folder_exists),
    transform=ToTensor()
)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
print("(Down)loaded MNIST dataset")

(Down)loaded MNIST dataset


Run & Train Model (We need to do this, as an ONNX file can't easily be loaded into PyTorch)

In [5]:
# Instantiate Model
model = Model().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

print("Starting Training")
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Finished Training")

Starting Training
Epoch 1
-------------------------------
loss: 2.295470  [   64/60000]
loss: 2.295720  [ 6464/60000]
loss: 2.306595  [12864/60000]
loss: 2.272050  [19264/60000]
loss: 2.287586  [25664/60000]
loss: 2.285815  [32064/60000]
loss: 2.272544  [38464/60000]
loss: 2.271732  [44864/60000]
loss: 2.250515  [51264/60000]
loss: 2.256143  [57664/60000]
Test Error: 
 Accuracy: 40.0%, Avg loss: 2.252269 

Epoch 2
-------------------------------
loss: 2.245810  [   64/60000]
loss: 2.244970  [ 6464/60000]
loss: 2.262033  [12864/60000]
loss: 2.209453  [19264/60000]
loss: 2.237423  [25664/60000]
loss: 2.234699  [32064/60000]
loss: 2.205138  [38464/60000]
loss: 2.223561  [44864/60000]
loss: 2.182846  [51264/60000]
loss: 2.182479  [57664/60000]
Test Error: 
 Accuracy: 54.0%, Avg loss: 2.181837 

Epoch 3
-------------------------------
loss: 2.173581  [   64/60000]
loss: 2.168354  [ 6464/60000]
loss: 2.197321  [12864/60000]
loss: 2.113449  [19264/60000]
loss: 2.159253  [25664/60000]
loss: 2.

Split Model

In [47]:
shards = shard(model)

Save Shards

In [50]:
# This has to be 
sample_input_tensor = torch.randn(1, 1, 28, 28).to(device)
for i in range(len(shards)):
    MODEL_FILE = f"{MODEL_SPLIT_PREFIX}_shard_{i}.onnx"
    model_shard = shards[i]
    torch.onnx.export(
        model=model_shard,
        args=sample_input_tensor,
        f=f"{MODEL_DIR}/{MODEL_FILE}",
        verbose=True,
        input_names=['input'],
        output_names=['output']
    )
    print(f"Saved {MODEL_FILE}")
    
    # Update sample input tensor for next layer/stack of layers
    model_shard.eval()
    sample_input_tensor = model_shard(sample_input_tensor)

Exported graph: graph(%input : Float(1, 1, 28, 28, strides=[784, 784, 28, 1], requires_grad=0, device=mps:0)):
  %output : Float(1, 784, strides=[784, 1], requires_grad=0, device=mps:0) = onnx::Flatten[axis=1, onnx_name="/0/Flatten"](%input), scope: torch.nn.modules.container.Sequential::/torch.nn.modules.flatten.Flatten::0 # /usr/local/lib/python3.9/site-packages/torch/nn/modules/flatten.py:49:0
  return (%output)

Saved model_2_shard_0.onnx
Exported graph: graph(%input : Float(1, 784, strides=[784, 1], requires_grad=0, device=mps:0),
      %0.0.weight : Float(512, 784, strides=[784, 1], requires_grad=1, device=mps:0),
      %0.0.bias : Float(512, strides=[1], requires_grad=1, device=mps:0),
      %0.2.weight : Float(512, 512, strides=[512, 1], requires_grad=1, device=mps:0),
      %0.2.bias : Float(512, strides=[1], requires_grad=1, device=mps:0),
      %0.4.weight : Float(10, 512, strides=[512, 1], requires_grad=1, device=mps:0),
      %0.4.bias : Float(10, strides=[1], requires_gra