In [35]:
import mpemu

In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchsummary import summary
from CustomConv2D import RConv2D

In [37]:
class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.conv1 = RConv2D(1, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = RConv2D(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [39]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [40]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [41]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [42]:
model = MNISTModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [43]:
print(model)
summary(model, (1, 28, 28))

MNISTModel(
  (conv1): RConv2D()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): RConv2D()
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           RConv2D-1           [-1, 32, 28, 28]             320
         MaxPool2d-2           [-1, 32, 14, 14]               0
           RConv2D-3           [-1, 64, 14, 14]          18,496
         MaxPool2d-4             [-1, 64, 7, 7]               0
            Linear-5                  [-1, 128]         401,536
            Linear-6                   [-1, 10]           1,290
Total params: 421,642
Trainable params: 421,642
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.36
Params size (MB): 1.61
Es

In [61]:
def train_model(model, criterion, optimizer, train_loader, device, num_epochs=5):
    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Train Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}")

In [62]:
train_model(model, criterion, optimizer, train_loader, device)

model.eval()

from mpemu import mpt_emu
list_exempt_layers = ["conv1", "conv2", "fc1", "fc2"]
model, emulator = mpt_emu.quantize_model(model, dtype="E4M3", list_exempt_layers=list_exempt_layers)

def evaluate_model(model, criterion, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Accuracy after quantization: {accuracy:.2f}%")

evaluate_model(model, criterion, test_loader, device)

print("Training complete, quantized model evaluation complete.")

Train Epoch [1/5], Loss: 0.045492727024124026
Train Epoch [2/5], Loss: 0.04155944252109274
Train Epoch [3/5], Loss: 0.03382681631384061
Train Epoch [4/5], Loss: 0.03323990523399032
Train Epoch [5/5], Loss: 0.03439501183939741
e4m3 : quantizing model weights..
Accuracy after quantization: 97.92%
Training complete, quantized model evaluation complete.
