# Distributed Quantum Neural Networks in Merlin (arXiv:2505.08474v1)


## Training a conventional CNN on MNIST dataset

In [5]:
#Importing ML libraries

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as utils


from scipy.optimize import minimize


#Quandela libraries
import perceval as pcvl
from merlin import PhotonicBackend, CircuitType, StatePattern, AnsatzFactory, QuantumLayer
from merlin import OutputMappingStrategy
import merlin as ML

#Importing local libraries

from boson_sampler import BosonSampler
from utils import MNIST_partial, accuracy, plot_training_metrics

#Importing system libraries
import time
import random
from datetime import datetime

#Importing other libraries
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
from collections.abc import Iterable
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
#Load the datas

# dataset from csv file, to use for the challenge
train_dataset = MNIST_partial(split = 'train')
val_dataset = MNIST_partial(split='val')

# definition of the dataloader, to process the data in the model
# here, we need a batch size of 1 to use the boson sampler
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size, shuffle = False)

In [12]:

# Define the CNN model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=4)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(4, 4, kernel_size=4)
        self.fc1 = nn.Linear(4*4*4, 10)

        
    def forward(self, x):
        x = self.pool(self.conv1(x))
        x = self.pool(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x


In [13]:
#Define training hyperparameters
learning_rate = 1e-3
num_epochs = 3


# Instantiate the model and loss function
model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

#Count number of parameters in the model
num_classical_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("# of parameters in classical CNN model: ", num_classical_parameter)

# of parameters in classical CNN model:  978


In [5]:
# Training loop
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")


In [6]:
# Testing loop
model.eval()
correct = 0
total = 0
loss_test_list = [] 
with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        loss_test = criterion(outputs, labels).cpu().detach().numpy()
        loss_test_list.append(loss_test)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on the test set: {(100 * correct / total):.2f}%")

Accuracy on the test set: 82.33%


## Reducing the number of parameters to train the CNN by using a photonic quantum neural network

Add Figure 2 of paper, with simpler description.

In [7]:
from torchvision import transforms

#Load the datas
"""transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])"""

train_dataset = MNIST_partial(split='train')
val_dataset = MNIST_partial(split='val')

# definition of the dataloader, to process the data in the model
# here, we need a batch size of 1 to use the boson sampler
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size, shuffle = False)

In [8]:
#The size of the PQNN (photonic quantum neural network) will depend on the number parameters needed in the classical CNN
#We know that a PQCNN of n photons with m modes covers m!/(n!*(m-n)!) outcomes.
#We want num_classical_parameters in the CNN to be lower than number of outcomes in the PQCNN
#In our case we take two PQCNNs so the total number of outcomes, is the product of the number of outcome in each of them : total_pqcnn_params = number_PQCNN1_outcomes * number_PQCNN2_outcomes


#TODO : 
#- Define a function that takes the number of CNN params and number of PQCNN wanted as the input, and output the number of photons and modes for each PQCNN
#- Hyperparams optimization : additional layers, like in original repo ? Learning  rate scheduler ?
#- Add MPS layer to reduce number of params



#In the meantime we hardcode the numbers of photons and modes for that specific case

#n_1 represents the number of photons in the first PQCNN
n_1 = 4
m_1 = 9

n_2 = 4
m_2 = 8

initial_state_1 = [1]*n_1 + [0] * (m_1 - n_1)
initial_state_2 = [1]*n_2 + [0] * (m_2 - n_2)


In [19]:
class QuantumParameterizedCNN(nn.Module):
    def __init__(self, num_classical_parameter):
        super().__init__()

        self.cnn = CNNModel()

        PQCNN_1 = pcvl.GenericInterferometer(m_1,
                                lambda i: pcvl.BS() // pcvl.PS(pcvl.P(f"theta_1i{i}_ps")) // \
                                            pcvl.BS() // pcvl.PS(pcvl.P(f"theta_1o{i}_ps")),
                                shape=pcvl.InterferometerShape.RECTANGLE)

        PQCNN_2 = pcvl.GenericInterferometer(m_2,
                                lambda i: pcvl.BS() // pcvl.PS(pcvl.P(f"theta_2i{i}_ps")) // \
                                            pcvl.BS() // pcvl.PS(pcvl.P(f"theta_2o{i}_ps")),
                                shape=pcvl.InterferometerShape.RECTANGLE)


        self.QL_1 = ML.QuantumLayer(
                    input_size=None,
                    output_size=None,
                    circuit=PQCNN_1,
                    trainable_parameters=["theta"],
                    input_parameters=[],
                    input_state= initial_state_1,
                    no_bunching=False,
                    output_mapping_strategy=OutputMappingStrategy.NONE,
                )

        self.QL_2 = ML.QuantumLayer(
                    input_size=None,
                    output_size=None,
                    circuit=PQCNN_2,
                    trainable_parameters=["theta"],
                    input_parameters=[],
                    input_state= initial_state_2,
                    no_bunching=False,
                    output_mapping_strategy=OutputMappingStrategy.NONE,
                )

        self.num_classical_parameter = num_classical_parameter

        self.register_buffer('param_shapes', torch.tensor([p.numel() for p in self.cnn.parameters()]))

        # Store a template state dict for reference
        self.template_state_dict = self.cnn.state_dict()

    def probs_to_weights(self, probs_):
        """
        Convert probability tensor to CNN weights + preserves gradient
        """
        weight_dict = {}
        data_iterator = probs_.view(-1)

        for name, param in self.template_state_dict.items():
            shape = param.shape
            num_elements = param.numel()

            if len(data_iterator) < num_elements:
                chunk = torch.cat([
                    data_iterator,
                    torch.zeros(num_elements - len(data_iterator), device=data_iterator.device)
                ])[:num_elements].reshape(shape)
            else:
                chunk = data_iterator[:num_elements].reshape(shape)

            weight_dict[name] = chunk  # to preserve the gradient flow
            data_iterator = data_iterator[num_elements:]

        return weight_dict

    def forward(self, x):
        # Get probabilities from quantum layers
        probs1 = self.QL_1()
        probs2 = self.QL_2()

        tensor_product = torch.outer(probs1, probs2).flatten()

        if tensor_product.size(0) < self.num_classical_parameter:
            padded_tensor = torch.cat([
                tensor_product,
                torch.zeros(self.num_classical_parameter - tensor_product.size(0), device=x.device)
            ])
        else:
            padded_tensor = tensor_product[:self.num_classical_parameter]

        # Convert probabilities to real-valued weights
        eps = 1e-6
        tensor_product_clipped = padded_tensor.clamp(eps, 1 - eps)
        new_params = torch.log(tensor_product_clipped / (1 - tensor_product_clipped))
        new_params = torch.tanh(new_params)  # Optional scaling

        # Convert to CNN weights + preserve gradients
        weight_dict = self.probs_to_weights(new_params)

        # FUNCTIONAL forward pass (preserves gradients)
        return self.forward_functional(x, weight_dict)

    def forward_functional(self, x, weight_dict):
        """
        Functional forward pass that preserves gradients to quantum layers
        """
        
        # Conv1 + Pool
        x = F.conv2d(x, weight_dict['conv1.weight'], weight_dict.get('conv1.bias', None))
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        # Conv2 + Pool
        x = F.conv2d(x, weight_dict['conv2.weight'], weight_dict.get('conv2.bias', None))
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        # Flatten + FC
        x = torch.flatten(x, 1)
        x = F.linear(x, weight_dict['fc1.weight'], weight_dict.get('fc1.bias', None))

        return x

In [20]:
learning_rate = 0.01
num_epochs = 10

# Instantiate the model
#qmodel = HQCNNModel_v2()
qmodel = QuantumParameterizedCNN(num_classical_parameter)

# Instantiate loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(qmodel.parameters(), lr=learning_rate)

In [23]:
# Training loop
# Fixed: proper accuracy calculation and gradient clipping
for epoch in range(num_epochs):
    correct = 0  # Reset for each epoch
    total = 0    # Reset for each epoch
    
    for i, (images, labels) in enumerate(train_loader):
        qmodel.train()
        optimizer.zero_grad()
        outputs = qmodel(images)
        loss = criterion(outputs, labels)
        loss.backward()
        ql1_grads = [p.grad for p in qmodel.parameters() if p.requires_grad]
        #print(ql1_grads)
        
        # Gradient clipping to prevent exploding gradients
        max_norm = 1.0
        utils.clip_grad_norm_(qmodel.parameters(), max_norm)

        optimizer.step()
        
        # Calculate training accuracy for this batch
        with torch.no_grad():
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        if (i+1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    # Print epoch accuracy
    epoch_acc = correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}] completed - Train ACC: {epoch_acc:.4f}")

Epoch [1/10], Step [100/375], Loss: 2.3260
Epoch [1/10], Step [200/375], Loss: 2.3112
Epoch [1/10], Step [300/375], Loss: 2.2987
Epoch [1/10] completed - Train ACC: 0.1288
Epoch [2/10], Step [100/375], Loss: 2.1961
Epoch [2/10], Step [200/375], Loss: 2.3373
Epoch [2/10], Step [300/375], Loss: 2.3675
Epoch [2/10] completed - Train ACC: 0.1357
Epoch [3/10], Step [100/375], Loss: 2.3440
Epoch [3/10], Step [200/375], Loss: 2.3429
Epoch [3/10], Step [300/375], Loss: 2.2669
Epoch [3/10] completed - Train ACC: 0.1448
Epoch [4/10], Step [100/375], Loss: 2.3472
Epoch [4/10], Step [200/375], Loss: 2.2338
Epoch [4/10], Step [300/375], Loss: 2.3031
Epoch [4/10] completed - Train ACC: 0.1415
Epoch [5/10], Step [100/375], Loss: 2.3008
Epoch [5/10], Step [200/375], Loss: 2.2822
Epoch [5/10], Step [300/375], Loss: 2.3310
Epoch [5/10] completed - Train ACC: 0.1393
Epoch [6/10], Step [100/375], Loss: 2.2549
Epoch [6/10], Step [200/375], Loss: 2.1543
Epoch [6/10], Step [300/375], Loss: 2.2807
Epoch [6/10

In [22]:
# Testing loop
qmodel.eval()
correct = 0
total = 0
loss_test_list = []
with torch.no_grad():
    for images, labels in val_loader:
        outputs = qmodel(images)
        loss_test = criterion(outputs, labels).cpu().detach().numpy()
        loss_test_list.append(loss_test)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on the test set: {(100 * correct / total):.2f}%")

Accuracy on the test set: 8.83%
