In [2]:
import numpy as np
import torch
from sklearn.datasets import make_moons
from sklearn.preprocessing import StandardScaler

In [3]:
def create_moon_dataset(n_samples=1000):
    # Create moon-shaped data
    X, y = make_moons(n_samples=n_samples, noise=0.1)
    
    # Scale the features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    # Convert to torch tensors
    X = torch.FloatTensor(X)
    # Convert to one-hot encoding
    y_onehot = torch.zeros(n_samples, 2)
    y_onehot[range(n_samples), y] = 1
    
    return X, y_onehot

# Create and split the data
X, y = create_moon_dataset()  # or create_moon_dataset(1000)

# Split into train and test
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

In [4]:
X[:5]

tensor([[-0.7101,  1.3707],
        [ 0.6174, -0.8954],
        [-0.0175,  1.2468],
        [-0.5357, -0.9086],
        [-1.1670,  1.4153]])

In [5]:
y[:5]

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.]])

In [6]:
IN_FEAT = 2
OUT_FEAT = 2

class SimpleNet(torch.nn.Module):
    """Simple MLP with PyTorch"""

    def __init__(self, n_hidden = 30):
        super().__init__()
        self.fc1 = torch.nn.Linear(in_features=IN_FEAT, out_features=n_hidden)
        self.fc2 = torch.nn.Linear(in_features=n_hidden, out_features=n_hidden)
        self.fc3 = torch.nn.Linear(in_features=n_hidden, out_features=OUT_FEAT)


    def forward(self, x):
        """Forward pass."""
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [7]:
def train_model(model, X_train, y_train, epochs=10):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in range(epochs):
        outputs = model(X_train)
        loss = criterion(outputs, y_train)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

In [8]:
model = SimpleNet()

In [9]:
# Train the model
train_model(model, X_train, y_train)

# Test predictions
with torch.no_grad():
    test_predictions = model(X_test)
    accuracy = ((test_predictions.argmax(dim=1) == y_test.argmax(dim=1)).float().mean())
    print(f"Test accuracy: {accuracy:.4f}")

Epoch [10/10], Loss: 0.4718
Test accuracy: 0.5750


In [10]:
from brevitas import nn as qnn
from brevitas.core.quant import QuantType
from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat

N_BITS = 4
IN_FEAT = 2
OUT_FEAT = 2
    
class QuantSimpleNet(torch.nn.Module):
    def __init__(
        self,
        n_hidden = 30,
        qlinear_args={
            "weight_bit_width": N_BITS,
            "weight_quant": Int8WeightPerTensorFloat,
            "bias": True,
            "bias_quant": None,
            "narrow_range": True
        },
        qidentity_args={"bit_width": N_BITS, "act_quant": Int8ActPerTensorFloat},
    ):
        super().__init__()

        self.quant_inp = qnn.QuantIdentity(**qidentity_args)
        self.fc1 = qnn.QuantLinear(IN_FEAT, n_hidden, **qlinear_args)
        self.relu1 = qnn.QuantReLU(bit_width=qidentity_args["bit_width"])
        self.fc2 = qnn.QuantLinear(n_hidden, n_hidden, **qlinear_args)
        self.relu2 = qnn.QuantReLU(bit_width=qidentity_args["bit_width"])
        self.fc3 = qnn.QuantLinear(n_hidden, OUT_FEAT, **qlinear_args)

        for m in self.modules():
            if isinstance(m, qnn.QuantLinear):
                torch.nn.init.uniform_(m.weight.data, -1, 1)

    def forward(self, x):
        x = self.quant_inp(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x       


In [11]:
qntmodel = QuantSimpleNet()
train_model(qntmodel, X_train, y_train)

# Test predictions
with torch.no_grad():
    test_predictions = qntmodel(X_test)
    accuracy = ((test_predictions.argmax(dim=1) == y_test.argmax(dim=1)).float().mean())
    print(f"Test accuracy: {accuracy:.4f}")

Epoch [10/10], Loss: 27.6828
Test accuracy: 0.5350


  return super().rename(names)


In [12]:
from concrete.ml.torch.compile import compile_brevitas_qat_model

N_FEAT = 2

torch_input = torch.randn(100, N_FEAT)
quantized_module = compile_brevitas_qat_model(
    qntmodel, # our model
    torch_input, # a representative input-set to be used for both quantization and compilation
    rounding_threshold_bits={"n_bits": N_BITS, "method": "approximate"}
)


In [13]:
y_pred = quantized_module.forward(X_test.numpy(), fhe="execute")

In [16]:
print((test_predictions.argmax(dim=1) == y_test.argmax(dim=1)).float().mean())


tensor(0.5350)
