In [52]:
import pennylane as qml
import torch
from torch.utils.data import DataLoader, TensorDataset
import ray
from ray import tune
from ray.air import session
from sklearn.model_selection import train_test_split
import numpy as np
import yaml
import json
import time
import os
import pandas as pd
import torch.optim as optim


# Custom Libraries
from utils.model import Qkernel
from utils.data_generator import DataGenerator
from utils.agent import TrainModel

# Backend Configuration
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [2]:
# Read Configs
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

data = np.load('checkerboard_dataset.npy', allow_pickle=True).item()
x_train, x_test, y_train, y_test = data['x_train'], data['x_test'], data['y_train'], data['y_test']

training_data = torch.tensor(x_train, dtype=torch.float32, requires_grad=True)
testing_data = torch.tensor(x_test, dtype=torch.float32, requires_grad=True)
training_labels = torch.tensor(y_train, dtype=torch.int)
testing_labels = torch.tensor(y_test, dtype=torch.int)

kernel = Qkernel(   
                        device = config['qkernel']['device'], 
                        n_qubits = 4, 
                        trainable = True, 
                        input_scaling = True, 
                        data_reuploading = True, 
                        ansatz = 'embedding_paper', 
                        ansatz_layers = 5
                    )
    
agent = TrainModel(
                        kernel=kernel,
                        training_data=training_data,
                        training_labels=training_labels,
                        testing_data=testing_data,
                        testing_labels=testing_labels,
                        optimizer= 'adam',
                        lr= 0.1,
                        epochs = 200,
                        train_method= 'ccka',
                        target_accuracy=0.95,
                        get_alignment_every=10,  
                        validate_every_epoch=None, 
                        base_path='.',
                        lambda_kao=0.01,
                        lambda_co=0.1,
                        clusters=4
                      )

In [21]:
# Initialize empty lists for main centroids and class centroids
main_centroids = []
class_centroids = []
main_centroid_labels = []
class_centroid_labels = []

for c in [1, -1]:
    class_data = training_data[training_labels == c]
    # Calculate the main centroid and add it to main_centroids
    main_centroid = torch.mean(class_data, axis = 0).requires_grad_()  # Shape [1, feature_dim]
    main_centroids.append(main_centroid)
    main_centroid_labels.append(c)
    
            # Calculate centroids for each cluster in the class and stack them into a single tensor
    #class_centroids.append([np.mean(cluster.tolist(), axis=0).tolist() for cluster in np.array_split(class_data, 4)])
    #class_centroid_labels.append([c] * 4)

In [28]:
training_data.grad

In [77]:
data = training_data.detach().numpy()
data_labels = training_labels.detach().numpy()


for c in [1, -1]:
    cdata = data[data_labels == c]
    mc = [np.mean(cdata, axis=0)]
    sub_centroids = [np.mean(cluster, axis=0) for cluster in np.array_split(cdata, 4)]
    class_centroids = mc + sub_centroids
    

In [78]:
class_centroids

[array([0.47623625, 0.5011184 ], dtype=float32),
 array([0.37070206, 0.6296703 ], dtype=float32),
 array([0.36902234, 0.50664234], dtype=float32),
 array([0.64839965, 0.39306217], dtype=float32),
 array([0.5303492 , 0.46642599], dtype=float32)]

In [49]:
_main_centroids = np.stack(main_centroid)
_main_centroids = torch.tensor(_main_centroids, requires_grad=True)

In [60]:
_class_centroids = np.vstack(class_centroids)
_class_centroids = torch.tensor(_class_centroids, requires_grad=True)
_class_centroids

tensor([0.2706, 0.4972], grad_fn=<SelectBackward0>)

In [63]:
kernel_optimizer = optim.SGD([
            {'params': _main_centroids, 'lr': 0.1},
            {'params': _class_centroids, 'lr': 0.1},
            #{'params': self._class_centroids, 'lr': self._lr},
        ])

In [None]:
  """
    def _get_centroids(self, data, data_labels):
        for c in self._n_classes:
            class_data = data[data_labels == c]
            main_centroid = torch.mean(class_data, dim=0)
            self._main_centroids.append(main_centroid.requires_grad_())
            self._main_centroids_labels.append(c)
            class_centroids = [torch.mean(cluster, dim=0) for cluster in torch.chunk(class_data, self._clusters)]
            self._class_centroids.append([centroid.requires_grad_() for centroid in class_centroids])
            self._class_centroid_labels.append([c] * self._clusters)
    """

    def _get_centroids(self, data, data_labels):
        # Initialize empty lists for main centroids and class centroids
        main_centroids = []
        class_centroids = []
        main_centroid_labels = []
        class_centroid_labels = []

        for c in self._n_classes:
            class_data = np.array(data[data_labels == c].detach())
            # Calculate the main centroid and add it to main_centroids
            main_centroid = np.mean(class_data, axis = 0)  # Shape [1, feature_dim]
            main_centroids.append(main_centroid.tolist())
            main_centroid_labels.append(c)
    
            # Calculate centroids for each cluster in the class and stack them into a single tensor
            class_centroids.append([np.mean(cluster.tolist(), axis=0).tolist() for cluster in np.array_split(class_data, self._clusters)])
            class_centroid_labels.append([c] * self._clusters)

        self._main_centroids = torch.tensor(main_centroids, requires_grad=True)
        self._main_centroids_labels = torch.tensor(main_centroid_labels)
        self._class_centroids = torch.tensor(class_centroids, requires_grad=True)
        self._class_centroid_labels = torch.tensor(class_centroid_labels)

In [None]:
  _class = epoch % len(self._n_classes)
                class_centroids = torch.tensor(self._class_centroids[_class], requires_grad=True)
                class_labels = self._class_centroid_labels[_class]
                main_centroid = torch.tensor(self._main_centroids[_class], requires_grad=True)

                x_0 = main_centroid.repeat(self._clusters, 1)
                x_1 = class_centroids
                
                K = self._kernel(x_0, x_1).to(torch.float32)

                loss = self._centroid_loss(K = K, Y=class_labels, centroid=main_centroid, cl=_class + 1)
                loss = loss.mean()
                loss.backward()

                for param in self._kernel.parameters():
                    if param.grad is not None:
                        print(f"Kernel Gradient: {param.grad}")
                for param in [self._main_centroids[_class], K]:
                    print(f"Centroid Gradient: {param.grad}")
            
                optimizer.step()
                print(f"Epoch {epoch + 1}th, Kernel Loss: {loss}" )

                self._loss_arr.append(loss.item())
                self._per_epoch_executions += x_0.shape[0]
                """
                # Kao loss
                loss_kao = -self._loss_kao(class_centroids, class_labels, self._main_centroids[_class])
                loss_kao.backward(retain_graph=True)
                optimizer.step() 

                # Co loss
                self._centroid_optimizer.zero_grad()
                loss_co = -self.loss_co(class_centroids, class_labels, self._main_centroids[_class], _class + 1)
                loss_co.backward(retain_graph=True)
                self._centroid_optimizer.step()
                
                

                loss_kao, loss_co = self._centroid_loss(class_centroids, class_labels, self._main_centroids[_class], _class + 1)
                loss_kao.backward(retain_graph=True)
                optimizer.step() 
                
                loss_co.backward(retain_graph=True)
                self._centroid_optimizer.step()

                self._per_epoch_executions += self._kernel._circuit_executions
                print(self._per_epoch_executions)
                print(f"Epoch {epoch + 1}th, Kernel Loss: {loss}" )
                """
