In this approach, we implement a quantum-inspired classification method for MNIST dataset by representing classical images as quantum Hamiltonians and then employing the principle of Adiabatic Thereom to find a quantum state that corrosponds to specific image, and then perform classification on the basis of its similarity to other states.

We have used 2 methods for the same.

Data preparation:

We use the MNIST dataset consisting of 60000 training images and 1000 tet images.The original size of the images is (784,). We reshape the images to size (28,28), normalize the images and then convert the images into hamiltonians using 4 methods

1) H = (A + A.T) / 2

2)H = AA.T

3)H = outer product of flattened image vectors

4) H = -i*log(V) where V is a unitary matrix.

Then all the hamiltonians are seperated to 10 classes on the basis of the digits from 0 to 9.

In [None]:
# Data Preprocessing of the MNIST Dataset to produce the train and test normalized Hamiitonians...
# We can construct the hamiltonians from the four methods described in the paper...
import numpy as np
from skimage.transform import resize
import matplotlib.pyplot as plt
from PIL import Image
import torch
from sklearn.datasets import fetch_openml
import scipy
from tensorflow.keras.datasets import mnist

# Load MNIST using Keras
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# # Reshape and convert to float64 for consistency
x_train = x_train.reshape(-1, 784).astype(np.float64)
x_test = x_test.reshape(-1, 784).astype(np.float64)
print("Train:", x_train.shape, y_train.shape)
print("Test:", x_test.shape, y_test.shape)

# ----------------------------
# Helper functions
# ----------------------------
def separate_digits(images, labels):
    """Group images by digit label."""
    digit_image = {d: [] for d in range(10)}
    for img, lbl in zip(images, labels):
        digit_image[lbl].append(img)
    return digit_image

def resize_images_batch(images, new_size=(8, 8), batch_size=500):
    """Resize a batch of flattened 28x28 images to new_size."""
    n = len(images)
    resized = []
    for i in range(0, n, batch_size):
        batch = images[i:i+batch_size]
        resized_batch = [resize(img.reshape(28,28), new_size).flatten() for img in batch]
        resized.extend(resized_batch)
    return np.array(resized)

def normalize_batch(images):
    """Normalize each image vector."""
    norms = np.linalg.norm(images, axis=1, keepdims=True)
    return images / norms

#Creating Hamiltonian using outer product method
def density_matrix_batch(images):
    """Convert vectors to density matrices."""
    return np.matmul(images[:,:,np.newaxis], images[:,np.newaxis,:])

#Creating the Hamiltonian using H = A + A.T/2 method
def hamiltonian_symmetric_batch(images):
  N,D = images.shape
  H_list = []
  for i in range(N):
    a = images[i]
    A = np.outer(a,np.ones(D))
    H = (A + A.conj().T) / 2
    H_list.append(H)

  return np.array(H_list)

#Creating the Hamiltonian using H = A @ A.T method
def hamiltonian_product_batch(images):
  N , D = images.shape
  H_list = []
  for i in range(N):
    a = images[i]
    A = np.outer(a,np.ones(D))
    H = A @ A.T
    H_list.append(H)
  return np.array(H_list)

import scipy.linalg
#Creating the Hamiltonian using H = -i * log(V) method
def hamiltonian_using_log(images):
    def _make_hermitian(M):
        return 0.5 * (M + M.conj().transpose(-2,-1))

    def _make_unitary(M):
        H = _make_hermitian(M)
        return torch.matrix_exp(-1j*H)

    N,D = images.shape
    hamiltonians = np.zeros((N,D,D),dtype = np.complex128)
    for i in range(N):
        image = images[i]
        mat = np.diag(image)
        mat_torch = torch.tensor(mat,dtype = torch.complex128)
        H = _make_unitary(mat_torch)
        hamiltonians[i] = H
    return hamiltonians


# ----------------------------
# Process training data
# ----------------------------
digit_images_dict = separate_digits(x_train, y_train)
resized_digit_images = {}
normalized_digit_images = {}
density_matrices = {}

for digit, imgs in digit_images_dict.items():
    imgs = np.array(imgs)
    imgs_resized = resize_images_batch(imgs, new_size=(8,8), batch_size=500)
    imgs_normalized = normalize_batch(imgs_resized)
    print(f"normalized_images shape:- {imgs_normalized.shape}")
    density1 = density_matrix_batch(imgs_normalized)
    print(f"shape 1:- {density1.shape}")
    #OR
    #density2 = hamiltonian_symmetric_batch(imgs_normalized)
    #print(f"shape 2:- {density2.shape}")
    #OR
    #density3 = hamiltonian_product_batch(imgs_normalized)
    #print(f"shape 3:- {density3.shape}")
    #OR
    #density4 = hamiltonian_using_log(imgs_normalized)
    #print(f"shape 4:- {density4.shape}")
    density1 /= np.linalg.norm(density1, axis=(1,2), keepdims=True)
    #density2 /= np.linalg.norm(density2, axis=(1,2), keepdims=True)
    #density3 /= np.linalg.norm(density3, axis=(1,2), keepdims=True)
    #density4 /= np.linalg.norm(density4, axis=(1,2), keepdims=True)
    resized_digit_images[digit] = imgs_resized
    normalized_digit_images[digit] = imgs_normalized
    density_matrices[digit] = density1
    #density_matrices[digit] = density2
    #density_matrices[digit] = density3
    #density_matrices[digit] = density4

train_density_matrices = np.concatenate([density_matrices[d] for d in range(10)], axis=0)
train_density_matrices_tensor = torch.tensor(train_density_matrices, dtype=torch.cfloat)

# ----------------------------
# Process test data
# ----------------------------
test_images_resized = np.array([resize(img.reshape(28,28), (8,8)).flatten() for img in x_test])
test_normed = normalize_batch(test_images_resized)
test_density = density_matrix_batch(test_normed)
test_density /= np.linalg.norm(test_density, axis=(1,2), keepdims=True)
test_density_tensor = torch.tensor(test_density, dtype=torch.cfloat)

# ----------------------------
# Visualization example
# ---------------------------

for digit in range(10):
    images_to_plot = resized_digit_images[digit][:10]
    plt.figure(figsize=(10,2))
    for i in range(10):
        plt.subplot(1, 10, i+1)
        plt.imshow(images_to_plot[i].reshape(8,8), cmap='magma')
        plt.title(f"{digit}")
        plt.axis('off')
    plt.show()

normalized_Hermitian_Digit_matrices = train_density_matrices_tensor
normalized_hermitian_matrices_test_input = test_density_tensor

print(f"normalized_Hermitian_Digit_matrices shape:- {normalized_Hermitian_Digit_matrices.shape}")
print(f"normalized_hermitian_matrices_test_input shape:- {normalized_hermitian_matrices_test_input.shape}")

labels = []
for i in range(10):
    labels.append(i)

print(labels)

D = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
# D = [100] * 10
labels_zero = [labels[0]]*D[0]
labels_one  = [labels[1]]*D[1]
labels_two  = [labels[2]]*D[2]
labels_three  = [labels[3]]*D[3]
labels_four  = [labels[4]]*D[4]
labels_five  = [labels[5]]*D[5]
labels_six  = [labels[6]]*D[6]
labels_seven  = [labels[7]]*D[7]
labels_eigth  = [labels[8]]*D[8]
labels_nineth  = [labels[9]]*D[9]
labels_zero = np.array(labels_zero,dtype = int)
labels_one = np.array(labels_one,dtype = int)
labels_two = np.array(labels_two,dtype = int)
labels_three = np.array(labels_three,dtype = int)
labels_four = np.array(labels_four,dtype = int)
labels_five = np.array(labels_five,dtype = int)
labels_six = np.array(labels_six,dtype = int)
labels_seven = np.array(labels_seven,dtype = int)
labels_eigth = np.array(labels_eigth,dtype = int)
labels_nineth = np.array(labels_nineth,dtype = int)

labels_new_train = np.concatenate((labels_zero,labels_one))
labels_new_train = np.concatenate((labels_new_train,labels_two))
labels_new_train = np.concatenate((labels_new_train,labels_three))
labels_new_train = np.concatenate((labels_new_train,labels_four))
labels_new_train = np.concatenate((labels_new_train,labels_five))
labels_new_train = np.concatenate((labels_new_train,labels_six))
labels_new_train = np.concatenate((labels_new_train,labels_seven))
labels_new_train = np.concatenate((labels_new_train,labels_eigth))
labels_new_train = np.concatenate((labels_new_train,labels_nineth))


**Method 1:-**

In this method, the classical data(MNIST images) are converted into quantum Hamiltonians by representing each image as a matrix.A neural network is then used to train the 10 class Hamiltonians one for each digit, which act as prototypes.In classification phase, a new images is first converted into quantum Hamiltonian and this Hamiltonian is compared to all the mean trained Hamiltonians of the 10 classes using Frobenius norm as the distance metric. The images is classified as the digit whose mean class Hamiltonian is closest or has smallest Frobenius norm.

Here to obtain the training hamiltonians we have made use of 3 unitary methods

1) Using Singular Value decomposition(make_unitary1)

2)Using QR decomposition(make_unitary2)

3)Using Quantum unitary encoding.(make_unitary3)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

#Class to handle the training hamiltonian matrices
class MatrixModel(nn.Module):
  def __init__(self,num_classes = 10, matrix_size = 64):
    super().__init__()
    self.num_classes = num_classes
    self.matrix_size = matrix_size
    self.eigenvalues = nn.Parameter(torch.randn(num_classes,matrix_size,dtype = torch.float64))
    self.eigenvectors_real = nn.Parameter(torch.randn(num_classes,matrix_size,matrix_size,dtype = torch.float64))
    self.eigenvectors_imag = nn.Parameter(torch.randn(num_classes,matrix_size,matrix_size,dtype = torch.float64))

  def get_complex_eigenvectors(self):
    return torch.complex(self.eigenvectors_real, self.eigenvectors_imag)

  def find_unitary_transformation(self,input_density_matrix, output_density_matrix):
    X = np.dot(output_density_matrix, np.linalg.pinv(input_density_matrix))
    U, S, V_dagger = np.linalg.svd(X, full_matrices=False)
    phase_matrix = np.diag(np.exp(1j * np.angle(S)))
    unitary_matrix = U @ (phase_matrix @ V_dagger)
    unitary_matrix /= np.linalg.det(unitary_matrix)**(1/2)
    return unitary_matrix

  def is_unitary(matrix):
    # Check if the matrix is unitary
    identity = np.eye(matrix.shape[0])
    return np.allclose(matrix @ matrix.conj().T, identity) and np.allclose(matrix.conj().T @ matrix, identity)



  def make_unitary1(self,matrix):
    U,_,Vh = torch.linalg.svd(matrix,full_matrices = False)
    return U @ Vh

  def make_unitary2(self,matrix):
    matrix = matrix / (torch.linalg.norm(matrix) + 1e-12) #for -i log(V) case only
    Q,R = torch.linalg.qr(matrix)
    return Q

  def make_unitary3(self,matrix):
    # Lower triangular density matrix
    matrix = matrix.detach().cpu().numpy()
    input_density_matrix = np.zeros((64, 64), dtype=np.complex128)
    input_density_matrix[0, 0] = 1.0
    lower_triangular = np.tril(matrix)
    lower_triangular_conj = np.conj(lower_triangular).T
    Density_Matrix_Classical_Lower = lower_triangular + lower_triangular_conj


    # Upper triangular density matrix
    upper_triangular = np.triu(matrix)
    upper_triangular_conj = np.conj(upper_triangular).T
    Density_Matrix_Classical_Upper = upper_triangular + upper_triangular_conj

    # Halve diagonal elements
    np.fill_diagonal(Density_Matrix_Classical_Lower,
                     Density_Matrix_Classical_Lower.diagonal() / 2)
    np.fill_diagonal(Density_Matrix_Classical_Upper,
                     Density_Matrix_Classical_Upper.diagonal() / 2)

    # Normalize by trace
    Density_Matrix_Classical_Lower_Normalized = (
        Density_Matrix_Classical_Lower / np.trace(Density_Matrix_Classical_Lower)
    )
    Density_Matrix_Classical_Upper_Normalized = (
        Density_Matrix_Classical_Upper / np.trace(Density_Matrix_Classical_Upper)
    )

    # Find unitary transformations
    unitary_transformation_lower = self.find_unitary_transformation(
        input_density_matrix, Density_Matrix_Classical_Lower_Normalized
    )
    unitary_transformation_upper = self.find_unitary_transformation(
        input_density_matrix, Density_Matrix_Classical_Upper_Normalized
    )

    # Choose one — here we return the lower version
    # return unitary_transformation_lower
    return torch.from_numpy(unitary_transformation_lower).to(torch.complex128)






  def get_hamiltonians_orig1(self):
        eigenvectors_complex = self.get_complex_eigenvectors()
        unitary_vecs = torch.stack([self.make_unitary1(mat) for mat in eigenvectors_complex])
        diag_matrices = torch.diag_embed(self.eigenvalues.to(torch.complex128))
        hamiltonians = unitary_vecs @ diag_matrices @ unitary_vecs.conj().transpose(-1, -2)
        hamiltonians = (hamiltonians + hamiltonians.conj().transpose(-1, -2)) / 2
        return hamiltonians

  def get_hamiltonians_orig2(self):
        eigenvectors_complex = self.get_complex_eigenvectors()
        unitary_vecs = torch.stack([self.make_unitary2(mat) for mat in eigenvectors_complex])
        diag_matrices = torch.diag_embed(self.eigenvalues.to(torch.complex128))
        hamiltonians = unitary_vecs @ diag_matrices @ unitary_vecs.conj().transpose(-1, -2)
        hamiltonians = (hamiltonians + hamiltonians.conj().transpose(-1, -2)) / 2
        return hamiltonians

  def get_hamiltonians_orig3(self):
        eigenvectors_complex = self.get_complex_eigenvectors()
        unitary_vecs = torch.stack([self.make_unitary3(mat) for mat in eigenvectors_complex])
        diag_matrices = torch.diag_embed(self.eigenvalues.to(torch.complex128))
        hamiltonians = unitary_vecs @ diag_matrices @ unitary_vecs.conj().transpose(-1, -2)
        hamiltonians = (hamiltonians + hamiltonians.conj().transpose(-1, -2)) / 2

        return hamiltonians





  def forward(self):
    return self.get_hamiltonians_orig1() #OR get_hamiltonians_orig2,get_hamiltonians_orig3

def combined_loss_batched(output, target_batch, labels_batch):
  batch_size = target_batch.size(0)
  class_hamiltonians = output[labels_batch]
  # print(class_hamiltonians.shape)
  # print(target_batch.shape)
  # diff = torch.nan_to_num(class_hamiltonians - target_batch,nan = 0.0,posinf = 1e6,neginf = -1e6)
  # losses = torch.linalg.norm(diff,dim = (1,2))
  losses = torch.linalg.norm(class_hamiltonians - target_batch,dim = (1,2))
  return torch.mean(losses)

def combined_loss_batched2(output, target_batch, labels_batch):
    batch_size = target_batch.size(0)

    # Differentiable class selection
    class_hamiltonians = torch.gather(
        output, 1, labels_batch.view(-1, 1, 1, 1).expand(-1, 1, 64, 64)
    ).squeeze(1)

    # Convert target to Hamiltonian form
    target_batch = target_batch.reshape(batch_size, -1)  # (n, 64)
    target_batch = torch.einsum('bi,bj->bij', target_batch, target_batch.conj())  # (n, 64, 64)

    losses = torch.linalg.norm(class_hamiltonians - target_batch, dim=(1, 2))
    return torch.mean(losses)


def create_labels_from_class_counts(class_counts):
    labels = []
    for class_idx, count in enumerate(class_counts):
        labels.extend([class_idx] * count)
    return labels

def create_batched_data(data, labels, batch_size=64):
    if not isinstance(data, torch.Tensor):
        if isinstance(data, list) and len(data) > 0:
            data = torch.stack(data)
        else:
            data = torch.tensor(data)
    if not isinstance(labels, torch.Tensor):
        labels = torch.tensor(labels)
    if not data.dtype == torch.complex128:
        data = data.to(torch.complex128)
    dataset = TensorDataset(data, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader

def train_model(model, dataloader,optimizer, scheduler, threshold = 0.0000001, num_epochs=100):
    model.train()
    all_losses = []
    wait = 0
    patience = 5
    to_stop = 0
    epoch = 0
    for epoch in range(num_epochs):
        print(f"Epoch:- {epoch}")
        total_loss = 0.0
        num_batches = 0

        for batch_data, batch_labels in dataloader:
            if num_batches % 1000 == 0:
              print(num_batches)
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)
            optimizer.zero_grad()
            outputs = model()
            loss = combined_loss_batched(outputs, batch_data, batch_labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        scheduler.step(avg_loss)
        diff = 0
        if epoch > 2:
          diff = all_losses[-1] - avg_loss
        all_losses.append(avg_loss)
        if epoch > 2 and avg_loss - all_losses[-1] < threshold:
          print("less than threshold")
          if wait < patience:
            wait = wait + 1
          else:
            to_stop = 1
        if epoch % 1 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], Average Loss: {avg_loss:.4e}, Difference = {diff:.10e}')
        epoch = epoch + 1

    print("Training completed!")

def inference(model, test_data, test_labels=None):
    model.eval()

    with torch.no_grad():
        hamiltonians = model().cpu() # 10,64,64
        # print(hamiltonians.shape)
        predicted_labels = []

        for test_sample in test_data:
            if isinstance(test_sample, torch.Tensor):
                test_sample = test_sample.cpu()
            frobenius_norms = []
            for class_idx in range(10):
                # print(test_sample.shape)
                # print(hamiltonians[class_idx].shape)
                norm = torch.linalg.norm(test_sample - hamiltonians[class_idx], ord='fro')
                frobenius_norms.append(norm.item())
            predicted_labels.append(np.argmin(frobenius_norms))
        if test_labels is not None:
            accuracy = np.sum(np.array(predicted_labels) == np.array(test_labels))
            accuracy_percent = (accuracy / len(test_labels)) * 100
            return predicted_labels, accuracy_percent

        return predicted_labels

We perform training using a batch size of 500 and number of epochs = 80

In [None]:
import scipy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Creating the training data and labels")
class_counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
training_labels = create_labels_from_class_counts(class_counts)
training_data = torch.as_tensor(normalized_Hermitian_Digit_matrices,dtype = torch.complex128,device = device)
print(f"training data shape:- {training_data.shape}")


print("Initialising the model,optimier,scheduler")
model = MatrixModel(num_classes = 10,matrix_size = 64).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10)
batch_size = 500

dataloader = create_batched_data(training_data, training_labels, batch_size=batch_size)
print(f"Training with {len(dataloader)} batches of size {batch_size}")
train_model(model, dataloader, optimizer, scheduler,num_epochs = 80)

model.eval()
with torch.no_grad():
  trained_hamiltonians = model().cpu()
  trained_eigenvalues = model.eigenvalues.cpu()
  trained_eigenvectors = model.get_complex_eigenvectors().cpu()

  print("\nTrained Components:")
  for class_idx in range(10):
      print(f"\nClass {class_idx}:")
      print(f"Eigenvalues shape: {trained_eigenvalues[class_idx].shape}")
      print(f"Eigenvectors shape: {trained_eigenvectors[class_idx].shape}")
      print(f"Hamiltonian shape: {trained_hamiltonians[class_idx].shape}")
      H = trained_hamiltonians[class_idx]
      hermitian_error = torch.max(torch.abs(H - H.conj().T))
      print(f"Hermiticity error: {hermitian_error:.2e}")



Performing the inference

visualize_eigenvalues() is used to visualize the trained eigenvalues of the model per class.

In [None]:
print("Performing inference..")
train_data = torch.as_tensor(normalized_Hermitian_Digit_matrices,dtype = torch.complex128)
test_data = torch.as_tensor(normalized_hermitian_matrices_test_input,
                               dtype=torch.complex128)

predicted_labels , train_acc = inference(model,train_data,torch.tensor(training_labels,dtype = torch.long))
print(f"Train Accuracy: {train_acc:.2f}%")

predicted_labels, accuracy = inference(model, test_data, torch.tensor(y_test,dtype = torch.long))
print(f"Test Accuracy: {accuracy:.2f}%")

def visualize_eigenvalues(model):
    with torch.no_grad():
        plt.figure(figsize=(15,6))
        for c in range(10):
            eig = model.eigenvalues[c].cpu().numpy()
            plt.subplot(2,5,c+1)
            plt.plot(np.sort(eig), 'o--')
            plt.title(f'Class {c}')
            plt.grid(True)
        plt.tight_layout()
        plt.show()

visualize_eigenvalues(model)

**Method 2:-**

This methid is improvement on the first method.Instead of training on single prototype Hamiltonian per class, the model learns class specific mean eigenvectors and eigenvalues for all 10 digits.During the classification phase, for each test image, the model iteratively optimizes new eigenvalues to best reconstruct the test Hamiltonian using pretrained eigenvectors of each class.The test image is then assigned the label of class that yields lowest reconstruction error.  

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class_sizes = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]  # Must sum to 60000

class ComplexMatrixModel(nn.Module):
    def __init__(self, num_classes=10, matrix_size=64):
        super().__init__()
        self.num_classes = num_classes
        self.matrix_size = matrix_size
        self.class_sizes = class_sizes
        self.eigenvectors = nn.Parameter(
            torch.randn(num_classes, matrix_size, matrix_size, dtype=torch.complex128, device=device)
        )
        self.eigenvalues = nn.Parameter(
            torch.randn(num_classes, matrix_size, dtype=torch.float64, device=device)
        )
        self.class_weights = nn.ParameterList([
            nn.Parameter(torch.randn(size, matrix_size, dtype=torch.float64, device=device))
            for size in self.class_sizes
        ])

    def make_unitary1(self, matrix):
        U, _, Vh = torch.linalg.svd(matrix)
        return U @ Vh

    def make_unitary2(self,matrix):
      Q,R = torch.linalg.qr(matrix)
      return Q

    def make_unitary3(self,matrix):
      matrix = (matrix + matrix.conj().T) / 2
      s = torch.matrix_exp((-1j)*matrix)
      return s

    #CHange 1,2,3
    def forward(self, class_idx=None, batch_indices=None):
        if class_idx is not None:
            return self._generate_class_hamiltonians1(class_idx, batch_indices)
        else:
            all_hams = []
            for idx in range(self.num_classes):
                hams = self._generate_class_hamiltonians1(idx)
                all_hams.append(hams)
            return torch.cat(all_hams, dim=0)

    def _generate_class_hamiltonians1(self, class_idx, batch_indices=None):
        U = self.make_unitary1(self.eigenvectors[class_idx])
        eigvals = self.eigenvalues[class_idx]
        if batch_indices is not None:
            weights = self.class_weights[class_idx][batch_indices]
        else:
            weights = self.class_weights[class_idx]

        scaled_eig = weights * eigvals.unsqueeze(0)
        diag_mats = torch.diag_embed(scaled_eig).to(torch.complex128)
        U = U.to(torch.complex128)
        U_expanded = U.unsqueeze(0).expand(diag_mats.shape[0], -1, -1)
        U_conj_T = U_expanded.conj().transpose(-1, -2)
        H = U_expanded @ diag_mats @ U_conj_T
        H_hermitian = (H + H.conj().transpose(-1, -2)) / 2
        return H_hermitian

    def _generate_class_hamiltonians2(self, class_idx, batch_indices=None):
        U = self.make_unitary2(self.eigenvectors[class_idx])
        eigvals = self.eigenvalues[class_idx]
        if batch_indices is not None:
            weights = self.class_weights[class_idx][batch_indices]
        else:
            weights = self.class_weights[class_idx]

        scaled_eig = weights * eigvals.unsqueeze(0)
        diag_mats = torch.diag_embed(scaled_eig).to(torch.complex128)
        U = U.to(torch.complex128)
        U_expanded = U.unsqueeze(0).expand(diag_mats.shape[0], -1, -1)
        U_conj_T = U_expanded.conj().transpose(-1, -2)
        H = U_expanded @ diag_mats @ U_conj_T
        H_hermitian = (H + H.conj().transpose(-1, -2)) / 2
        return H_hermitian

    def _generate_class_hamiltonians3(self, class_idx, batch_indices=None):
        U = self.make_unitary3(self.eigenvectors[class_idx])
        eigvals = self.eigenvalues[class_idx]
        if batch_indices is not None:
            weights = self.class_weights[class_idx][batch_indices]
        else:
            weights = self.class_weights[class_idx]

        scaled_eig = weights * eigvals.unsqueeze(0)
        diag_mats = torch.diag_embed(scaled_eig).to(torch.complex128)
        U = U.to(torch.complex128)
        U_expanded = U.unsqueeze(0).expand(diag_mats.shape[0], -1, -1)
        U_conj_T = U_expanded.conj().transpose(-1, -2)
        H = U_expanded @ diag_mats @ U_conj_T
        H_hermitian = (H + H.conj().transpose(-1, -2)) / 2
        return H_hermitian


    def reconstruct_hamiltonian1(self, class_idx, eigenvalues):
        U = self.make_unitary1(self.eigenvectors[class_idx]).to(torch.complex128)
        d = torch.diag(eigenvalues.to(torch.complex128))
        H = U @ d @ U.conj().T
        H_hermitian = (H + H.conj().T)/2
        return H_hermitian

    def reconstruct_hamiltonian2(self, class_idx, eigenvalues):
        U = self.make_unitary2(self.eigenvectors[class_idx]).to(torch.complex128)
        d = torch.diag(eigenvalues.to(torch.complex128))
        H = U @ d @ U.conj().T
        H_hermitian = (H + H.conj().T)/2
        return H_hermitian

    def reconstruct_hamiltonian3(self, class_idx, eigenvalues):
        U = self.make_unitary3(self.eigenvectors[class_idx]).to(torch.complex128)
        d = torch.diag(eigenvalues.to(torch.complex128))
        H = U @ d @ U.conj().T
        H_hermitian = (H + H.conj().T)/2
        return H_hermitian

def frobenius_loss(output, target):
    return torch.mean(torch.linalg.norm(output - target, dim=(-2, -1)))

def train_model_batched(model, train_data, train_labels, epochs=10, batch_size=256):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10)

    train_data = torch.as_tensor(train_data, dtype=torch.complex128)
    train_labels = torch.as_tensor(train_labels, dtype=torch.long)

    train_dataset = TensorDataset(train_data, train_labels)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        batch_count = 0
        for batch_data, batch_labels in train_loader:
            if batch_count % 100 == 0 :
                print(f"Batch:{batch_count}")
            optimizer.zero_grad()
            batch_outputs, batch_targets = [], []
            for class_idx in range(10):
                mask = (batch_labels == class_idx)
                if not mask.any():
                    continue
                sample_idx = torch.where(mask)[0]
                class_sample_idxs = batch_data.new_tensor(
                    [idx.item() for idx in sample_idx], dtype=torch.long, device=device
                )
                class_local_indices = mask.nonzero().flatten()
                class_out = model(class_idx=class_idx, batch_indices=class_local_indices)
                batch_outputs.append(class_out)
                batch_targets.append(batch_data[mask].to(torch.complex128))

            if batch_outputs:
                outputs = torch.cat(batch_outputs, dim=0)
                targets = torch.cat(batch_targets, dim=0)
                loss = frobenius_loss(outputs, targets)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                batch_count += 1
        avg_loss = total_loss / max(batch_count, 1)
        scheduler.step(avg_loss)
        if epoch % 1 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] Loss: {avg_loss:.5e}")

#Here change 1,2,3
import random
def classify_train(model, test_hamiltonians, test_labels,optimize_epochs=20, lr=0.1):
    # class_counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
    class_counts = [5923,12665,18623,24754,30596,36017,41935,48200,54051,60000]

    model.eval()
    results = []
    with torch.no_grad():
        fixed_eigenvecs = []
        for c in range(10):
            fixed_eigenvecs.append(model.make_unitary1(model.eigenvectors[c]).to(torch.complex128))
    no = 100
    for idx, test_H in enumerate(test_hamiltonians):
            min_error = float('inf')
            best_class = -1
            test_H = test_H.to(device).to(torch.complex128)
            for class_idx in range(10):
                # Optimize eigenvalues for this test_H and class
                eigvals = model.eigenvalues[class_idx].detach().clone().to(device).requires_grad_(True)
                optimizer = optim.Adam([eigvals], lr=lr)
                for _ in range(optimize_epochs):
                    optimizer.zero_grad()
                    U = fixed_eigenvecs[class_idx]
                    dmat = torch.diag(eigvals.to(torch.complex128))
                    H = U @ dmat @ U.conj().T
                    H_herm = (H + H.conj().T)/2
                    loss = torch.linalg.norm(H_herm - test_H)
                    loss.backward()
                    optimizer.step()
                err = loss.item()
                if err < min_error:
                    min_error = err
                    best_class = class_idx
            results.append(best_class)
            if (idx+1)%50 == 0 or idx < 10:
                print(f"Test {idx+1}/{len(test_hamiltonians)} done, best class: {best_class}, loss: {min_error:.3e}")
    return np.array(results)

#Change 1,2,3
def classify_test(model, test_hamiltonians, test_labels,optimize_epochs=20, lr=0.1):
    # class_counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]

    model.eval()
    results = []
    with torch.no_grad():
        fixed_eigenvecs = []
        for c in range(10):
            fixed_eigenvecs.append(model.make_unitary1(model.eigenvectors[c]).to(torch.complex128))
    for idx, test_H in enumerate(test_hamiltonians):
            min_error = float('inf')
            best_class = -1
            test_H = test_H.to(device).to(torch.complex128)
            for class_idx in range(10):
                # Optimize eigenvalues for this test_H and class
                eigvals = model.eigenvalues[class_idx].detach().clone().to(device).requires_grad_(True)
                optimizer = optim.Adam([eigvals], lr=lr)
                for _ in range(optimize_epochs):
                    optimizer.zero_grad()
                    U = fixed_eigenvecs[class_idx]
                    dmat = torch.diag(eigvals.to(torch.complex128))
                    H = U @ dmat @ U.conj().T
                    H_herm = (H + H.conj().T)/2
                    loss = torch.linalg.norm(H_herm - test_H)
                    loss.backward()
                    optimizer.step()
                err = loss.item()
                if err < min_error:
                    min_error = err
                    best_class = class_idx
            results.append(best_class)
            if (idx+1)%50 == 0 or idx < 10:
                print(f"Test {idx+1}/{len(test_hamiltonians)} done, best class: {best_class}, loss: {min_error:.3e}")
    return np.array(results)


def evaluate(pred_labels, true_labels):
    acc = (pred_labels == true_labels.cpu().numpy()).mean()
    print(f"\nTest Accuracy: {acc*100:.2f}%")
    for c in range(10):
        mask = (true_labels.cpu().numpy() == c)
        acc_c = (pred_labels[mask] == c).mean() if mask.sum() > 0 else np.nan
        print(f"Class {c}: {acc_c*100:.2f}% ({mask.sum()} samples)")
    return acc

def visualize_eigenvalues(model):
    with torch.no_grad():
        plt.figure(figsize=(15,6))
        for c in range(10):
            eig = model.eigenvalues[c].cpu().numpy()
            plt.subplot(2,5,c+1)
            plt.plot(np.sort(eig), 'o--')
            plt.title(f'Class {c}')
            plt.grid(True)
        plt.tight_layout()
        plt.show()

Training

We use batch size of 600 ans number of epochs = 80 for training

In [None]:
def create_labels_from_class_counts(class_counts):
    labels = []
    for class_idx, count in enumerate(class_counts):
        labels.extend([class_idx] * count)
    return labels

def create_batched_data(data, labels, batch_size=64):
    if not isinstance(data, torch.Tensor):
        if isinstance(data, list) and len(data) > 0:
            data = torch.stack(data)
        else:
            data = torch.tensor(data)
    if not isinstance(labels, torch.Tensor):
        labels = torch.tensor(labels)
    if not data.dtype == torch.complex128:
        data = data.to(torch.complex128)
    dataset = TensorDataset(data, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader


class_counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
train_hamiltonians = normalized_Hermitian_Digit_matrices
train_labels = create_labels_from_class_counts(class_counts)
test_hamiltonians = normalized_hermitian_matrices_test_input
test_labels = y_test

model = ComplexMatrixModel(num_classes=10, matrix_size=64).to(device)
print("Model Instantiated.")

print("----- TRAINING PHASE -----")
train_model_batched(model, train_hamiltonians, np.array(train_labels), epochs=80, batch_size=600)


Inference

In [None]:
print("----- CLASSIFICATION ON TEST -----")
print("Train prediction")
pred_labels_train = classify_train(model,train_hamiltonians,torch.tensor(train_labels,dtype = torch.long),optimize_epochs=20,lr = 0.1)
print("test prediction")
pred_labels_test = classify_test(model, test_hamiltonians, torch.tensor(test_labels[:1000],dtype = torch.long), optimize_epochs=20, lr=0.1)

print("----- ACCURACY -----")
train_accuracy = evaluate(pred_labels_train, torch.tensor(training_labels,dtype = torch.long))
test_accuracy = evaluate(pred_labels_test, torch.tensor(y_test,dtype = torch.long))
print(f"Train accuracy: {train_accuracy}")
print(f"Test accuracy: {test_accuracy}")
