# <span style="font-family: 'Computer Modern'; font-size: 42pt; font-weight: bold;">Quantum Convolutional Neural Network (QCNN): *Laboratory of Particle Physics and Cosmology (LPPC)*</span>

In [1]:
#### ***** IMPORTS / DEPENDENCIES *****:

### *** PLOTTING ***:
import matplotlib; # (NOT ACCESSED)
import matplotlib.pyplot as plt

### *** PENNYLANE ***:
import pennylane as qml
from pennylane import numpy as np

### *** DATA ***:
import numpy as np
# import pandas as pd # (NOT ACCESSED)
import seaborn as sns
sns.set()

### *** JAX ***:
import jax;
## JAX CONFIGURATIONS:
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
# import jax.experimental.sparse as jsp # (NOT ACCESSED)
# import jax.scipy.linalg as jsl  # (NOT ACCESSED)

### *** RNG ***:
seed = 0
# Using NumPy (base):
# rng = np.random.default_rng(seed=seed) # ORIGINAL (NumPy)
# Using JAX (base):
rng_jax = jax.random.PRNGKey(seed=seed) # *1* (JAX)
rng_jax_arr = jnp.array(rng_jax) # *2* (JAX)

### *** OTHER ***:
# from glob import glob

In [2]:
#### ***** PACKAGE IMPORTS (IN PROGRESS) *****:
# ****************************************************************************************
## *1* MNIST DATA LOADING CLASS:
# from lppc_qcnn.load_qc_data import LoadDataQC # LoadDataQC() <--- STATIC METHOD

## *2* QUANTUM CIRCUIT AND LAYERS CLASS:
# from lppc_qcnn.circuit_layers import LayersQC # LayersQC() <--- INSTANCE METHOD (SELF)
# -> Define Instance of LayersQC:
# layers_obj = LayersQC()

## *3* TRAIN QCNN / RESULTS CLASS:
# from lppc_qcnn.circuit_layers import TrainQC # TrainQC() <--- INSTANCE METHOD (SELF)
# -> Define Instance of TrainQC:
# layers_obj = TrainQC()

## *$* QUANTUM AND MATH OPERATORS CLASS:
# from lppc_qcnn.qc_operators import QuantumMathOps # <--- STATIC METHOD
# -> Define Instance of QuantumMathOps:
# qmo_obj = QuantumMathOps()
# ****************************************************************************************

<span style="font-family: 'Computer Modern'; font-weight: bold; font-size: 24pt;">LOADING MNIST DATASET</span>

In [3]:
# ********************************************
#           INITIAL PARAMETER SETUP
# ********************************************

## MNIST DATA LOADING CLASS:
from lppc_qcnn.load_qc_data import LoadDataQC # <--- STATIC METHOD

## DEFINE VARIABLES:
n_qubits = 6 # Number of qubits
active_qubits = 6 # Active qubits
# active_qubits = list(range(active_qubits))
num_wires = 6 # Number of wires
num_wires_draw = 2 # Number of wires (DRAWINGS)
# num_wires_test = 4 # Number of wires (TEST)

## QUANTUM DEVICE:
# device = qml.device("default.mixed", wires=num_wires)
device = qml.device("default.qubit", wires=num_wires) # Six-qubit device

In [4]:
# ********************************************
#          LOADING THE MNIST DATASET
# ********************************************

## DEFINE VARIABLES (DATA):
num_train = jnp.int64(2) # Binary classification
num_test = jnp.int64(2)

n_train = jnp.int64(2)
n_test = jnp.int64(2)

## DATA PARAMETERS:
rng_jax = jax.random.PRNGKey(seed=seed)  # Random Number Generator (JAX)
# rng_np = np.random.default_rng(seed=seed) # Random Number Generator (NUMPY)

## DIGITS DATA (ORIGINAL):
# (Note: 'n_train'/'n_test' (N) = VARIABLE, 'num_train'/'num_test' (NUM) = FUNCTION ARGUMENT)
# x_train, y_train, x_test, y_test = LoadDataQC.load_digits_data(n_train, n_test, rng) # Loading digits
x_train, y_train, x_test, y_test = LoadDataQC.load_digits_data_jaxV2(num_train=n_train, num_test=n_test, rng=rng_jax) # Loading digits


##         ***** FUNCTIONALITY CHECK PRINT STATEMENTS (DATA) *****
# -------------------------------------------------------------------------
print(f"{'='*15} FUNCTIONALITY CHECK (DATA) {'='*15}")

# Shapes and Types:
print(f"\n{'='*14} (1) SHAPES AND TYPES {'='*14}")
print(f"• x_train type:  {type(x_train)}  | shape:  {x_train.shape}")
print(f"• y_train type:  {type(y_train)}  | shape:  {y_train.shape}")
print(f"• x_test type:  {type(x_test)}  | shape:  {x_test.shape}")
print(f"• y_test type:  {type(y_test)}  | shape:  {y_test.shape}")

# Normalization:
print(f"\n{'='*10} (2) NORMALIZATION {'='*10}")
print(f"• x_train first row norm:  {np.linalg.norm(x_train[0])}")
print(f"• x_test first row norm:  {np.linalg.norm(x_test[0])}")

# Label Uniqueness:
print(f"\n{'='*10} (3) LABEL UNIQUENESS {'='*10}")
print(f"• Unique labels -> y_train:  {np.unique(y_train)}")
print(f"• Unique labels -> y_test:  {np.unique(y_test)}")
# -------------------------------------------------------------------------

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int64[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function load_digits_data_jaxV2 at /Users/seanchisholm/VSCode_LPPC/qcnn-lppc/lppc_qcnn/load_qc_data.py:173 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i64[1797] = device_put[devices=[None] srcs=[None]] b
    from line /Users/seanchisholm/VSCode_LPPC/qcnn-lppc/lppc_qcnn/load_qc_data.py:200:17 (LoadDataQC.load_digits_data_jaxV2)

  operation a:i64[] = convert_element_type[new_dtype=int64 weak_type=False] b
    from line /Users/seanchisholm/VSCode_LPPC/qcnn-lppc/lppc_qcnn/load_qc_data.py:203:39 (LoadDataQC.load_digits_data_jaxV2)

  operation a:i64[] = convert_element_type[new_dtype=int64 weak_type=False] b
    from line /Users/seanchisholm/VSCode_LPPC/qcnn-lppc/lppc_qcnn/load_qc_data.py:203:55 (LoadDataQC.load_digits_data_jaxV2)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [None]:
# ********************************************
#        VISUALIZING THE MNIST DATASET
# ********************************************

## DRAW MNIST IMAGE:
LoadDataQC.draw_mnist_data()

<span style="font-family: 'Computer Modern'; font-weight: bold; font-size: 24pt;">CONSTRUCTING QUANTUM CIRCUIT</span>

In [None]:
# ********************************************
#    VISUALIZING / PLOTTING QUANTUM CIRCUIT
# ********************************************

## QUANTUM CIRCUIT AND LAYERS CLASS:
from lppc_qcnn.circuit_layers import LayersQC # <--- INSTANCE METHOD (SELF)
# Define Instance of LayersQC:
layers_obj = LayersQC()

## DEFINE SAMPLE WEIGHTS / FEATURES:
# (Note: Adjust second dimension as needed)
weights = np.random.rand(81, 2) # <--- SHAPE ~ [(num_wires // 2) * (3 ** 3)]
# weights = jnp.array(weights) # WRAP WITH JAX
# weights = np.random.rand(num_wires, 2)
last_layer_weights = np.random.rand(4 ** 2 - 1)
# last_layer_weights = jnp.array(last_layer_weights) # WRAP WITH JAX
# last_layer_weights = np.random.rand(4 ** (num_wires // 2) - 1)
features = np.random.rand(2 ** num_wires)
# features = jnp.array(features) # WRAP WITH JAX

## DRAW QUANTUM CIRCUIT:
fig, ax = qml.draw_mpl(layers_obj.conv_net)(
layers_obj, weights, last_layer_weights, features
)

print("*** QCNN QUANTUM CIRCUIT ***")
plt.show()

In [None]:
## DEFINE VARIABLES (CONV DRAWING):
n_qubits = 6  # Number of qubits
active_qubits = list(range(n_qubits))  # Active qubits
num_wires = 6  # Number of wires
num_wires_draw = 2  # Number of wires (DRAWINGS)

params_conv = weights[:, 0]  # Use appropriate slicing based on your need

## CONVOLUTIONAL LAYER:
@qml.qnode(device)
def conv_circuit(params, active_qubits):
    layers_obj.three_conv_layer(params, active_qubits)
    return qml.probs(wires=active_qubits[:num_wires_draw])

## DRAW CONVOLUTIONAL LAYER CIRCUIT:
fig, ax = qml.draw_mpl(conv_circuit)(params_conv, active_qubits)
print("*** QCNN CONVOLUTIONAL LAYER CIRCUIT ***")
plt.show()

In [None]:
## CONVOLUTIONAL AND POOLING LAYER:
@qml.qnode(device)
def conv_and_pooling_circuit(kernel_weights, n_wires):
    layers_obj.conv_and_pooling(kernel_weights, n_wires)
    return qml.probs(wires=n_wires[:num_wires_draw])

## DRAW CONVOLUTIONAL AND POOLING LAYER CIRCUIT:
fig, ax = qml.draw_mpl(conv_and_pooling_circuit)(weights[:, 0], active_qubits)
print("*** QCNN CONVOLUTIONAL AND POOLING LAYER CIRCUIT ***")
plt.show()

In [None]:
# ********************************************
#         TRAINING QCNN / RESULTS (V1)
# ********************************************

## TRAINING QCNN CLASS:
from lppc_qcnn.circuit_layers import TrainQC  # TrainQC() <--- STATIC METHOD (FORMERLY INSTANCE METHOD (SELF))

# Define Instance of TrainQC (AS NEEDED):
# train_obj = TrainQC
# train_obj = TrainQC() # <- INSTANCE METHOD (SELF)

## DEFINE TRAIN PARAMETERS:
n_test = 2
# n_test = jnp.int64(2)
n_train = 2
# n_train = jnp.int64(2)
n_epochs = 100
# n_epochs = jnp.int64(100)
n_reps = 10
# n_reps = jnp.int64(10)

## DEFINE TRAIN PARAMETERS (DUMMY):
num_test = jnp.int64(2) # for jax dummy variable access (functionality)
num_train = jnp.int64(2) # for jax dummy variable access (functionality)
num_epochs = jnp.int64(100) # for jax dummy variable access (functionality)

# Train QCNN and get results:
results_df = TrainQC.compute_aggregated_results(num_train=n_train, num_test=n_test, num_epochs=n_epochs) # *6*

## *** ALTERNATES ***:
#results_df = train_obj.compute_aggregated_results(train_obj, n_train, n_test)  # *1*
#results_df = train_obj.compute_aggregated_results(n_train, n_test)  # *2*
#results_df = train_obj.compute_aggregated_results(n_train=n_train, n_test=n_test) # *3*
#results_df = compute_results(n_train=n_train, n_test=n_test, n_epochs=n_epochs) *4*
#results_df = compute_results_jit(n_train=n_train, n_test=n_test, n_epochs=n_epochs) # *5*

In [None]:
# ********************************************
#     PLOTTING AGGREGATED TRAINING RESULTS
# ********************************************

## DEFINE TRAIN PARAMETERS (ALSO ABOVE):
# n_test = 2
n_train = 2
n_epochs = 100
steps = 100
# train_obj = TrainQC() (RECALL)

# Plot aggregated training results:
TrainQC.plot_aggregated_results(results_df, n_train, steps=n_epochs, 
                                  title_loss='Train and Test Losses', 
                                  title_accuracy='Train and Test Accuracies', 
                                  markevery=10) # *1*

## ALTERNATES:
'''
TrainQC.plot_aggregated_results(results_df, n_train, steps, 
                                  title_loss='Train and Test Losses', 
                                  title_accuracy='Train and Test Accuracies', 
                                  markevery=10) # *2* (NO SELF)
TrainQC.plot_aggregated_results(results_df, n_train=n_train, steps=steps, 
                                  title_loss='Train and Test Losses', 
                                  title_accuracy='Train and Test Accuracies', 
                                  markevery=10) # *3*
'''

***

<span style="font-family: 'Computer Modern'; font-weight: bold; font-size: 24pt;">CODE TESTING / VALIDATION</span>

In [None]:
### ***** DATA (FOR FUTURE IMPLEMENTATION) *****:
# x_train, y_train, x_test, y_test = load_moments(n_train, n_test, rng) # Loading moments
# x_train, y_train, x_test, y_test = load_IC_data(n_train, n_test, rng) # Loading IC data

In [None]:
### ***** TESTING RNG TYPES *****:

# Example usage of numpy.random.Generator
rng = np.random.default_rng()
print(f"Type of rng: {type(rng)}")

# Example usage of jax.random.PRNGKey
rng_jax = jax.random.PRNGKey(0)
print(f"Type of rng_jax: {type(rng_jax)}")

In [None]:
### ***** RANDOM *****:

'''
a = x_train[0]

# ************************************************

print(type(a))

# ************************************************

full_array = jnp.array(
        [
            [
                [0,0,0],
                [2,2,2],
                [0,0,0],
                [0,0,0]
            ],
            [
                [1,1,1],
                [0,0,0],
                [0,0,0],
                [0,0,0]
            ],
            [
                [1,1,1],
                [0,0,0],
                [0,0,0],
                [0,0,0]
            ]
        ]
)
'''

In [None]:
### ***** TESTING JIT-COMPILED DATA LOADING *****:

'''
## JAX DIGITS DATA:
# Load data outside JIT-compiled function (AS NEEDED):
x_train, y_train, x_test, y_test = LoadDataQC.load_digits_data(n_train, n_test, rng)
'''

'''
### ***** ATTEMPT 1 (Wrapping 'compute_aggregated_results') *****:
# Define wrapper function to call method and apply jax:
compute_results = jax.jit(train_obj.compute_aggregated_results)

# Run training for multiple sizes and aggregate results:
results_df = compute_results(n_train=n_train, n_test=n_test, n_epochs=n_epochs) # *4 (WRAPPED WITH JAX) *

### ***** ATTEMPT 2 (Wrapping 'compute_aggregated_results') *****:
# Wrapper Function for 'compute_aggregated_results':
def compute_results(train_obj, n_train, n_test, n_epochs):
    return train_obj.compute_aggregated_results(n_train=n_train, n_test=n_test, n_epochs=n_epochs)
    # return train_obj.compute_aggregated_results(train_obj, n_train=n_train, n_test=n_test, n_epochs=n_epochs)
# compute_results = jax.jit(train_obj.compute_aggregated_results)

### ***** ATTEMPT 3 (Wrapping 'compute_aggregated_results') *****:
train_obj = TrainQC()
# New results function for wrapping:
def compute_results_wrapper(n_train, n_test, n_epochs):
    return train_obj.compute_aggregated_results(n_train, n_test, n_epochs)

# JIT-compile Wrapper Function:
compute_results_jit = jax.jit(compute_results_wrapper)
#compute_jit = jax.jit(lambda n_train, n_test, n_epochs: compute_results_wrapper(train_obj, n_train, n_test, n_epochs))
'''

In [None]:
# ********************************************
#         TRAINING QCNN / RESULTS (V2)
# ********************************************

'''
### ***** TRAINING QCNN W/O LOOP FUNCTION ('train_qcnn) *****:
from lppc_qcnn.circuit_layers import TrainQC  # TrainQC() <--- INSTANCE METHOD (SELF)

# Define Instance of TrainQC:
# train_obj = TrainQC
train_obj = TrainQC()

## DEFINE TRAIN PARAMETERS:
n_test = 2
n_train = 2
n_epochs = 100
n_reps = 10

# Define wrapper function to call method and apply jax:
@jax.jit
def run_iterations(n_train, n_test):
    return train_obj.run_iterations(n_train=n_train, n_test=n_test)

## RUN TRAINING LOOP:
train_sizes = [2]
results_df = run_iterations(n_train=n_train, n_test=n_test)
for n_train in train_sizes[1:]:
    results_df = pd.concat([results_df, run_iterations(n_train=n_train, n_test=n_test)])
'''

In [None]:
### ***** FUNCTIONALITY CHECK FOR TRAINING DICTIONARY (TrainQC.train_qcnn(args, *kwargs)) *****:

'''
train_dict = dict(
    n_train=[n_train] * n_epochs,
    step=jnp.arange(1, n_epochs + 1, dtype=int), # NP -> JNP
    train_cost=train_cost_epochs,
    train_acc=train_acc_epochs,
    test_cost=test_cost_epochs,
    test_acc=test_acc_epochs,
)
# TYPE (DICTIONARY):
print(f"train_dict: type = {type(train_dict)}")

# SHAPES AND TYPES (DICTIONARY ITEMS):
print(f"train_dict shapes and types:")
for key, value in train_dict.items():
    print(f"{key}: shape = {jnp.shape(value)}, type = {type(value)}") # NP -> JNP
'''

***