In [47]:
import torch
import json
import os
import warnings
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
import torchquantum as tq
import numpy as np
import math
from torch.optim import Adam
from torchquantum.measurement import expval_joint_analytical

warnings.simplefilter("ignore")

In [48]:
def load_data_setting(setting):
    folder = "../exps"
    json_file = "data_config.json"
    json_path = os.path.join(folder, setting, json_file)
    with open(json_path, "r") as f:
        config = json.load(f)

    permutation_seed = config["permutation_seed"]
    test_size = config["test_size"]
    partition_seed = config["partition_seed"]
    n_class = config["n_class"]
    
    data_tensors = config["data_tensors"]
    loaded_data_tensors = torch.load(data_tensors)
    data_tr = loaded_data_tensors["data_tr"]
    label_tr = loaded_data_tensors["label_tr"]
    data_te = loaded_data_tensors["data_te"]
    label_te = loaded_data_tensors["label_te"]

    return permutation_seed, test_size, partition_seed, n_class, data_tr, label_tr, data_te, label_te

In [49]:
def create_training_dataloader(data_tr, label_tr, batch_size):
    training_dataset = TensorDataset(data_tr, label_tr)
    training_dataloader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
    return training_dataloader

### Exp1: 10 classes, 8 clients

In [50]:
# Changing the expectations values to concat the clients' outputs
class QNNsubModel(nn.Module):
    def __init__(self, n_qubits=8, n_block=5, n_depth_per_block=1):
        # params is numpy array
        super().__init__()
        self.n_wires = n_qubits
        self.encoder_gates_x = ([tq.functional.rx] * self.n_wires + [tq.functional.ry] * self.n_wires)*2
        self.n_block = n_block
        self.n_depth_per_block = n_depth_per_block
        params = np.random.rand( self.n_wires*self.n_depth_per_block*self.n_block*2)*math.pi
        self.u_layers = tq.QuantumModuleList()
        for j in range(self.n_depth_per_block*self.n_block):
            for i in range(self.n_wires):
                self.u_layers.append( tq.RX(has_params=True, trainable=True, init_params=params[i+(2*j)*self.n_wires]) )
            for i in range(self.n_wires):
                self.u_layers.append( tq.RY(has_params=True, trainable=True, init_params=params[i+(2*j+1)*self.n_wires]) )

    def forward(self, x):
        bsz, nx_features = x.shape
        qdev = tq.QuantumDevice(
            n_wires=self.n_wires, bsz = bsz, device=x.device, record_op=False
        )
        n_depth_per_block = self.n_depth_per_block
        for d in range(self.n_block-1): # (2,4)
            for k in range(n_depth_per_block):
                for j in range(2*d*n_depth_per_block+2*k,2*d*n_depth_per_block+2*k+2):
                    for i in range(self.n_wires):
                        self.u_layers[i+j*self.n_wires](qdev, wires=i)
                for i in range(self.n_wires):
                    qdev.cz(wires=[i,(i+1)%self.n_wires])
            # data encoding
            #for j in range(2*d,2*d+1): # (0,2) (2,4)
            for k in range(self.n_wires):
                    #self.encoder_gates_x[k+j*self.n_wires](qdev, wires=k, params=x[:, (k+j*self.n_wires)])
                index = k + d * self.n_wires
                self.encoder_gates_x[index](qdev, wires=k, params=x[:, (index)])
            for i in range(self.n_wires):
                qdev.cz(wires=[i,(i+1)%self.n_wires])
        for d in range(self.n_block-1,self.n_block): # (4,5)
            for k in range(n_depth_per_block):
                for j in range(2*d*n_depth_per_block+2*k,2*d*n_depth_per_block+2*k+2):
                    for i in range(self.n_wires):
                        self.u_layers[i+j*self.n_wires](qdev, wires=i)
                if k==n_depth_per_block-1:
                    break
                for i in range(self.n_wires):
                    qdev.cz(wires=[i,(i+1)%self.n_wires])

        obs_list = [ expval_joint_analytical(qdev, "I"*i+Pauli+"I"*(self.n_wires-1-i)) for Pauli in ["Z"] for i in range(self.n_wires)]
        ret = torch.stack(obs_list, dim=1)
        return ret

In [51]:
class QantumServerSubModel(nn.Module):
    def __init__(self, n_qubits=8, n_block=5, n_depth_per_block=1):
        # params is numpy array
        super().__init__()
        self.n_wires = n_qubits
        self.encoder_gates_x = ([tq.functional.rx] * self.n_wires + [tq.functional.ry] * self.n_wires)*4
        self.n_block = n_block
        self.n_depth_per_block = n_depth_per_block
        params = np.random.rand( self.n_wires*self.n_depth_per_block*self.n_block*2)*math.pi
        self.u_layers = tq.QuantumModuleList()
        for j in range(self.n_depth_per_block*self.n_block):
            for i in range(self.n_wires):
                self.u_layers.append( tq.RX(has_params=True, trainable=True, init_params=params[i+(2*j)*self.n_wires]) )
            for i in range(self.n_wires):
                self.u_layers.append( tq.RY(has_params=True, trainable=True, init_params=params[i+(2*j+1)*self.n_wires]) )

    def forward(self, x):
        bsz, nx_features = x.shape
        qdev = tq.QuantumDevice(
            n_wires=self.n_wires, bsz = bsz, device=x.device, record_op=False
        )
        n_depth_per_block = self.n_depth_per_block
        for d in range(self.n_block-1): # (2,4)
            for k in range(n_depth_per_block):
                for j in range(2*d*n_depth_per_block+2*k,2*d*n_depth_per_block+2*k+2):
                    for i in range(self.n_wires):
                        self.u_layers[i+j*self.n_wires](qdev, wires=i)
                for i in range(self.n_wires):
                    qdev.cz(wires=[i,(i+1)%self.n_wires])
            # data encoding
            for j in range(2*d,2*d+2): # (0,2) (2,4)
                for k in range(self.n_wires):
                #for k in range(2):
                    self.encoder_gates_x[k+j*self.n_wires](qdev, wires=k, params=x[:, (k+j*self.n_wires)])
                    #index = k + j * 2
                    #self.encoder_gates_x[index](qdev, wires=k, params=x[:, (index)])
            for i in range(self.n_wires):
                qdev.cz(wires=[i,(i+1)%self.n_wires])
        for d in range(self.n_block-1,self.n_block): # (4,5)
            for k in range(n_depth_per_block):
                for j in range(2*d*n_depth_per_block+2*k,2*d*n_depth_per_block+2*k+2):
                    for i in range(self.n_wires):
                        self.u_layers[i+j*self.n_wires](qdev, wires=i)
                if k==n_depth_per_block-1:
                    break
                for i in range(self.n_wires):
                    qdev.cz(wires=[i,(i+1)%self.n_wires])

        obs_list = [ expval_joint_analytical(qdev, "I"*i+Pauli+"I"*(self.n_wires-1-i)) for Pauli in ["Z", "X"] for i in range(n_class//2)]
        ret = torch.stack(obs_list, dim=1)
        return ret

In [52]:
class QuantumServer(nn.Module):
    def __init__(self):
        super().__init__()
        # self.coeff = coeff
        self.softmax = nn.Softmax(dim=1)
        #self.fc1 = nn.Linear(in_features=16*4, out_features=10)
        self.q = QantumServerSubModel().to(device)
        #self.q = QantumServerSubModel_96()

    def forward(self, clients_outputs):
        # result = sum(clients_outputs)
        # result = result * self.coeff
        # result = self.softmax(result)
        concatenated_result = torch.cat(clients_outputs, dim=1)
        #result = self.fc1(concatenated_result)
        result = self.q(concatenated_result)
        result = self.softmax(result)

        return result

In [53]:
def train(data, label, clients_models, clients_optimizers, server_model, server_optimizer):
    for i, client_model in enumerate(clients_models.values()):
        client_model.train(mode=True)
    server_model.train(mode=True)

    for key, client_optimizer in clients_optimizers.items():
        client_optimizer.zero_grad()
    server_optimizer.zero_grad()

    num_clients = len(clients_models)
    features_per_client = data.shape[1] // num_clients
    #features_per_client = 64
    clients_data = [data[:, i * features_per_client:(i + 1) * features_per_client] for i in range(num_clients)]
    print(clients_data[0].shape)

    clients_outputs = []
    for i, client_model in enumerate(clients_models.values()):
        client_pred = client_model(clients_data[i])
        # print("Client pred shape:", client_pred.shape)
        clients_outputs.append(client_pred)

    #result = sum(clients_outputs)
    # print("Client 1 outuput shape:", clients_outputs[0].shape)
    result = server_model(clients_outputs)
    # result = clients_outputs[0] + clients_outputs[1] + clients_outputs[2] + clients_outputs[3]
    # result = server_model(result)
    # print("Result shape:", result.shape)
    loss = torch.nn.CrossEntropyLoss()(result, label)
    acc = (result.argmax(axis=1) == label).sum().item() / len(label)
    # acc = accuracy_score(y_tr, pred.argmax(axis=1).cpu().detach().numpy() )
    print(f"train loss: {loss.item():.5f}, train acc: {acc:.3f}", end=' ')
    loss.backward()
    for key, client_optimizer in clients_optimizers.items():
        client_optimizer.step()
    server_optimizer.step()

    return loss.item(), acc


def test(data, label, clients_models, server_model):
    num_clients = len(clients_models)
    features_per_client = data.shape[1] // num_clients
    #features_per_client = 64
    clients_data = [data[:, i * features_per_client:(i + 1) * features_per_client] for i in range(num_clients)]

    clients_outputs = []
    for i, client_model in enumerate(clients_models.values()):
        client_model.train(mode=False)
        with torch.no_grad():
            client_pred = client_model(clients_data[i])
        clients_outputs.append(client_pred)

    #result = sum(clients_outputs)
    with torch.no_grad():
        result = server_model(clients_outputs)
    # result = clients_outputs[0] + clients_outputs[1] + clients_outputs[2] + clients_outputs[3]
    # result = server_model(result)
    loss = torch.nn.CrossEntropyLoss()(result, label)
    acc = (result.argmax(axis=1) == label).sum().item() / len(label)
    print(f"test loss: {loss.item():.5f} test acc: {acc:.4f}")
    return loss.item(), acc

In [91]:
def run_n_save_exp_8clients(setting_folder, save_path):

    # Load data
    permutation_seed, test_size, partition_seed, n_class, data_tr, label_tr, data_te, label_te = load_data_setting(setting_folder)
    print("PERMUTATION SEED", permutation_seed)
    training_dataloader = create_training_dataloader(data_tr, label_tr, batch_size=32)

    
    # Hyperparameters
    max_epochs = 2
    lr = 0.002
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    
    # Models' initialization
    server_model = QuantumServer().to(device)
    server_optimizer = Adam(server_model.parameters(), lr=lr)
    client1_model = QNNsubModel().to(device)
    client1_optimizer = Adam(client1_model.parameters(), lr=lr)
    client2_model = QNNsubModel().to(device)
    client2_optimizer = Adam(client2_model.parameters(), lr=lr)
    client3_model = QNNsubModel().to(device)
    client3_optimizer = Adam(client3_model.parameters(), lr=lr)
    client4_model = QNNsubModel().to(device)
    client4_optimizer = Adam(client4_model.parameters(), lr=lr)
    client5_model = QNNsubModel().to(device)
    client5_optimizer = Adam(client5_model.parameters(), lr=lr)
    client6_model = QNNsubModel().to(device)
    client6_optimizer = Adam(client6_model.parameters(), lr=lr)
    client7_model = QNNsubModel().to(device)
    client7_optimizer = Adam(client7_model.parameters(), lr=lr)
    client8_model = QNNsubModel().to(device)
    client8_optimizer = Adam(client8_model.parameters(), lr=lr)
    
    clients_models = {"client1_model": client1_model,
                     "client2_model": client2_model,
                     "client3_model": client3_model,
                     "client4_model": client4_model,
                     "client5_model": client5_model,
                     "client6_model": client6_model,
                     "client7_model": client7_model,
                     "client8_model": client8_model}
    clients_optimizers = {"client1_optimizer": client1_optimizer,
                         "client2_optimizer": client2_optimizer,
                         "client3_optimizer": client3_optimizer,
                         "client4_optimizer": client4_optimizer,
                         "client5_optimizer": client5_optimizer,
                         "client6_optimizer": client6_optimizer,
                         "client7_optimizer": client7_optimizer,
                         "client8_optimizer": client8_optimizer}

    
    # Trainining and testing
    all_tr_loss = []
    all_test_loss = []
    all_tr_acc = []
    all_test_acc = []
    for i_epoch in range(max_epochs):
        epoch_loss = 0.0 
        epoch_accuracy = 0.0 
        total_samples = 0 
        print(f"Epoch {i_epoch}:", end=" ")
        for batch_X, batch_y in training_dataloader:
            print(batch_X.shape, batch_y.shape)
            loss_tr, acc_tr = train(batch_X, batch_y, clients_models, clients_optimizers, server_model, server_optimizer)
            batch_size = batch_X.size(0)
            epoch_loss += loss_tr * batch_size  # Scale loss by batch size
            epoch_accuracy += acc_tr * batch_size  # Scale accuracy by batch size
            total_samples += batch_size  # Update sample count
        epoch_loss /= total_samples
        epoch_accuracy /= total_samples
        print(f"Training loss:{epoch_loss:.4f} Training acc:{epoch_accuracy:.4f}")
        loss_test, acc_test = test(data_te, label_te, clients_models, server_model)
        all_tr_loss.append(epoch_loss)
        all_test_loss.append(loss_test)
        all_tr_acc.append(epoch_accuracy)
        all_test_acc.append(acc_test)

    
    # Save models & parameters
    models_tensors_path = save_path+"/"+"models_tensors.pth"
    torch.save({
        "client1_model": clients_models["client1_model"].state_dict(),
        "client2_model": clients_models["client2_model"].state_dict(),
        "client3_model": clients_models["client2_model"].state_dict(),
        "client4_model": clients_models["client2_model"].state_dict(),
        "client5_model": clients_models["client2_model"].state_dict(),
        "client6_model": clients_models["client2_model"].state_dict(),
        "client7_model": clients_models["client2_model"].state_dict(),
        "client8_model": clients_models["client2_model"].state_dict(),
        "server_model": server_model.state_dict(),
    }, models_tensors_path)
    
    params = {
        "permutation_seed": permutation_seed, 
        "test_size": test_size, 
        "partition_seed": partition_seed,
        "n_class": n_class,
        "all_tr_loss": all_tr_loss, 
        "all_test_loss": all_test_loss, 
        "all_tr_acc": all_tr_acc, 
        "all_test_acc": all_test_acc,
        "models_tensors": models_tensors_path
    }
    
    exp_path = save_path+"/"+"exp.json"
    with open(exp_path, "w") as f:
        json.dump(params, f, indent=4)



In [92]:
settings = ["Setting_1", "Setting_2"]
root = "../exps"

for setting in settings:
    path = os.path.join(root, setting)
    os.makedirs(path, exist_ok=True)
    run_n_save_exp_8clients(setting, path)

PERMUTATION SEED 42
Epoch 0: torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.30984, train acc: 0.062 torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.30361, train acc: 0.094 torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.30355, train acc: 0.125 torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.29577, train acc: 0.156 torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.31098, train acc: 0.031 torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.30721, train acc: 0.062 torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.30266, train acc: 0.156 torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.30255, train acc: 0.125 torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.30789, train acc: 0.062 torch.Size([32, 256]) torch.Size([32])
torch.Size([32, 32])
train loss: 2.30611, train ac