In this approach, we implement a quantum inspired classification method using the principles of Quantum Anharmonic Oscillator. Here we model the classical MNIST dataset images using the linear combination of terms in position and momentum operators to resemble the potentrial and kinetic energy of an anharmonic oscillator.

In this approach we have implemented 2 cases and remaninig 2 cases are improvements done on the 1st two cases.

Data preparation:

We use the MNIST dataset consisting of 60000 training images and 1000 tet images.The original size of the images is (784,). We reshape the images to size (28,28), normalize the images and then convert the images into hamiltonians using 4 methods

1) H = (A + A.T) / 2

2)H = AA.T

3)H = outer product of flattened image vectors

4) H = -i*log(V) where V is a unitary matrix.

Then all the hamiltonians are seperated to 10 classes on the basis of the digits from 0 to 9.

In [None]:
# Data Preprocessing of the MNIST Dataset to produce the train and test normalized Hamiitonians...
# We can construct the hamiltonians from the four methods described in the paper...
import numpy as np
from skimage.transform import resize
import matplotlib.pyplot as plt
from PIL import Image
import torch
from sklearn.datasets import fetch_openml
import scipy
from tensorflow.keras.datasets import mnist

# Load MNIST using Keras
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# # Reshape and convert to float64 for consistency
x_train = x_train.reshape(-1, 784).astype(np.float64)
x_test = x_test.reshape(-1, 784).astype(np.float64)
print("Train:", x_train.shape, y_train.shape)
print("Test:", x_test.shape, y_test.shape)

# ----------------------------
# Helper functions
# ----------------------------
def separate_digits(images, labels):
    """Group images by digit label."""
    digit_image = {d: [] for d in range(10)}
    for img, lbl in zip(images, labels):
        digit_image[lbl].append(img)
    return digit_image

def resize_images_batch(images, new_size=(8, 8), batch_size=500):
    """Resize a batch of flattened 28x28 images to new_size."""
    n = len(images)
    resized = []
    for i in range(0, n, batch_size):
        batch = images[i:i+batch_size]
        resized_batch = [resize(img.reshape(28,28), new_size).flatten() for img in batch]
        resized.extend(resized_batch)
    return np.array(resized)

def normalize_batch(images):
    """Normalize each image vector."""
    norms = np.linalg.norm(images, axis=1, keepdims=True)
    return images / norms

#Creating Hamiltonian using outer product method
def density_matrix_batch(images):
    """Convert vectors to density matrices."""
    return np.matmul(images[:,:,np.newaxis], images[:,np.newaxis,:])

#Creating the Hamiltonian using H = A + A.T/2 method
def hamiltonian_symmetric_batch(images):
  N,D = images.shape
  H_list = []
  for i in range(N):
    a = images[i]
    A = np.outer(a,np.ones(D))
    H = (A + A.conj().T) / 2
    H_list.append(H)

  return np.array(H_list)

#Creating the Hamiltonian using H = A @ A.T method
def hamiltonian_product_batch(images):
  N , D = images.shape
  H_list = []
  for i in range(N):
    a = images[i]
    A = np.outer(a,np.ones(D))
    H = A @ A.T
    H_list.append(H)
  return np.array(H_list)

import scipy.linalg
#Creating the Hamiltonian using H = -i * log(V) method
def hamiltonian_using_log(images):
    def _make_hermitian(M):
        return 0.5 * (M + M.conj().transpose(-2,-1))

    def _make_unitary(M):
        H = _make_hermitian(M)
        return torch.matrix_exp(-1j*H)

    N,D = images.shape
    hamiltonians = np.zeros((N,D,D),dtype = np.complex128)
    for i in range(N):
        image = images[i]
        mat = np.diag(image)
        mat_torch = torch.tensor(mat,dtype = torch.complex128)
        H = _make_unitary(mat_torch)
        hamiltonians[i] = H
    return hamiltonians


# ----------------------------
# Process training data
# ----------------------------
digit_images_dict = separate_digits(x_train, y_train)
resized_digit_images = {}
normalized_digit_images = {}
density_matrices = {}

for digit, imgs in digit_images_dict.items():
    imgs = np.array(imgs)
    imgs_resized = resize_images_batch(imgs, new_size=(8,8), batch_size=500)
    imgs_normalized = normalize_batch(imgs_resized)
    print(f"normalized_images shape:- {imgs_normalized.shape}")
    density1 = density_matrix_batch(imgs_normalized)
    print(f"shape 1:- {density1.shape}")
    #OR
    #density2 = hamiltonian_symmetric_batch(imgs_normalized)
    #print(f"shape 2:- {density2.shape}")
    #OR
    #density3 = hamiltonian_product_batch(imgs_normalized)
    #print(f"shape 3:- {density3.shape}")
    #OR
    #density4 = hamiltonian_using_log(imgs_normalized)
    #print(f"shape 4:- {density4.shape}")
    density1 /= np.linalg.norm(density1, axis=(1,2), keepdims=True)
    #density2 /= np.linalg.norm(density2, axis=(1,2), keepdims=True)
    #density3 /= np.linalg.norm(density3, axis=(1,2), keepdims=True)
    #density4 /= np.linalg.norm(density4, axis=(1,2), keepdims=True)
    resized_digit_images[digit] = imgs_resized
    normalized_digit_images[digit] = imgs_normalized
    density_matrices[digit] = density1
    #density_matrices[digit] = density2
    #density_matrices[digit] = density3
    #density_matrices[digit] = density4

train_density_matrices = np.concatenate([density_matrices[d] for d in range(10)], axis=0)
train_density_matrices_tensor = torch.tensor(train_density_matrices, dtype=torch.cfloat)

# ----------------------------
# Process test data
# ----------------------------
test_images_resized = np.array([resize(img.reshape(28,28), (8,8)).flatten() for img in x_test])
test_normed = normalize_batch(test_images_resized)
test_density = density_matrix_batch(test_normed)
test_density /= np.linalg.norm(test_density, axis=(1,2), keepdims=True)
test_density_tensor = torch.tensor(test_density, dtype=torch.cfloat)

# ----------------------------
# Visualization example
# ---------------------------

for digit in range(10):
    images_to_plot = resized_digit_images[digit][:10]
    plt.figure(figsize=(10,2))
    for i in range(10):
        plt.subplot(1, 10, i+1)
        plt.imshow(images_to_plot[i].reshape(8,8), cmap='magma')
        plt.title(f"{digit}")
        plt.axis('off')
    plt.show()

normalized_Hermitian_Digit_matrices = train_density_matrices_tensor
normalized_hermitian_matrices_test_input = test_density_tensor

print(f"normalized_Hermitian_Digit_matrices shape:- {normalized_Hermitian_Digit_matrices.shape}")
print(f"normalized_hermitian_matrices_test_input shape:- {normalized_hermitian_matrices_test_input.shape}")

labels = []
for i in range(10):
    labels.append(i)

print(labels)

D = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
# D = [100] * 10
labels_zero = [labels[0]]*D[0]
labels_one  = [labels[1]]*D[1]
labels_two  = [labels[2]]*D[2]
labels_three  = [labels[3]]*D[3]
labels_four  = [labels[4]]*D[4]
labels_five  = [labels[5]]*D[5]
labels_six  = [labels[6]]*D[6]
labels_seven  = [labels[7]]*D[7]
labels_eigth  = [labels[8]]*D[8]
labels_nineth  = [labels[9]]*D[9]
labels_zero = np.array(labels_zero,dtype = int)
labels_one = np.array(labels_one,dtype = int)
labels_two = np.array(labels_two,dtype = int)
labels_three = np.array(labels_three,dtype = int)
labels_four = np.array(labels_four,dtype = int)
labels_five = np.array(labels_five,dtype = int)
labels_six = np.array(labels_six,dtype = int)
labels_seven = np.array(labels_seven,dtype = int)
labels_eigth = np.array(labels_eigth,dtype = int)
labels_nineth = np.array(labels_nineth,dtype = int)

labels_new_train = np.concatenate((labels_zero,labels_one))
labels_new_train = np.concatenate((labels_new_train,labels_two))
labels_new_train = np.concatenate((labels_new_train,labels_three))
labels_new_train = np.concatenate((labels_new_train,labels_four))
labels_new_train = np.concatenate((labels_new_train,labels_five))
labels_new_train = np.concatenate((labels_new_train,labels_six))
labels_new_train = np.concatenate((labels_new_train,labels_seven))
labels_new_train = np.concatenate((labels_new_train,labels_eigth))
labels_new_train = np.concatenate((labels_new_train,labels_nineth))


**Case 1:-**

In this we model the classical MNIST images using a linear combination of trainable base matrices. The model learns 2 things simaltaneously: image-specefic coefficients that define the combination ands a single unitary transformation matrix which is trained to rotate each reconstructed Hamiltonian into a pre-defined diagonal target matrix effectively mapping each image to unique class specific quantum state.Classification is then done by simply checking the position of largest diagonal element in transformed test hamiltonian.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import time
from typing import Tuple

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

class IntegratedHermitianClassifier(nn.Module):
    def __init__(
        self,
        matrix_size: int = 64,
        d_order: int = 3,
        lr: float = 1e-3,
        epochs: int = 50,
        batch_size: int = 128,
        device: str = None,
    ) -> None:
        super().__init__()

        # Hyper-parameters
        self.n = matrix_size
        self.d = d_order
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

        # Trainable Hermitian bases (small init for stability)
        self.x = nn.Parameter(0.1 * torch.randn(matrix_size, matrix_size, dtype=torch.cfloat))
        self.p = nn.Parameter(0.1 * torch.randn(matrix_size, matrix_size, dtype=torch.cfloat))

        # Trainable unregularised matrix that will be exponentiated to unitary
        self.U_param = nn.Parameter(torch.randn(matrix_size, matrix_size, dtype=torch.cfloat))

        # Cached powers of  x  (x² … xᵈ)
        self.register_buffer("powers", torch.zeros(d_order, matrix_size, matrix_size, dtype=torch.cfloat))
        self._powers_computed = False

        # Placeholders for data tensors
        self.Mtr: torch.Tensor = None
        self.Mte: torch.Tensor = None
        self.y:   torch.Tensor = None
        self.yte: torch.Tensor = None
        self.target_matrices: torch.Tensor = None

        self.to(self.device)

    @staticmethod
    def _make_hermitian(M: torch.Tensor) -> torch.Tensor:
        return 0.5 * (M + M.conj().transpose(-2, -1))

    @staticmethod
    def _frobenius_norm(mat: torch.Tensor) -> torch.Tensor:
        return torch.norm(mat, p="fro", dim=(-2, -1), keepdim=True) + 1e-8

    def _normalise(self, M: torch.Tensor) -> torch.Tensor:
        return M / self._frobenius_norm(M)

    # def _make_unitary(self, matrix):
    #     U, _, Vh = torch.linalg.svd(matrix, full_matrices=False)
    #     return U @ Vh

    # Other Methods of Enforcing Unitary COnstraints that can be tried...
    # def _make_unitary(self, matrix):
    #     Q, R = torch.linalg.qr(matrix)
    #     return Q

    def _make_unitary(self, M: torch.Tensor) -> torch.Tensor:
        """
        Generates a unitary matrix through matrix exponential of a Hermitian:
            U = exp(-i H),  H = (M + M†)/2
        """
        H = self._make_hermitian(M)
        return torch.matrix_exp(-1j * H)

    @torch.no_grad()
    def _compute_powers(self) -> None:
        x_herm = self._make_hermitian(self._normalise(self.x))

        x_power = x_herm @ x_herm
        self.powers[0] = x_power
        for k in range(1, self.d):
            x_power = x_power @ x_herm
            self.powers[k] = x_power

        self._powers_computed = True

    def _base_hamiltonian(self) -> torch.Tensor:
        x_h = self._make_hermitian(self._normalise(self.x))
        p_h = self._make_hermitian(self._normalise(self.p))
        return self._make_hermitian(0.5 * (p_h @ p_h) + 0.5 * (x_h @ x_h))

    def _reconstruct_batch(self, coeffs: torch.Tensor) -> torch.Tensor:
        """
        Reconstruct Hamiltonians for a batch of coefficient vectors.
        coeffs.shape == (batch, d)
        """
        if not self._powers_computed:
            self._compute_powers()

        H0 = self._base_hamiltonian()
        batch = coeffs.shape[0]

        recon = H0.unsqueeze(0).expand(batch, -1, -1).clone()
        for k in range(self.d):
            recon = recon + coeffs[:, k].unsqueeze(-1).unsqueeze(-1) * self.powers[k]
        recon = 0.5 * (recon + recon.conj().transpose(-2, -1))
        return recon

    @staticmethod
    def _frobenius_batch(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
        diff = A - B
        # ||A-B||_F² = Tr((A-B)†(A-B))
        tr = torch.diagonal(diff.conj().transpose(-2, -1) @ diff, dim1=-2, dim2=-1).sum(-1)
        return torch.abs(tr)  # shape: (batch,)

    def load_data(
        self,
        train_mats: np.ndarray,
        train_labels: np.ndarray,
        test_mats: np.ndarray,
        test_labels: np.ndarray,
        target_mats: np.ndarray,
    ) -> None:
        print(f"► Loading data on {self.device} …")

        self.Mtr  = torch.as_tensor(train_mats, dtype=torch.cfloat, device=self.device)
        self.y    = torch.as_tensor(train_labels, dtype=torch.long,  device=self.device)
        self.Mte  = torch.as_tensor(test_mats,  dtype=torch.cfloat, device=self.device)
        self.yte  = torch.as_tensor(test_labels, dtype=torch.long,  device=self.device)
        self.target_matrices = torch.as_tensor(target_mats, dtype=torch.cfloat, device=self.device)

        self._compute_powers()  # pre-compute x-powers once

        print(
            f"✓ Data loaded — train {self.Mtr.shape}, test {self.Mte.shape}"
        )

    def forward(
        self,
        coeffs: torch.Tensor,             # (batch, d)
        originals: torch.Tensor,          # (batch, n, n)
        targets: torch.Tensor,            # (batch, n, n)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        # 1 -  reconstruction loss
        reconstructed = self._reconstruct_batch(coeffs)
        L_recon = self._frobenius_batch(reconstructed, originals).mean()

        # 2 -  unitary mapping loss
        U = self._make_unitary(self.U_param)
        transformed = U @ reconstructed @ U.conj().transpose(-2, -1)
        L_unitary = self._frobenius_batch(transformed, targets).mean()

        # 3 -  unitary constraint
        I = torch.eye(self.n, dtype=torch.cfloat, device=self.device)
        L_uc0 = torch.norm(U @ U.conj().T - I, p="fro")
        L_uc1 = torch.norm(U.conj().T @ U - I, p="fro")

        # weighted total
        total = L_recon + 10*L_unitary + 0.1 * (L_uc0 + L_uc1)
        return total, L_recon, L_unitary, L_uc0, L_uc1

    @torch.no_grad()
    def accuracy(self, mats: torch.Tensor, labels: torch.Tensor) -> float:
        U = self._make_unitary(self.U_param)
        outs = U @ mats @ U.conj().transpose(-2, -1)
        diag = torch.abs(torch.diagonal(outs, dim1=-2, dim2=-1))      # (batch, n)
        preds = torch.argmax(diag, dim=-1) // 7                        # 10 classes (positions spaced by 7)
        return torch.mean((preds == labels).float()).item()

class ComplexCoefficients(nn.Module):
    def __init__(self, d: int, n_samples: int, device: torch.device) -> None:
        super().__init__()
        std = 0.01 / np.sqrt(d)
        self.real = nn.Parameter(std * torch.randn(n_samples, d, device=device))
        self.imag = nn.Parameter(std * torch.randn(n_samples, d, device=device))

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        return torch.complex(self.real[idx], self.imag[idx])

    # handy accessor
    def all(self) -> torch.Tensor:
        return torch.complex(self.real, self.imag)

class ReduceLROnPlateau:
    def __init__(self, optimiser: optim.Optimizer, factor: float = 0.5, patience: int = 3, min_lr: float = 1e-8):
        self.opt = optimiser
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.best: float = None
        self.bad_epochs = 0

    def step(self, metric: float):
        if self.best is None or metric < self.best:
            self.best = metric
            self.bad_epochs = 0
            return
        self.bad_epochs += 1
        if self.bad_epochs >= self.patience:
            for pg in self.opt.param_groups:
                new_lr = max(pg["lr"] * self.factor, self.min_lr)
                pg["lr"] = new_lr
            self.bad_epochs = 0


def create_target_matrices(labels: np.ndarray, matrix_size: int = 64) -> np.ndarray:
    positions = [0, 7, 14, 21, 28, 35, 42, 49, 56, 63]
    out = np.zeros((len(labels), matrix_size, matrix_size), dtype=np.complex64)

    for i, lab in enumerate(labels):
        pos = positions[lab]
        out[i, pos, pos] = 1.0

    for i in range(len(out)):
        norm = np.linalg.norm(out[i], 'fro')
        if norm > 0:
            out[i] = out[i] / norm

    return out


def train_model(model: IntegratedHermitianClassifier) -> torch.Tensor:
    print("► Training started …")
    t0 = time.time()

    N = model.Mtr.shape[0]
    coeffs = ComplexCoefficients(model.d, N, model.device)

    opt_coeff   = optim.Adam(coeffs.parameters(), lr=model.lr)
    opt_bases   = optim.Adam([model.x, model.p], lr=model.lr * 0.1)
    opt_unitary = optim.Adam([model.U_param],     lr=model.lr * 0.05)

    sched_coeff   = ReduceLROnPlateau(opt_coeff)
    sched_bases   = ReduceLROnPlateau(opt_bases)
    sched_unitary = ReduceLROnPlateau(opt_unitary)

    ds = TensorDataset(model.Mtr, model.target_matrices, torch.arange(N, device=model.device))
    dl = DataLoader(ds, batch_size=model.batch_size, shuffle=True)

    for epoch in range(1, model.epochs + 1):
        print(f"Epoch: {epoch}")
        loss_sum = recon_sum = unit_sum = 0.0
        batches = 0

        for orig, tgt, idx in dl:
            a = coeffs(idx)

            opt_coeff.zero_grad()
            opt_bases.zero_grad()
            opt_unitary.zero_grad()

            tot, Lr, Lu, _, _ = model(a, orig, tgt)
            tot.backward()

            torch.nn.utils.clip_grad_norm_(coeffs.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_([model.x, model.p], 0.5)
            torch.nn.utils.clip_grad_norm_([model.U_param], 0.5)

            opt_coeff.step()
            opt_bases.step()
            opt_unitary.step()

            model._compute_powers()

            loss_sum  += tot.item()
            recon_sum += Lr.item()
            unit_sum  += Lu.item()
            batches   += 1

        sched_coeff.step(loss_sum / batches)
        sched_bases.step(recon_sum / batches)
        sched_unitary.step(unit_sum / batches)

        train_acc = model.accuracy(model.Mtr, model.y)
        test_acc  = model.accuracy(model.Mte, model.yte)

        print(
            f"  Epoch {epoch:02d}/{model.epochs} | "
            f"Loss {loss_sum / batches:.4e} | "
            f"Recon {recon_sum / batches:.4e} | "
            f"Unit {unit_sum / batches:.4e} | "
            f"Acc train {train_acc:.3f} │ test {test_acc:.3f} | "
            f"Δt {time.time()-t0:.1f}s"
        )

    print(f"✓ Training finished in {time.time() - t0:.1f}s")
    return coeffs.all()


# @torch.no_grad()
def predict_coefficients(
    model: IntegratedHermitianClassifier,
    d_order: int,
    lr: float = 1e-3,
    epochs: int = 200,
    batch_size: int = 2000,
) -> torch.Tensor:
    print("► Predicting coefficients for test set …")
    Nt = model.Mte.shape[0]
    out = torch.empty((Nt, d_order), dtype=torch.cfloat, device=model.device)

    ds = TensorDataset(model.Mte, torch.arange(Nt, device=model.device))
    dl = DataLoader(ds, batch_size=batch_size)

    for mats, idx in dl:
        coeff = ComplexCoefficients(d_order, len(idx), model.device)
        opt = optim.Adam(coeff.parameters(), lr=lr)

        for _ in range(epochs):
            opt.zero_grad()
            preds = coeff(torch.arange(len(idx), device=model.device))
            recon = model._reconstruct_batch(preds)
            loss = model._frobenius_batch(recon, mats).mean()
            loss.backward()
            opt.step()

        out[idx] = coeff.all()
    print("✓ Coefficient prediction complete")
    return out

def run_pipeline(
    train_mats: np.ndarray,
    train_labels: np.ndarray,
    test_mats: np.ndarray,
    test_labels: np.ndarray,
    *,
    matrix_size: int = 64,
    d_order: int = 10,
    lr: float = 5e-3,
    epochs: int = 50,
    batch_size: int = 128,
    out_dir: str = "./",
) -> Tuple[IntegratedHermitianClassifier, torch.Tensor, torch.Tensor]:

    set_seed(42)

    # Build class-specific target Hamiltonians
    targets = create_target_matrices(train_labels, matrix_size)

    # Initialise model and load data
    model = IntegratedHermitianClassifier(
        matrix_size=matrix_size,
        d_order=d_order,
        lr=lr,
        epochs=epochs,
        batch_size=batch_size,
    )
    model.load_data(train_mats, train_labels, test_mats, test_labels, targets)

    # Train
    train_coeffs = train_model(model)

    # Predict test coefficients
    #test_coeffs = predict_coefficients(model, d_order, lr=lr * 0.1)

    # # Persist artefacts
    # save_results(train_coeffs, test_coeffs, model, out_dir)

    # Final accuracies
    print("\n================ FINAL ACCURACY ================ ")
    print(f"Train: {model.accuracy(model.Mtr, model.y):.3%}")
    print(f" Test: {model.accuracy(model.Mte, model.yte):.3%}")
    print("===============================================\n")

    return model, train_coeffs, test_coeffs


def create_labels_from_class_counts(class_counts):
    labels = []
    for class_idx, count in enumerate(class_counts):
        labels.extend([class_idx] * count)
    return labels

class_counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]

train_X = normalized_Hermitian_Digit_matrices.numpy()
train_y = create_labels_from_class_counts(class_counts)
test_X = normalized_hermitian_matrices_test_input.numpy()
test_y = y_test

run_pipeline(train_X, train_y, test_X, test_y, matrix_size = 64, d_order = 50, lr = 1e-1, epochs = 60, batch_size = 100)

**Case 2:-**

Thsi is robust improvement done over case 1.Here instead of one global set of ciefficients, here the model learns 10 seperate pairs of coefficients per class and hence model trains class specific base Hamiltonians with training objective of minimizing the combined loss from reconstruction and unitary mapping.The classification is done on class to class basis with new set of coefficients being learnt per class and then calculating 2 erros: reconstruction error and target loss and then first classes with low targrt loss are chosen after which the class with least reconstruction loss is chosen.

In [None]:
from __future__ import annotations
import time
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


class IntegratedHermitianClassifier(nn.Module):
    def __init__(
        self,
        matrix_size: int = 64,
        n_classes: int = 10,
        d_order: int = 10,
        lr: float = 2e-3,
        epochs: int = 100,
        batch_size: int = 512,
        device: str | None = None,
    ) -> None:
        super().__init__()

        self.n, self.C, self.d = matrix_size, n_classes, d_order
        self.lr, self.epochs, self.batch_size = lr, epochs, batch_size
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

        # class-wise Hermitian bases  (x_c , p_c)  → Parameters (C,n,n)
        self.x_cls = nn.Parameter(0.1 * torch.randn(self.C, self.n, self.n, dtype=torch.cfloat))
        self.p_cls = nn.Parameter(0.1 * torch.randn(self.C, self.n, self.n, dtype=torch.cfloat))

        # global unconstrained parameter whose exponential is unitary
        self.U_param = nn.Parameter(torch.randn(self.n, self.n, dtype=torch.cfloat))

        # data placeholders
        self.Mtr = self.Mte = self.y = self.yte = self.target_mats = None

        self.to(self.device)

    @staticmethod
    def _make_hermitian(M: torch.Tensor) -> torch.Tensor:
        return 0.5 * (M + M.conj().transpose(-2, -1))

    @staticmethod
    def _fro(M: torch.Tensor) -> torch.Tensor:
        return torch.norm(M, p="fro", dim=(-2, -1), keepdim=True) + 1e-8

    def _unitary(self) -> torch.Tensor:
        H = self._make_hermitian(self.U_param)
        return torch.matrix_exp(-1j * H)

    def _reconstruct(
        self,
        coeffs: torch.Tensor,      # (B,d)
        x_b:   torch.Tensor,       # (B,n,n)  — already class-selected
        p_b:   torch.Tensor,       # (B,n,n)
    ) -> torch.Tensor:
        """Vectorised reconstruction for a batch."""
        B = coeffs.shape[0]
        x_h = self._make_hermitian(self._normalise(x_b := self._normalise(x_b)))
        p_h = self._make_hermitian(self._normalise(p_b := self._normalise(p_b)))

        H0 = self._make_hermitian(0.5 * (p_h @ p_h) + 0.5 * (x_h @ x_h))
        powers = []
        x_pow = x_h @ x_h
        powers.append(x_pow)
        for _ in range(1, self.d):
            x_pow = x_pow @ x_h
            powers.append(x_pow)

        recon = H0.clone()
        for k in range(self.d):
            recon = recon + coeffs[:, k].view(B, 1, 1) * powers[k]
        recon = 0.5 * (recon + recon.conj().transpose(-2, -1))
        return recon

    def _normalise(self, M: torch.Tensor) -> torch.Tensor:
        return M / self._fro(M)

    def forward(
        self,
        coeffs: torch.Tensor,      # (B,d)
        labels: torch.Tensor,      # (B,)
        originals: torch.Tensor,   # (B,n,n)
        targets: torch.Tensor,     # (B,n,n)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        x_b = self.x_cls[labels]
        p_b = self.p_cls[labels]
        recon = self._reconstruct(coeffs, x_b, p_b)

        diff     = recon - originals
        L_recon  = torch.diagonal(diff.conj().transpose(-2, -1) @ diff,
                                  dim1=-2, dim2=-1).sum(-1).abs().mean()

        U  = self._unitary()
        Uh = U.conj().transpose(-2, -1)
        diffU = U @ recon @ Uh - targets
        L_unit = torch.diagonal(diffU.conj().transpose(-2, -1) @ diffU,
                                dim1=-2, dim2=-1).sum(-1).abs().mean()

        I = torch.eye(self.n, dtype=torch.cfloat, device=self.device)
        L_uc = torch.norm(U @ Uh - I, p="fro") + torch.norm(Uh @ U - I, p="fro")

        total = L_recon + L_unit #+ 0.1 * L_uc
        return total, L_recon, L_unit, L_uc

    @torch.no_grad()
    def accuracy_diag(self, mats: torch.Tensor, labels: torch.Tensor) -> float:
        """Quick diagnostic accuracy – diag( UHU† )."""
        U = self._unitary()
        outs = U @ mats @ U.conj().transpose(-2, -1)
        preds = torch.diagonal(outs.abs(), dim1=-2, dim2=-1).argmax(-1) // 7
        return (preds == labels).float().mean().item()

    def load_data(
        self,
        train_mats: np.ndarray,
        train_labels: np.ndarray,
        test_mats:  np.ndarray,
        test_labels: np.ndarray,
        target_mats: np.ndarray,
    ) -> None:
        print(f"► Loading data on {self.device} …")

        self.Mtr = torch.tensor(train_mats, dtype=torch.cfloat, device=self.device)
        self.y   = torch.tensor(train_labels, dtype=torch.long, device=self.device)
        self.Mte = torch.tensor(test_mats,  dtype=torch.cfloat, device=self.device)
        self.yte = torch.tensor(test_labels, dtype=torch.long, device=self.device)
        self.target_mats = torch.tensor(target_mats, dtype=torch.cfloat, device=self.device)

        print(f"Data loaded — train {self.Mtr.shape}, test {self.Mte.shape}")


class ComplexCoefficients(nn.Module):
    def __init__(self, d: int, n_samples: int, device: torch.device):
        super().__init__()
        std = 0.01 / np.sqrt(d)
        self.real = nn.Parameter(std * torch.randn(n_samples, d, device=device))
        self.imag = nn.Parameter(std * torch.randn(n_samples, d, device=device))

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        return torch.complex(self.real[idx], self.imag[idx])

    def all(self) -> torch.Tensor:
        return torch.complex(self.real, self.imag)

class ReduceLROnPlateau:
    def __init__(self, opt: optim.Optimizer, factor=.5, patience=3, min_lr=1e-8):
        self.opt, self.factor, self.patience, self.min_lr = opt, factor, patience, min_lr
        self.best, self.bad = None, 0

    def step(self, metric: float):
        if self.best is None or metric < self.best:
            self.best, self.bad = metric, 0
            return
        self.bad += 1
        if self.bad >= self.patience:
            for pg in self.opt.param_groups:
                pg["lr"] = max(pg["lr"] * self.factor, self.min_lr)
            self.bad = 0

def create_target_matrices(labels: np.ndarray | List[int], n: int = 64) -> np.ndarray:
    pos = [0, 7, 14, 21, 28, 35, 42, 49, 56, 63]
    out = np.zeros((len(labels), n, n), np.complex64)
    for i, lab in enumerate(labels):
        out[i, pos[lab], pos[lab]] = 1.0
    out /= np.linalg.norm(out, axis=(-2, -1), keepdims=True, ord='fro')
    return out


def train(model: IntegratedHermitianClassifier) -> torch.Tensor:
    print("► Training …")
    t0 = time.time()

    N = model.Mtr.shape[0]
    coeffs = ComplexCoefficients(model.d, N, model.device)

    opt_coeff = optim.Adam(coeffs.parameters(), lr=model.lr)
    opt_bases = optim.Adam([model.x_cls, model.p_cls], lr=model.lr * .1)
    opt_unit  = optim.Adam([model.U_param],             lr=model.lr * .05)

    sch_coeff = ReduceLROnPlateau(opt_coeff)
    sch_bases = ReduceLROnPlateau(opt_bases)
    sch_unit  = ReduceLROnPlateau(opt_unit)

    ds = TensorDataset(model.Mtr, model.target_mats, model.y, torch.arange(N, device=model.device))
    dl = DataLoader(ds, batch_size=model.batch_size, shuffle=True)

    for ep in range(1, model.epochs + 1):
        print(f"Epoch: {ep}")
        tot_sum = rec_sum = unit_sum = 0.0; batches = 0
        for H, T, lab, idx in dl:
            a = coeffs(idx)

            opt_coeff.zero_grad(); opt_bases.zero_grad(); opt_unit.zero_grad()
            tot, Lr, Lu, L_uc = model(a, lab, H, T)
            tot.backward()

            torch.nn.utils.clip_grad_norm_(coeffs.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_([model.x_cls, model.p_cls], .5)
            torch.nn.utils.clip_grad_norm_([model.U_param], .5)

            opt_coeff.step(); opt_bases.step(); opt_unit.step()

            tot_sum += tot.item(); rec_sum += Lr.item(); unit_sum += Lu.item()
            batches += 1

        sch_coeff.step(tot_sum / batches)
        sch_bases.step(rec_sum / batches)
        sch_unit.step(unit_sum / batches)

        print(f"  Ep {ep:03d}/{model.epochs} | "
              f"Loss {tot_sum/batches:.4e} | "
              f"Recon {rec_sum/batches:.4e} | "
              f"Unit {unit_sum/batches:.4e} | "
              f"Acc {model.accuracy_diag(model.Mtr, model.y):.3f} │ "
              f"Δt {time.time()-t0:.1f}s")

    print("Done  (training time {:.1f}s)".format(time.time()-t0))
    return coeffs.all()


def classify_test_set(
    model: IntegratedHermitianClassifier,
    d_order: int,
    lr: float = 1e-3,
    epochs: int = 100,
    batch_size: int = 10000,
    tol_target: float = 1e-3,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns
    -------
    coeffs_out : (N_test,d)  coefficients of the chosen class per sample
    preds      : (N_test,)   predicted class labels
    """
    print("Classifying test set …")
    N = model.Mte.shape[0]
    coeffs_out = torch.empty((N, d_order), dtype=torch.cfloat, device=model.device)
    preds      = torch.empty(N, dtype=torch.long, device=model.device)

    # pre-compute class targets (C,n,n)
    tgt_cls = torch.tensor(create_target_matrices(list(range(model.C)), model.n),
                           dtype=torch.cfloat, device=model.device)
    U  = model._unitary()
    Uh = U.conj().transpose(-2, -1)

    ds = TensorDataset(model.Mte, torch.arange(N, device=model.device))
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False)

    for H_batch, idx_batch in dl:
        for j, idx in enumerate(idx_batch):
            H = H_batch[j:j+1]

            best_match_loss = torch.inf
            best_nomatch_loss = torch.inf
            best_coeff = None
            best_class = None
            matched = False

            for c in range(model.C):
                coeff = ComplexCoefficients(d_order, 1, model.device)
                opt   = optim.Adam(coeff.parameters(), lr=lr)

                x_c = model.x_cls[c:c+1]  # (1,n,n)
                p_c = model.p_cls[c:c+1]

                for _ in range(epochs):
                    opt.zero_grad()
                    a = coeff(torch.tensor([0], device=model.device))
                    recon = model._reconstruct(a, x_c, p_c)
                    loss = (recon - H).abs().pow(2).sum()
                    loss.backward()
                    opt.step()

                a_opt = coeff.all()  # (1,d)
                recon = model._reconstruct(a_opt, x_c, p_c)
                rec_loss = (recon - H).abs().pow(2).sum().item()

                out = U @ recon @ Uh
                tgt_diff = (out - tgt_cls[c:c+1]).abs().pow(2).sum().item()

                if tgt_diff < tol_target:      # criterion 1 satisfied
                    matched = True
                    if rec_loss < best_match_loss:
                        best_match_loss = rec_loss
                        best_coeff, best_class = a_opt.squeeze(0), c
                elif not matched and rec_loss < best_nomatch_loss:
                    best_nomatch_loss = rec_loss
                    best_coeff, best_class = a_opt.squeeze(0), c

            coeffs_out[idx] = best_coeff
            preds[idx]      = best_class

    print("Classification complete")
    return coeffs_out, preds


def run_pipeline(
    train_mats: np.ndarray,
    train_labels: np.ndarray,
    test_mats:  np.ndarray,
    test_labels: np.ndarray,
    *,
    matrix_size: int = 64,
    d_order: int = 10,
    lr: float = 5e-3,
    epochs: int = 200,
    batch_size: int = 1024,
) -> Tuple[IntegratedHermitianClassifier, torch.Tensor, torch.Tensor, torch.Tensor]:

    set_seed(42)

    tgt_train = create_target_matrices(train_labels, matrix_size)

    model = IntegratedHermitianClassifier(
        matrix_size=matrix_size,
        d_order=d_order,
        lr=lr,
        epochs=epochs,
        batch_size=batch_size,
    )
    model.load_data(train_mats, train_labels, test_mats, test_labels, tgt_train)

    coeff_tr = train(model)
    #coeff_te, preds_te = classify_test_set(model, d_order, lr=lr*0.1)

    print(f" Train diag-acc  : {model.accuracy_diag(model.Mtr, model.y):.3%}")
    print(f" Test  diag-acc  : {model.accuracy_diag(model.Mte, model.yte):.3%}")
    #print(f" Test  class-acc : {(preds_te.cpu().numpy()==test_labels).mean():.3%}")

    return model, coeff_tr, coeff_te, preds_te


def create_labels_from_class_counts(class_counts):
    labels = []
    for class_idx, count in enumerate(class_counts):
        labels.extend([class_idx] * count)
    return labels

class_counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]

train_X = normalized_Hermitian_Digit_matrices.numpy()
train_y = create_labels_from_class_counts(class_counts)
test_X = normalized_hermitian_matrices_test_input.numpy()
test_y = y_test

run_pipeline(train_X, train_y, test_X, test_y, matrix_size = 64, d_order = 90, lr = 1e-1, epochs = 60, batch_size = 100)



**Case 1 + Case 5**

This an improved extension of case 1 in which we have used used the smae coefficients for all classes but different unitary matrices for classes.

In [None]:
# x and p shared for separate classes , U seperate for seperate classes...(OVR Strategy) - Case 1 plus Case 5

from __future__ import annotations
import time
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def create_target_matrices(labels: np.ndarray | List[int], n: int = 64) -> np.ndarray:
    """Create diagonal target matrices with class-specific positions."""
    pos = [0, 7, 14, 21, 28, 35, 42, 49, 56, 63]
    out = np.zeros((len(labels), n, n), np.complex64)
    for i, lab in enumerate(labels):
        out[i, pos[lab], pos[lab]] = 1.0
    norms = np.linalg.norm(out, axis=(-2, -1), ord='fro', keepdims=True)
    out = out / (norms + 1e-8)
    return out

class ComplexCoefficients(nn.Module):
    def __init__(self, d: int, n_samples: int, device: torch.device):
        super().__init__()
        std = 0.01 / np.sqrt(d)
        self.real = nn.Parameter(std * torch.randn(n_samples, d, device=device))
        self.imag = nn.Parameter(std * torch.randn(n_samples, d, device=device))

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        return torch.complex(self.real[idx], self.imag[idx])

    def all(self) -> torch.Tensor:
        return torch.complex(self.real, self.imag)

class ReduceLROnPlateau:
    """Simple learning rate scheduler."""
    def __init__(self, opt: optim.Optimizer, factor=0.5, patience=3, min_lr=1e-8):
        self.opt, self.factor, self.patience, self.min_lr = opt, factor, patience, min_lr
        self.best, self.bad = None, 0

    def step(self, metric: float):
        if self.best is None or metric < self.best:
            self.best, self.bad = metric, 0
            return
        self.bad += 1
        if self.bad >= self.patience:
            for pg in self.opt.param_groups:
                pg["lr"] = max(pg["lr"] * self.factor, self.min_lr)
            self.bad = 0

class OvRMultiMarginHermitianClassifier(nn.Module):
    """
    Multi-class classifier with:
    - Shared Hermitian bases (x, p) for all classes
    - Separate unitary matrices U_k for each class (OvR strategy)
    - Multi-margin hinge loss objective
    """
    def __init__(
        self,
        matrix_size: int = 64,
        n_classes: int = 10,
        d_order: int = 10,
        lr: float = 2e-3,
        epochs: int = 100,
        batch_size: int = 128,
        chunk_size: int = 100,
        margin: float = 1.0,
        device: str | None = None,
    ) -> None:
        super().__init__()

        self.n, self.C, self.d = matrix_size, n_classes, d_order
        self.lr, self.epochs, self.batch_size = lr, epochs, batch_size
        self.chunk_size = chunk_size
        self.margin = margin
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

        # Shared Hermitian bases (x, p), parameters (n, n)
        self.x_mat = nn.Parameter(0.1 * torch.randn(self.n, self.n, dtype=torch.cfloat,device = self.device))
        self.p_mat = nn.Parameter(0.1 * torch.randn(self.n, self.n, dtype=torch.cfloat,device = self.device))

        # Separate unitary parameters for each class (C, n, n)
        self.U_param = nn.Parameter(torch.randn(self.C, self.n, self.n, dtype=torch.cfloat,device = self.device))

        # Class-specific target matrices (C, n, n)
        class_targets = create_target_matrices(list(range(self.C)), self.n)
        self.class_targets = torch.tensor(class_targets, dtype=torch.cfloat, device=self.device)

        # Data placeholders
        self.Mtr = self.Mte = self.y = self.yte = self.target_mats = None

        self.to(self.device)

    @staticmethod
    def _make_hermitian(M: torch.Tensor) -> torch.Tensor:
        return 0.5 * (M + M.conj().transpose(-2, -1))

    @staticmethod
    def _fro_norm(M: torch.Tensor) -> torch.Tensor:
        return torch.norm(M, p="fro", dim=(-2, -1), keepdim=True) + 1e-8

    def _normalise(self, M: torch.Tensor) -> torch.Tensor:
        return M / self._fro_norm(M)

    def _make_unitary(self, M: torch.Tensor) -> torch.Tensor:
        H = self._make_hermitian(M)
        return torch.matrix_exp(-1j * H)

    def _reconstruct(
        self,
        coeffs: torch.Tensor,      # (B, d)
        x_b: torch.Tensor,         # (B, n, n) or (n, n) if shared
        p_b: torch.Tensor,         # (B, n, n) or (n, n) if shared
    ) -> torch.Tensor:
        """Vectorized reconstruction for a batch."""
        B = coeffs.shape[0]

        x_h = self._make_hermitian(self._normalise(x_b))
        p_h = self._make_hermitian(self._normalise(p_b))

        # Base Hamiltonian: H0 = 0.5*(p²) + 0.5*(x²)
        H0 = self._make_hermitian(0.5 * (p_h @ p_h) + 0.5 * (x_h @ x_h))

        # Compute powers of x
        powers = []
        x_pow = x_h @ x_h
        powers.append(x_pow)
        for _ in range(1, self.d):
            x_pow = x_pow @ x_h
            powers.append(x_pow)

        # Reconstruct: H = H0 + Σ a_k * x^{k+2}
        recon = H0.clone()
        for k in range(self.d):
            recon = recon + coeffs[:, k].view(B, 1, 1) * powers[k]

        recon = 0.5 * (recon + recon.conj().transpose(-2, -1))
        return recon

    def forward(
        self,
        coeffs: torch.Tensor,      # (B, d)
        labels: torch.Tensor,      # (B,)
        originals: torch.Tensor,   # (B, n, n)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        B = originals.size(0)
        # Shared x and p for all classes, expand to (B, n, n)
        x_b = self.x_mat.unsqueeze(0).expand(B, self.n, self.n)
        p_b = self.p_mat.unsqueeze(0).expand(B, self.n, self.n)

        # Reconstruct Hamiltonians
        recon = self._reconstruct(coeffs, x_b, p_b)

        # 1) Reconstruction loss
        diff = recon - originals
        L_recon = torch.diagonal(diff.conj().transpose(-2, -1) @ diff, dim1=-2, dim2=-1).sum(-1).abs().mean()

        # 2) Multi-margin hinge loss with separate unitaries
        U = self._make_unitary(self.U_param)  # (C, n, n)

        # Compute scores for all classes (memory-efficient)
        scores = torch.zeros(B, self.C, device=self.device)
        for k in range(self.C):
            U_k = U[k]
            U_k_dag = U_k.conj().T
            target_k = self.class_targets[k]
            transformed_k = U_k @ recon @ U_k_dag
            dists_k = torch.norm(transformed_k - target_k.unsqueeze(0), p="fro", dim=(-2, -1))
            scores[:, k] = -dists_k

        # Multi-margin hinge loss: L = (1/C) * Σ_{j≠i} max(0, margin + s_j - s_i)
        correct_scores = scores[torch.arange(B), labels]
        hinge_losses = torch.zeros(B, device=self.device)
        for i in range(B):
            correct_score = correct_scores[i]
            wrong_scores = torch.cat([scores[i, :labels[i]], scores[i, labels[i]+1:]])
            margins = torch.clamp(self.margin + wrong_scores - correct_score, min=0.0)
            hinge_losses[i] = margins.mean()
        L_multimargin = hinge_losses.mean()

        # 3) Unitary constraints (batched)
        I = torch.eye(self.n, dtype=torch.cfloat, device=self.device)
        U_dag = U.conj().transpose(-2, -1)
        L_uc = torch.norm(U @ U_dag - I, p="fro") + torch.norm(U_dag @ U - I, p="fro")

        # Total loss
        total = L_recon + 5.0 * L_multimargin + 0.1 * L_uc
        return total, L_recon, L_multimargin, L_uc

    @torch.no_grad()
    def accuracy(self, mats: torch.Tensor, labels: torch.Tensor) -> float:
        """Memory-efficient accuracy computation using chunking."""
        U = self._make_unitary(self.U_param)  # (C, n, n)

        total_correct = 0
        total_samples = 0

        for start_idx in range(0, mats.size(0), self.chunk_size):
            end_idx = min(start_idx + self.chunk_size, mats.size(0))
            chunk_mats = mats[start_idx:end_idx]
            chunk_labels = labels[start_idx:end_idx]

            chunk_size_actual = chunk_mats.size(0)
            scores = torch.zeros(chunk_size_actual, self.C, device=self.device)

            for k in range(self.C):
                U_k = U[k]
                U_k_dag = U_k.conj().T
                target_k = self.class_targets[k]

                transformed_k = U_k @ chunk_mats @ U_k_dag
                dists_k = torch.norm(transformed_k - target_k.unsqueeze(0), p="fro", dim=(-2, -1))
                scores[:, k] = -dists_k

            preds = torch.argmax(scores, dim=1)
            total_correct += (preds == chunk_labels).sum().item()
            total_samples += chunk_size_actual

        return total_correct / total_samples

    def load_data(
        self,
        train_mats: np.ndarray,
        train_labels: np.ndarray,
        test_mats: np.ndarray,
        test_labels: np.ndarray,
        target_mats: np.ndarray,
    ) -> None:
        print(f"Loading data on {self.device} …")
        self.Mtr = torch.tensor(train_mats, dtype=torch.cfloat, device=self.device)
        self.y   = torch.tensor(train_labels, dtype=torch.long, device=self.device)
        self.Mte = torch.tensor(test_mats, dtype=torch.cfloat, device=self.device)
        self.yte = torch.tensor(test_labels, dtype=torch.long, device=self.device)
        self.target_mats = torch.tensor(target_mats, dtype=torch.cfloat, device=self.device)
        print(f"Data loaded — train {self.Mtr.shape}, test {self.Mte.shape}")

def train_model(model: OvRMultiMarginHermitianClassifier) -> torch.Tensor:
    """Training routine with multi-margin hinge loss."""
    print("► Training started …")
    t0 = time.time()

    N = model.Mtr.shape[0]
    coeffs = ComplexCoefficients(model.d, N, model.device)

    # Optimizers with different learning rates
    opt_coeff = optim.Adam(coeffs.parameters(), lr=model.lr)
    opt_bases = optim.Adam([model.x_mat, model.p_mat], lr=model.lr * 0.1)
    opt_units = optim.Adam([model.U_param], lr=model.lr * 0.05)

    # Schedulers
    sch_coeff = ReduceLROnPlateau(opt_coeff)
    sch_bases = ReduceLROnPlateau(opt_bases)
    sch_units = ReduceLROnPlateau(opt_units)

    # DataLoader
    ds = TensorDataset(model.Mtr, model.target_mats, model.y, torch.arange(N, device=model.device))
    dl = DataLoader(ds, batch_size=model.batch_size, shuffle=True)

    for ep in range(1, model.epochs + 1):
        print(f"Epoch:{ep}")
        tot_sum = rec_sum = margin_sum = 0.0
        batches = 0

        for H, T, lab, idx in dl:
            H, T, lab, idx = H.to(model.device), T.to(model.device), lab.to(model.device), idx.to(model.device)
            a = coeffs(idx)
            opt_coeff.zero_grad()
            opt_bases.zero_grad()
            opt_units.zero_grad()
            tot, L_recon, L_margin, L_uc = model(a, lab, H)
            tot.backward()
            torch.nn.utils.clip_grad_norm_(coeffs.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_([model.x_mat, model.p_mat], 0.5)
            torch.nn.utils.clip_grad_norm_([model.U_param], 0.5)
            opt_coeff.step()
            opt_bases.step()
            opt_units.step()
            tot_sum += tot.item()
            rec_sum += L_recon.item()
            margin_sum += L_margin.item()
            batches += 1
            del H, T, lab, idx
            torch.cuda.empty_cache()

        sch_coeff.step(tot_sum / batches)
        sch_bases.step(rec_sum / batches)
        sch_units.step(margin_sum / batches)

        if ep % 1 == 0 or ep == 1:
            train_acc = model.accuracy(model.Mtr, model.y)
            test_acc = model.accuracy(model.Mte, model.yte)
            acc_str = f"Acc train {train_acc:.3f} │ test {test_acc:.3f} | "
        else:
            acc_str = ""

        print(f"  Ep {ep:03d}/{model.epochs} | "
              f"Loss {tot_sum/batches:.4e} | "
              f"Recon {rec_sum/batches:.4e} | "
              f"Margin {margin_sum/batches:.4e} | "
              f"{acc_str}"
              f"Δt {time.time()-t0:.1f}s")

    print(f"Training finished in {time.time() - t0:.1f}s")
    return coeffs.all()

def classify_test_set(
    model: OvRMultiMarginHermitianClassifier,
    d_order: int,
    lr: float = 1e-3,
    epochs: int = 100,
    batch_size: int = 1000,
    tol_target: float = 1e-3,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Classify test samples by finding best reconstruction for each class.
    Returns:
        coeffs_out: (N_test, d) coefficients of the chosen class per sample
        preds: (N_test,) predicted class labels
    """
    print("Classifying test set …")
    N = model.Mte.shape[0]
    coeffs_out = torch.empty((N, d_order), dtype=torch.cfloat, device=model.device)
    preds = torch.empty(N, dtype=torch.long, device=model.device)

    # Pre-compute unitaries and targets
    U = model._make_unitary(model.U_param)

    ds = TensorDataset(model.Mte, torch.arange(N, device=model.device))
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False)

    for H_batch, idx_batch in dl:
        for j, idx in enumerate(idx_batch):
            H = H_batch[j:j+1]

            best_match_loss = torch.inf
            best_nomatch_loss = torch.inf
            best_coeff = None
            best_class = None
            matched = False

            for c in range(model.C):
                coeff = ComplexCoefficients(d_order, 1, model.device)
                opt = optim.Adam(coeff.parameters(), lr=lr)
                x_c = model.x_mat.unsqueeze(0)
                p_c = model.p_mat.unsqueeze(0)

                # Optimize coefficients for reconstruction
                for _ in range(epochs):
                    opt.zero_grad()
                    a = coeff(torch.tensor([0], device=model.device))
                    recon = model._reconstruct(a, x_c, p_c)
                    loss = (recon - H).abs().pow(2).sum()
                    loss.backward()
                    opt.step()

                a_opt = coeff.all()
                recon = model._reconstruct(a_opt, x_c, p_c)
                rec_loss = (recon - H).abs().pow(2).sum().item()
                U_c = U[c]
                U_c_dag = U_c.conj().T
                transformed = U_c @ recon @ U_c_dag
                target_c = model.class_targets[c:c+1]
                tgt_diff = (transformed - target_c).abs().pow(2).sum().item()

                if tgt_diff < tol_target:
                    matched = True
                    if rec_loss < best_match_loss:
                        best_match_loss = rec_loss
                        best_coeff, best_class = a_opt.squeeze(0), c
                elif not matched and rec_loss < best_nomatch_loss:
                    best_nomatch_loss = rec_loss
                    best_coeff, best_class = a_opt.squeeze(0), c

            coeffs_out[idx] = best_coeff
            preds[idx] = best_class

    print("Classification complete")
    return coeffs_out, preds

def run_pipeline(
    train_mats: np.ndarray,
    train_labels: np.ndarray,
    test_mats: np.ndarray,
    test_labels: np.ndarray,
    *,
    matrix_size: int = 64,
    d_order: int = 10,
    lr: float = 2e-3,
    epochs: int = 100,
    batch_size: int = 64,
    chunk_size: int = 100,
    margin: float = 1.0,
) -> Tuple[OvRMultiMarginHermitianClassifier, torch.Tensor, torch.Tensor, torch.Tensor]:

    set_seed(42)
    tgt_train = create_target_matrices(train_labels, matrix_size)
    n_classes = int(train_labels.max() + 1)

    model = OvRMultiMarginHermitianClassifier(
        matrix_size=matrix_size,
        n_classes=n_classes,
        d_order=d_order,
        lr=lr,
        epochs=epochs,
        batch_size=batch_size,
        chunk_size=chunk_size,
        margin=margin,
    )
    model.load_data(train_mats, train_labels, test_mats, test_labels, tgt_train)
    coeff_tr = train_model(model)
    #coeff_te, preds_te = classify_test_set(model, d_order, lr=lr*0.1)

    print(f"Train accuracy: {model.accuracy(model.Mtr, model.y):.2%}")
    print(f"Test accuracy:  {model.accuracy(model.Mte, model.yte):.2%}")
    #print(f"Test class-acc: {(preds_te.cpu().numpy() == test_labels).mean():.2%}")


    return model, coeff_tr, coeff_te, preds_te

def create_labels_from_class_counts(class_counts):
    labels = []
    for class_idx, count in enumerate(class_counts):
        labels.extend([class_idx] * count)
    return labels

class_counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]

train_X = normalized_Hermitian_Digit_matrices.numpy()
train_y = np.array(create_labels_from_class_counts(class_counts))
test_X = normalized_hermitian_matrices_test_input.numpy()
test_y = np.array(y_test)

run_pipeline(train_X, train_y, test_X, test_y, matrix_size = 64, d_order = 100, lr = 1e-1, epochs = 60, batch_size = 100)

**Case 2 + Case 5:-**


This is an improved extension for class 2 where the coefficnets(xcap and pcap) as well as unitary matrix(U) are seperate for seperate classes.

In [None]:
# x and p separate for separate classes , U seperate for seperate classes...(OVR Strategy) - Case 2 plus Case 5
# For d_order = 10 , perform sweeps d_order = 10, 100 ,1000 , and find the sweet point of best accuracy...

# Accuracy for d_order 10 - 76 percent

from __future__ import annotations
import time
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


def create_target_matrices(labels: np.ndarray | List[int], n: int = 64) -> np.ndarray:
    """Create diagonal target matrices with class-specific positions."""
    pos = [0, 7, 14, 21, 28, 35, 42, 49, 56, 63]
    out = np.zeros((len(labels), n, n), np.complex64)
    for i, lab in enumerate(labels):
        out[i, pos[lab], pos[lab]] = 1.0
    norms = np.linalg.norm(out, axis=(-2, -1), ord='fro', keepdims=True)
    out = out / (norms + 1e-8)
    return out


class ComplexCoefficients(nn.Module):
    def __init__(self, d: int, n_samples: int, device: torch.device):
        super().__init__()
        std = 0.01 / np.sqrt(d)
        self.real = nn.Parameter(std * torch.randn(n_samples, d, device=device))
        self.imag = nn.Parameter(std * torch.randn(n_samples, d, device=device))

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        return torch.complex(self.real[idx], self.imag[idx])

    def all(self) -> torch.Tensor:
        return torch.complex(self.real, self.imag)


class ReduceLROnPlateau:
    """Simple learning rate scheduler."""
    def __init__(self, opt: optim.Optimizer, factor=0.5, patience=3, min_lr=1e-8):
        self.opt, self.factor, self.patience, self.min_lr = opt, factor, patience, min_lr
        self.best, self.bad = None, 0

    def step(self, metric: float):
        if self.best is None or metric < self.best:
            self.best, self.bad = metric, 0
            return
        self.bad += 1
        if self.bad >= self.patience:
            for pg in self.opt.param_groups:
                pg["lr"] = max(pg["lr"] * self.factor, self.min_lr)
            self.bad = 0


class OvRMultiMarginHermitianClassifier(nn.Module):
    """
    Multi-class classifier with:
    - Class-specific Hermitian bases (x_c, p_c) for each class
    - Separate unitary matrices U_k for each class (OvR strategy)
    - Multi-margin hinge loss objective
    """

    def __init__(
        self,
        matrix_size: int = 64,
        n_classes: int = 10,
        d_order: int = 10,
        lr: float = 2e-3,
        epochs: int = 100,
        batch_size: int = 128,
        chunk_size: int = 100,
        margin: float = 1.0,
        device: str | None = None,
    ) -> None:
        super().__init__()

        self.n, self.C, self.d = matrix_size, n_classes, d_order
        self.lr, self.epochs, self.batch_size = lr, epochs, batch_size
        self.chunk_size = chunk_size
        self.margin = margin
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

        # Class-wise Hermitian bases (x_c, p_c) → Parameters (C, n, n)
        self.x_cls = nn.Parameter(0.1 * torch.randn(self.C, self.n, self.n, dtype=torch.cfloat,device = self.device))
        self.p_cls = nn.Parameter(0.1 * torch.randn(self.C, self.n, self.n, dtype=torch.cfloat,device = self.device))

        # Separate unitary parameters for each class (C, n, n)
        self.U_param = nn.Parameter(torch.randn(self.C, self.n, self.n, dtype=torch.cfloat,device = self.device))

        # Class-specific target matrices (C, n, n)
        class_targets = create_target_matrices(list(range(self.C)), self.n)
        self.class_targets = torch.tensor(class_targets, dtype=torch.cfloat, device=self.device)

        # Data placeholders
        self.Mtr = self.Mte = self.y = self.yte = self.target_mats = None

        self.to(self.device)

    @staticmethod
    def _make_hermitian(M: torch.Tensor) -> torch.Tensor:
        return 0.5 * (M + M.conj().transpose(-2, -1))

    @staticmethod
    def _fro_norm(M: torch.Tensor) -> torch.Tensor:
        return torch.norm(M, p="fro", dim=(-2, -1), keepdim=True) + 1e-8

    def _normalise(self, M: torch.Tensor) -> torch.Tensor:
        return M / self._fro_norm(M)

    def _make_unitary(self, M: torch.Tensor) -> torch.Tensor:
        H = self._make_hermitian(M)
        return torch.matrix_exp(-1j * H)

    def _reconstruct(
        self,
        coeffs: torch.Tensor,      # (B, d)
        x_b: torch.Tensor,         # (B, n, n) - class-selected x matrices
        p_b: torch.Tensor,         # (B, n, n) - class-selected p matrices
    ) -> torch.Tensor:
        """Vectorized reconstruction for a batch."""
        B = coeffs.shape[0]
        x_h = self._make_hermitian(self._normalise(x_b))
        p_h = self._make_hermitian(self._normalise(p_b))

        # Base Hamiltonian: H0 = 0.5*(p²) + 0.5*(x²)
        H0 = self._make_hermitian(0.5 * (p_h @ p_h) + 0.5 * (x_h @ x_h))

        # Compute powers of x
        powers = []
        x_pow = x_h @ x_h
        powers.append(x_pow)
        for _ in range(1, self.d):
            x_pow = x_pow @ x_h
            powers.append(x_pow)

        # Reconstruct: H = H0 + Σ a_k * x^{k+2}
        recon = H0.clone()
        for k in range(self.d):
            recon = recon + coeffs[:, k].view(B, 1, 1) * powers[k]

        recon = 0.5 * (recon + recon.conj().transpose(-2, -1))
        return recon

    def forward(
        self,
        coeffs: torch.Tensor,      # (B, d)
        labels: torch.Tensor,      # (B,)
        originals: torch.Tensor,   # (B, n, n)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        B = originals.size(0)

        # Get class-specific bases for this batch
        x_b = self.x_cls[labels]  # (B, n, n)
        p_b = self.p_cls[labels]  # (B, n, n)

        # Reconstruct Hamiltonians
        recon = self._reconstruct(coeffs, x_b, p_b)

        # 1) Reconstruction loss
        diff = recon - originals
        L_recon = torch.diagonal(diff.conj().transpose(-2, -1) @ diff,
                                dim1=-2, dim2=-1).sum(-1).abs().mean()

        # 2) Multi-margin hinge loss with separate unitaries
        U = self._make_unitary(self.U_param)  # (C, n, n)

        # Compute scores for all classes (memory-efficient)
        scores = torch.zeros(B, self.C, device=self.device)
        for k in range(self.C):
            U_k = U[k]  # (n, n)
            U_k_dag = U_k.conj().T
            target_k = self.class_targets[k]
            transformed_k = U_k @ recon @ U_k_dag
            # Score = negative distance to target_k (higher is better)
            dists_k = torch.norm(transformed_k - target_k.unsqueeze(0), p="fro", dim=(-2, -1))
            scores[:, k] = -dists_k

        # Multi-margin hinge loss: L = (1/C) * Σ_{j≠i} max(0, margin + s_j - s_i)
        correct_scores = scores[torch.arange(B), labels]

        hinge_losses = torch.zeros(B, device=self.device)
        for i in range(B):
            correct_score = correct_scores[i]
            wrong_scores = torch.cat([scores[i, :labels[i]], scores[i, labels[i]+1:]])
            margins = torch.clamp(self.margin + wrong_scores - correct_score, min=0.0)
            hinge_losses[i] = margins.mean()

        L_multimargin = hinge_losses.mean()

        # 3) Unitary constraints (batched)
        I = torch.eye(self.n, dtype=torch.cfloat, device=self.device)
        U_dag = U.conj().transpose(-2, -1)
        L_uc = torch.norm(U @ U_dag - I, p="fro") + torch.norm(U_dag @ U - I, p="fro")

        # Total loss
        total = L_recon + 5.0 * L_multimargin + 0.1 * L_uc
        return total, L_recon, L_multimargin, L_uc

    @torch.no_grad()
    def accuracy(self, mats: torch.Tensor, labels: torch.Tensor) -> float:
        """Memory-efficient accuracy computation using chunking."""
        U = self._make_unitary(self.U_param)  # (C, n, n)

        total_correct = 0
        total_samples = 0

        # Process in chunks to avoid memory issues
        for start_idx in range(0, mats.size(0), self.chunk_size):
            end_idx = min(start_idx + self.chunk_size, mats.size(0))
            chunk_mats = mats[start_idx:end_idx]
            chunk_labels = labels[start_idx:end_idx]

            chunk_size_actual = chunk_mats.size(0)
            scores = torch.zeros(chunk_size_actual, self.C, device=self.device)

            # Compute scores for each class
            for k in range(self.C):
                U_k = U[k]
                U_k_dag = U_k.conj().T
                target_k = self.class_targets[k]

                transformed_k = U_k @ chunk_mats @ U_k_dag
                dists_k = torch.norm(transformed_k - target_k.unsqueeze(0), p="fro", dim=(-2, -1))
                scores[:, k] = -dists_k  # higher is better

            preds = torch.argmax(scores, dim=1)
            total_correct += (preds == chunk_labels).sum().item()
            total_samples += chunk_size_actual

        return total_correct / total_samples

    def load_data(
        self,
        train_mats: np.ndarray,
        train_labels: np.ndarray,
        test_mats: np.ndarray,
        test_labels: np.ndarray,
        target_mats: np.ndarray,
    ) -> None:
        print(f"Loading data on {self.device} …")

        self.Mtr = torch.tensor(train_mats, dtype=torch.cfloat, device=self.device)
        self.y   = torch.tensor(train_labels, dtype=torch.long, device=self.device)
        self.Mte = torch.tensor(test_mats, dtype=torch.cfloat, device=self.device)
        self.yte = torch.tensor(test_labels, dtype=torch.long, device=self.device)
        self.target_mats = torch.tensor(target_mats, dtype=torch.cfloat, device=self.device)

        print(f"Data loaded — train {self.Mtr.shape}, test {self.Mte.shape}")


def train_model(model: OvRMultiMarginHermitianClassifier) -> torch.Tensor:
    """Training routine with multi-margin hinge loss."""
    print("► Training started …")
    t0 = time.time()

    N = model.Mtr.shape[0]
    coeffs = ComplexCoefficients(model.d, N, model.device)

    # Optimizers with different learning rates
    opt_coeff = optim.Adam(coeffs.parameters(), lr=model.lr)
    opt_bases = optim.Adam([model.x_cls, model.p_cls], lr=model.lr * 0.1)
    opt_units = optim.Adam([model.U_param], lr=model.lr * 0.05)

    # Schedulers
    sch_coeff = ReduceLROnPlateau(opt_coeff)
    sch_bases = ReduceLROnPlateau(opt_bases)
    sch_units = ReduceLROnPlateau(opt_units)

    # DataLoader
    ds = TensorDataset(model.Mtr, model.target_mats, model.y, torch.arange(N, device=model.device))
    dl = DataLoader(ds, batch_size=model.batch_size, shuffle=True)

    for ep in range(1, model.epochs + 1):
        print(f"Epoch : {ep}")
        tot_sum = rec_sum = margin_sum = 0.0
        batches = 0

        for H, T, lab, idx in dl:
            H, T, lab, idx = H.to(model.device), T.to(model.device), lab.to(model.device), idx.to(model.device)
            a = coeffs(idx)

            # Zero gradients
            opt_coeff.zero_grad()
            opt_bases.zero_grad()
            opt_units.zero_grad()

            # Forward pass
            tot, L_recon, L_margin, L_uc = model(a, lab, H)
            tot.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(coeffs.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_([model.x_cls, model.p_cls], 0.5)
            torch.nn.utils.clip_grad_norm_([model.U_param], 0.5)

            # Update parameters
            opt_coeff.step()
            opt_bases.step()
            opt_units.step()

            tot_sum += tot.item()
            rec_sum += L_recon.item()
            margin_sum += L_margin.item()
            batches += 1
            del H, T, lab, idx
            torch.cuda.empty_cache()


        sch_coeff.step(tot_sum / batches)
        sch_bases.step(rec_sum / batches)
        sch_units.step(margin_sum / batches)

        if ep % 1 == 0 or ep == 1:
            train_acc = model.accuracy(model.Mtr, model.y)
            test_acc = model.accuracy(model.Mte, model.yte)
            acc_str = f"Acc train {train_acc:.3f} │ test {test_acc:.3f} | "
        else:
            acc_str = ""

        print(f"  Ep {ep:03d}/{model.epochs} | "
              f"Loss {tot_sum/batches:.4e} | "
              f"Recon {rec_sum/batches:.4e} | "
              f"Margin {margin_sum/batches:.4e} | "
              f"{acc_str}"
              f"Δt {time.time()-t0:.1f}s")

    print(f"✓ Training finished in {time.time() - t0:.1f}s")
    xcap_np = model.x_cls.detach().cpu().numpy()
    pcap_np = model.p_cls.detach().cpu().numpy()
    return xcap_np, pcap_np, coeffs.all()


def classify_test_set(
    model: OvRMultiMarginHermitianClassifier,
    d_order: int,
    lr: float = 1e-3,
    epochs: int = 100,
    batch_size: int = 1000,
    tol_target: float = 1e-3,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Classify test samples by finding best reconstruction for each class.

    Returns:
        coeffs_out: (N_test, d) coefficients of the chosen class per sample
        preds: (N_test,) predicted class labels
    """
    print("► Classifying test set …")
    N = model.Mte.shape[0]
    coeffs_out = torch.empty((N, d_order), dtype=torch.cfloat, device=model.device)
    preds = torch.empty(N, dtype=torch.long, device=model.device)

    # Pre-compute unitaries and targets
    U = model._make_unitary(model.U_param)  # (C, n, n)

    ds = TensorDataset(model.Mte, torch.arange(N, device=model.device))
    #ds = TensorDataset(torch.tensor(test_mat[:1000],dtype=torch.cfloat, device=model.device), torch.arange(N, device=model.device))
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False)
    i = 0
    while i == 0:
        for H_batch, idx_batch in dl:
            while i == 0:
                H_batch = H_batch.to(model.device)
                idx_batch = idx_batch.to(model.device)
                print(f"Idx batch: {idx_batch}")
                for j, idx in enumerate(idx_batch):
                    print(f"{j},{idx}")
                    H = H_batch[j:j+1]  # (1, n, n)

                    best_match_loss = torch.inf
                    best_nomatch_loss = torch.inf
                    best_coeff = None
                    best_class = None
                    matched = False

                    for c in range(model.C):
                        coeff = ComplexCoefficients(d_order, 1, model.device)
                        opt = optim.Adam(coeff.parameters(), lr=lr)

                        # Class-specific bases
                        x_c = model.x_cls[c:c+1]  # (1, n, n)
                        p_c = model.p_cls[c:c+1]  # (1, n, n)

                        # Optimize coefficients for reconstruction
                        for _ in range(epochs):
                            opt.zero_grad()
                            a = coeff(torch.tensor([0], device=model.device))
                            recon = model._reconstruct(a, x_c, p_c)
                            loss = (recon - H).abs().pow(2).sum()
                            loss.backward()
                            opt.step()

                        # Get final coefficients and reconstruction
                        a_opt = coeff.all()  # (1, d)
                        recon = model._reconstruct(a_opt, x_c, p_c)
                        rec_loss = (recon - H).abs().pow(2).sum().item()

                        # Check if transformation matches class target
                        U_c = U[c]
                        U_c_dag = U_c.conj().T
                        transformed = U_c @ recon @ U_c_dag
                        target_c = model.class_targets[c:c+1]
                        tgt_diff = (transformed - target_c).abs().pow(2).sum().item()

                        # Classification logic
                        if tgt_diff < tol_target:  # Matches class target
                            matched = True
                            if rec_loss < best_match_loss:
                                best_match_loss = rec_loss
                                best_coeff, best_class = a_opt.squeeze(0), c
                        elif not matched and rec_loss < best_nomatch_loss:
                            best_nomatch_loss = rec_loss
                            best_coeff, best_class = a_opt.squeeze(0), c

                    coeffs_out[idx] = best_coeff
                    preds[idx] = best_class
                    i = 1


    print("Classification complete")
    return coeffs_out, preds


def run_pipeline(
    train_mats: np.ndarray,
    train_labels: np.ndarray,
    test_mats: np.ndarray,
    test_labels: np.ndarray,
    *,
    matrix_size: int = 64,
    d_order: int = 10,
    lr: float = 2e-3,
    epochs: int = 100,
    batch_size: int = 64,
    chunk_size: int = 100,
    margin: float = 1.0,
) -> Tuple[OvRMultiMarginHermitianClassifier, torch.Tensor, torch.Tensor, torch.Tensor]:

    set_seed(42)

    # Create target matrices for training
    tgt_train = create_target_matrices(train_labels, matrix_size)
    n_classes = int(train_labels.max() + 1)

    # Initialize model
    model = OvRMultiMarginHermitianClassifier(
        matrix_size=matrix_size,
        n_classes=n_classes,
        d_order=d_order,
        lr=lr,
        epochs=epochs,
        batch_size=batch_size,
        chunk_size=chunk_size,
        margin=margin,
    )
    model.load_data(train_mats, train_labels, test_mats, test_labels, tgt_train)

    # Train the model
    x,p,coeff_tr = train_model(model)

    # Classify test set
    coeff_te, preds_te = classify_test_set(model, d_order,lr=lr*0.1,batch_size = 15)
    print(coeff_te.shape)
    # Final results
    print(f"Train accuracy: {model.accuracy(model.Mtr, model.y):.2%}")
    print(f"Test accuracy:  {model.accuracy(model.Mte, model.yte):.2%}")
    #print(f"Test class-acc: {(preds_te.cpu().numpy() == test_labels[:5]).mean():.2%}")
    print("===============================================\n")

    return model, coeff_te.cpu(), preds_te.cpu(), x, p, coeff_tr



# Perform d_order sweeps 10,100,1000 , and find the sweet point by Grid Search Sweep ...

def create_labels_from_class_counts(class_counts):
    labels = []
    for class_idx, count in enumerate(class_counts):
        labels.extend([class_idx] * count)
    return labels

class_counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]

train_X = normalized_Hermitian_Digit_matrices.numpy()
train_y = np.array(create_labels_from_class_counts(class_counts))
test_X = normalized_hermitian_matrices_test_input.numpy()
test_y = np.array(y_test)

model,coeff_te, preds_te, x, p, coeff_tr = run_pipeline(train_X, train_y, test_X, test_y, matrix_size = 64, d_order = 20, lr = 1e-1, epochs = 60, batch_size = 100)