In [1]:
import torch
import ray
from ray import tune
from ray.air import session
from sklearn.model_selection import train_test_split
import yaml
import argparse
import shutil
from collections import namedtuple
import os
import datetime


# Custom Libraries
from utils.data_generator import DataGenerator
from utils.agent import TrainModel
from utils.helper import to_python_native, gen_experiment_name, set_seed, save_model_state
from utils.plotter import alignment_progress_over_iterations, plot_initial_final_accuracies
from utils.pennylane.model import Qkernel

2025-07-24 10:25:03,321	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-07-24 10:25:03,641	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
# === Backend Configuration ===
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.xpu.is_available():
    device = torch.device("xpu")
else:
    device = torch.device("cpu")

set_seed(42)

In [41]:
data_generator = DataGenerator(
        dataset_name= 'wine',
        file_path=None,
    )

training_data, training_labels, testing_data, testing_labels = data_generator.generate_dataset()
training_data = torch.tensor(training_data.to_numpy(), dtype=torch.float32, requires_grad=True)
testing_data = torch.tensor(testing_data.to_numpy(), dtype=torch.float32, requires_grad=True)
training_labels = torch.tensor(training_labels.to_numpy(), dtype=torch.int)
testing_labels = torch.tensor(testing_labels.to_numpy(), dtype=torch.int)

In [42]:
print(training_data.size())
print(testing_data.size())

torch.Size([142, 13])
torch.Size([36, 13])


In [43]:
kernel = Qkernel(
        device='lightning.qubit',
        n_qubits=4,
        trainable=True,
        input_scaling=True,
        data_reuploading=True,
        ansatz='embedding_paper',
        ansatz_layers=5
    )

In [44]:
agent = TrainModel(
        kernel=kernel,
        training_data=training_data,
        training_labels=training_labels,
        testing_data=testing_data,
        testing_labels=testing_labels,
        optimizer='gd',
        lr= 0.1,
        mclr=0.01,
        cclr=0.01,
        epochs=400,
        train_method='ccka',
        target_accuracy=0.95,
        get_alignment_every=100,
        validate_every_epoch=1,
        base_path='./',
        lambda_kao=0.001,
        lambda_co=0.001,
        clusters=8
    )

Epochs:  40


In [45]:
before_metrics = agent.evaluate(testing_data, testing_labels, 'before')
before_metrics

{'alignment': tensor(1.5675, grad_fn=<SqueezeBackward0>),
 'executions': None,
 'training_accuracy': 0.9577464788732394,
 'testing_accuracy': 0.8055555555555556,
 'f1_score': 0.8130839550984743,
 'alignment_arr': [[], []],
 'loss_arr': [],
 'validation_accuracy_arr': []}

In [46]:
agent.fit_multiclass(training_data, training_labels)

Started Training
------------------------------------------------------------------
Epoch: 9 — Alignment per main centroid
Centroid 0 (label=0): Alignment = -0.43751469254493713
Centroid 1 (label=1): Alignment = -0.46624064445495605
Centroid 2 (label=2): Alignment = 0.17968833446502686
------------------------------------------------------------------
------------------------------------------------------------------
Epoch: 19 — Alignment per main centroid
Centroid 0 (label=0): Alignment = -0.4552779495716095
Centroid 1 (label=1): Alignment = -0.5152655839920044
Centroid 2 (label=2): Alignment = 0.3714822232723236
------------------------------------------------------------------


KeyboardInterrupt: 

In [38]:
after_metrics = agent.evaluate(testing_data, testing_labels, 'after')
after_metrics

{'alignment': tensor(1.0494, grad_fn=<SqueezeBackward0>),
 'executions': 14400,
 'training_accuracy': 0.9647887323943662,
 'testing_accuracy': 0.8055555555555556,
 'f1_score': 0.8130839550984743,
 'alignment_arr': [[], []],
 'loss_arr': [],
 'validation_accuracy_arr': []}