This is a toy example for using SynaptogenML. 
---------------------------------------------

Please not that in order to have a minimal example, we are using MNIST here. This purpose is to have a usage example only. As stated in out paper, we not believe that it is a good idea to draw any conclusion from experiments with MNIST, especially with a highly simplified network as we are using here.

This example should be used as a playground to understand how SynaptogenML can be used ,and which adjustable parameters are available.

For proper example setups with SynaptogenML, please have a look at our ASR examples to our publication as listed here: https://github.com/rwth-i6/returnn-experiments/tree/master/2025-memristor-asr


In [None]:
import copy
import sys
import time
from typing import Optional

import torch
from torch import nn
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

In [None]:
# please make sure to install the repo itself
from synaptogen_ml.quant_modules import LinearQuant, ActivationQuantizer
from synaptogen_ml.memristor_modules import TiledMemristorLinear, DacAdcHardwareSettings

In [None]:
def create_mnist_dataloaders(batch_size):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    print("download training data")
    dataset1 = MNIST("./", train=True, download=True, transform=transform)
    print("download testing data")
    dataset2 = MNIST("./", train=False, transform=transform)
    print("prepare dataloaders")
    dataloader_train = DataLoader(
        dataset=dataset1,
        batch_size=batch_size,
        shuffle=True,
    )

    dataloader_test = DataLoader(
        dataset=dataset2,
        batch_size=batch_size,
        shuffle=False,
    )
    return dataloader_train, dataloader_test

In [None]:
class QuantizedModel(nn.Module):
    def __init__(self, weight_precision: int = 4):
        super().__init__()

        self.weight_precision = weight_precision

        self.linear_1 = LinearQuant(
            in_features=28 * 28,
            out_features=512,
            weight_bit_prec=self.weight_precision,
            weight_quant_dtype=torch.qint8,
            weight_quant_method="per_tensor_symmetric",
            bias=False,
        )
        self.final_linear = LinearQuant(
            in_features=512,
            out_features=10,
            weight_bit_prec=self.weight_precision,
            weight_quant_dtype=torch.qint8,
            weight_quant_method="per_tensor_symmetric",
            bias=False,
        )

        self.activation_quant_l1_in = ActivationQuantizer(
            bit_precision=8,
            dtype=torch.qint8,
            method="per_tensor_symmetric",
            channel_axis=None,
            moving_avrg=None,
            reduce_range=False,
        )

        self.activation_quant_l1_out = ActivationQuantizer(
            bit_precision=8,
            dtype=torch.qint8,
            method="per_tensor_symmetric",
            channel_axis=None,
            moving_avrg=None,
            reduce_range=False,
        )

        self.activation_quant_final_in = ActivationQuantizer(
            bit_precision=8,
            dtype=torch.qint8,
            method="per_tensor_symmetric",
            channel_axis=None,
            moving_avrg=None,
            reduce_range=False,
        )

        self.activation_quant_final_out = ActivationQuantizer(
            bit_precision=8,
            dtype=torch.qint8,
            method="per_tensor_symmetric",
            channel_axis=None,
            moving_avrg=None,
            reduce_range=False,
        )

        self.memristor_linear_1 = None
        self.memristor_final = None

    def forward(self, image, use_memristor=False):
        inp = torch.reshape(image, shape=(-1, 28 * 28))
        if use_memristor:
            linear_out = self.memristor_linear_1(inp)
        else:
            linear_out = self.linear_1(self.activation_quant_l1_in(inp))
        out1 = nn.functional.tanh(self.activation_quant_l1_out(linear_out))
        if use_memristor:
            logits = self.memristor_final(out1)
        else:
            logits = self.final_linear(self.activation_quant_final_in(out1))
        quant_out = self.activation_quant_final_out(logits)
        return quant_out

    def prepare_memristor(self, hardware_settings, mem_array_inputs, mem_array_outputs):
        self.memristor_linear_1 = TiledMemristorLinear(
            in_features=28 * 28,
            out_features=512,
            weight_precision=self.weight_precision,
            converter_hardware_settings=hardware_settings,
            memristor_inputs=mem_array_inputs,
            memristor_outputs=mem_array_outputs,
        )
        self.memristor_final = TiledMemristorLinear(
            in_features=512,
            out_features=10,
            weight_precision=self.weight_precision,
            converter_hardware_settings=hardware_settings,
            memristor_inputs=mem_array_inputs,
            memristor_outputs=mem_array_outputs,
        )
        self.memristor_linear_1.init_from_linear_quant(
            self.activation_quant_l1_in, self.linear_1, num_cycles=0
        )
        self.memristor_final.init_from_linear_quant(
            self.activation_quant_final_in, self.final_linear, num_cycles=0
        )

In [None]:
BATCH_SIZE = 10
dataloader_train, dataloader_test = create_mnist_dataloaders(BATCH_SIZE)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device: %s" % device)

In [None]:
def run_training(
    model: nn.Module,
    hardware_settings: Optional[DacAdcHardwareSettings],
    mem_array_inputs: int = 64,
    mem_array_outputs: int = 64,
    num_epochs: int = 5,
    batch_size: int = 10,
    include_memristor_evaluation: bool = True,
):
    """

    :return:
    """
    if include_memristor_evaluation:
        assert hardware_settings is not None

    model.to(device=device)
    optimizer = torch.optim.RAdam(lr=1e-4, params=model.parameters())

    memristor_accs = []

    for i in range(num_epochs):
        print("\nstart train epoch %i" % i)
        total_ce = 0
        total_acc = 0
        num_examples = 0
        model.to(device=device)
        model.train()

        for data in dataloader_train:
            image, labels = data
            num_examples += image.shape[0]
            if device == "cpu" and num_examples > 2000:
                # do not train so much on CPU
                break
            image = image.to(device=device)
            labels = labels.to(device=device)
            logits = model.forward(image)
            ce = nn.functional.cross_entropy(logits, target=labels, reduction="sum")
            total_ce += ce.detach().cpu()
            acc = torch.sum(torch.eq(torch.argmax(logits, dim=-1), labels).int())
            total_acc += acc.detach().cpu()
            ce.backward()

            optimizer.step()
            optimizer.zero_grad()

        print(
            f"train ce: {total_ce / num_examples:.3f} acc: {total_acc / num_examples:.3f}"
        )
        total_ce = 0
        total_acc = 0
        num_examples = 0
        model.eval()
        print("\nstart normal quantized evaluation")
        start = time.time()
        for data in dataloader_test:
            start_tmp = time.time()
            image, labels = data
            image = image.to(device=device)
            labels = labels.to(device=device)
            num_examples += image.shape[0]
            with torch.no_grad():
                logits = model.forward(image)
            ce = nn.functional.cross_entropy(logits, target=labels, reduction="sum")
            total_ce += ce.detach().cpu()
            acc = torch.sum(torch.eq(torch.argmax(logits, dim=-1), labels).int())
            total_acc += acc.detach().cpu()
        end_float = time.time() - start
        end_float_avg = end_float / num_examples

        print(
            f"Normal-quant test ce: {total_ce / num_examples:.6f}, acc: {total_acc / num_examples:.6f}, time: {end_float:.2f}s, per sample: {end_float_avg:.2f}s"
        )

        model.prepare_memristor(
            hardware_settings=hardware_settings,
            mem_array_inputs=mem_array_inputs,
            mem_array_outputs=mem_array_outputs,
        )
        model.to(device=device)

        if include_memristor_evaluation:
            print("\nstart memristor evaluation")
            start = time.time()
            for data in dataloader_test:
                start_tmp = time.time()
                image, labels = data
                image = image.to(device=device)
                labels = labels.to(device=device)
                num_examples += image.shape[0]
                with torch.no_grad():
                    logits = model.forward(image, use_memristor=True)
                ce = nn.functional.cross_entropy(logits, target=labels, reduction="sum")
                total_ce += ce.detach().cpu()
                acc = torch.sum(torch.eq(torch.argmax(logits, dim=-1), labels).int())
                total_acc += acc.detach().cpu()
            end_float = time.time() - start
            end_float_avg = end_float / num_examples

            memristor_acc = total_acc / num_examples
            memristor_accs.append(memristor_acc)
            print(
                f"test memristor ce: {total_ce / num_examples:.6f}, acc: {memristor_acc:.6f}, time: {end_float:.2f}s, per sample: {end_float_avg:.2f}s"
            )

    assert any(acc >= expected_accuracy for acc in memristor_accs), (
        f"accuracy too low: {max(memristor_accs):.2f} <= {expected_accuracy:.2f}"
    )

In [None]:
hardware_settings = DacAdcHardwareSettings(
    input_bits=8,
    output_precision_bits=4,
    output_range_bits=4,
    hardware_input_vmax=0.6,
    hardware_output_current_scaling=8020.0,
)

model = QuantizedModel(weight_precision=4)

In [None]:
run_training(model=model, hardware_settings=hardware_settings)