In [1]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np

In [2]:
batch_size = 64
learning_rate = 0.001
hidden_size = 32
num_epochs = 25
input_size = 28 * 28  # MNIST images are 28x28 pixels
num_classes = 10  # 10 digits (0-9)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)), # MNIST mean and std
    transforms.Lambda(lambda x: x * 0.01)
])

In [5]:
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

# Create data loaders
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False
)


In [6]:
import brevitas.nn as qnn
import torch.nn as nn
import torch

class QuantMLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, n_bits=3):
        super(QuantMLP, self).__init__()

        self.flatten = nn.Flatten()

        # Input quantization
        self.quant_inp = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)

        # First quantized linear layer
        self.fc1 = qnn.QuantLinear(input_size, hidden_size, True, weight_bit_width=n_bits, bias_quant=None)

        self.quant2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc2 = qnn.QuantLinear(hidden_size, hidden_size, True, weight_bit_width=n_bits, bias_quant=None)
        self.quant3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc3 = qnn.QuantLinear(hidden_size, num_classes, True, weight_bit_width=n_bits, bias_quant=None)
    def forward(self, x):
        x = self.quant_inp(x)
        x = self.quant2(torch.relu(self.fc1(x)))
        x = self.quant3(torch.relu(self.fc2(x)))
        x = self.fc3(x)
        return x

In [7]:
from concrete.ml.torch.compile import compile_brevitas_qat_model
model = QuantMLP(input_size, hidden_size, num_classes).to(device)

Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [9]:
def train():
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        # For a batch size of 64, images shape might be [64, 1, 28, 28]
        batch_size = images.size(0)  # Should be 64
        
        # Reshape to [64, 784] for MNIST
        images = images.view(batch_size, -1).to(device)
        labels = labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(train_loader)
    accuracy = 100 * correct / total
    
    return epoch_loss, accuracy

In [10]:
def test():
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            batch_size = images.size(0)  # Should be 64
        
        # Reshape to [64, 784] for MNIST
            images = images.view(batch_size, -1).to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:
import time
start = time.time()
print(f"Training on {device}")
for epoch in range(num_epochs):
    train_loss, train_acc = train()
    test_acc = test()

    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Train Loss: {train_loss:.4f}, '
          f'Train Accuracy: {train_acc:.2f}%, '
          f'Test Accuracy: {test_acc:.2f}%')

# Save the model
torch.save(model.state_dict(), 'checkpoints/mnist_mlp.pth')
print('Model saved to mnist_mlp_QAT.pth')
end = time.time()
print("Training time", end-start)

Training on cpu


  return super().rename(names)


Epoch [1/25], Train Loss: 1.1057, Train Accuracy: 63.53%, Test Accuracy: 81.62%
Epoch [2/25], Train Loss: 0.5315, Train Accuracy: 83.92%, Test Accuracy: 86.88%
Epoch [3/25], Train Loss: 0.4205, Train Accuracy: 87.31%, Test Accuracy: 88.37%
Epoch [4/25], Train Loss: 0.3669, Train Accuracy: 89.09%, Test Accuracy: 90.10%
Epoch [5/25], Train Loss: 0.3330, Train Accuracy: 90.03%, Test Accuracy: 90.37%
Epoch [6/25], Train Loss: 0.3106, Train Accuracy: 90.71%, Test Accuracy: 90.10%
Epoch [7/25], Train Loss: 0.2907, Train Accuracy: 91.33%, Test Accuracy: 91.64%
Epoch [8/25], Train Loss: 0.2699, Train Accuracy: 91.96%, Test Accuracy: 91.66%
Epoch [9/25], Train Loss: 0.2597, Train Accuracy: 92.19%, Test Accuracy: 92.32%
Epoch [10/25], Train Loss: 0.2483, Train Accuracy: 92.54%, Test Accuracy: 93.00%
Epoch [11/25], Train Loss: 0.2380, Train Accuracy: 92.92%, Test Accuracy: 92.98%
Epoch [12/25], Train Loss: 0.2297, Train Accuracy: 93.17%, Test Accuracy: 93.40%
Epoch [13/25], Train Loss: 0.2229, Tr

In [12]:
torch_input = torch.randn(32,784)
quantized_module = compile_brevitas_qat_model(
    model, # our model
    torch_input, # a representative input-set to be used for both quantization and compilation
    rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)

In [13]:
from concrete.ml.deployment import FHEModelClient, FHEModelServer,FHEModelDev
fhe_directory = '/Users/prahaladhchandrahasan/Desktop/CMU_Spring2025/EPS/Project/FHE_DIR'

dev = FHEModelDev(path_dir=fhe_directory, model=quantized_module)
dev.save()

In [14]:
client = FHEModelClient(path_dir=fhe_directory, key_dir=fhe_directory)
serialized_evaluation_keys = client.get_serialized_evaluation_keys()

KeySetCache: miss, regenerating /Users/prahaladhchandrahasan/Desktop/CMU_Spring2025/EPS/Project/FHE_DIR/12078945368485275920


In [15]:
server = FHEModelServer(path_dir=fhe_directory)
server.load()

In [16]:
import time 
orginal_label = []
predicted_label = []
count = 0
start = time.time()
for batch_idx, (images, labels) in enumerate(test_loader):
    # Loop through each image in the batch
    for i in range(images.shape[0]):
        # Access individual image
        count = count + 1
        image = images[i]
        label = labels[i]
        orginal_label.append(label.item())
        image = image.view(1,-1).numpy()
        encrypted_data = client.quantize_encrypt_serialize(image)
        encrypted_result = server.run(encrypted_data, serialized_evaluation_keys)
        result = client.deserialize_decrypt_dequantize(encrypted_result)
        pred_label = np.argmax(result, axis=-1)
        predicted_label.append(pred_label[0])
        if count ==1:
            break
    if count==1:
        break
end = time.time()
print("Total time for 1 encrypted image inference", end-start)



Total time for 1 encrypted image inference 29.975343942642212




In [18]:
import time 
orginal_label = []
predicted_label = []
count = 0
start = time.time()
for batch_idx, (images, labels) in enumerate(test_loader):
    # Loop through each image in the batch
    for i in range(images.shape[0]):
        # Access individual image
        count = count + 1
        image = images[i]
        label = labels[i]
        orginal_label.append(label.item())
        image = image.view(1,-1).numpy()
        encrypted_data = client.quantize_encrypt_serialize(image)
        encrypted_result = server.run(encrypted_data, serialized_evaluation_keys)
        result = client.deserialize_decrypt_dequantize(encrypted_result)
        pred_label = np.argmax(result, axis=-1)
        predicted_label.append(pred_label[0])
        if count ==10:
            break
    if count==10:
        break
end = time.time()
print("Total time for 10 encrypted image inference", end-start)



Total time for 10 encrypted image inference 293.4850878715515


