# Comparison with MNIST dataset and Vision Transformers

## PyTorch

In [1]:
import torch
import quantum_transformers.qmlperfcomp.torch_backend as qpctorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
train_dataloader, valid_dataloader = qpctorch.data.get_mnist_dataloaders(batch_size=64, num_workers=4, pin_memory=True)

Using device: cuda


### Classical

In [2]:
model = qpctorch.classical.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3)
qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device)

Epoch   1/10: 100%|██████████| 938/938 [00:12<00:00, 73.64batch/s, Loss = 1.3147, AUC = 92.17%]                                                                                                                                           
Epoch   2/10: 100%|██████████| 938/938 [00:08<00:00, 104.42batch/s, Loss = 0.9392, AUC = 95.47%]                                                                                                                                          
Epoch   3/10: 100%|██████████| 938/938 [00:09<00:00, 100.25batch/s, Loss = 0.7299, AUC = 97.09%]                                                                                                                                          
Epoch   4/10: 100%|██████████| 938/938 [00:08<00:00, 104.74batch/s, Loss = 0.6336, AUC = 97.56%]                                                                                                                                          
Epoch   5/10: 100%|██████████| 938/938 [00:08<00:00, 110.71b

TOTAL TIME = 93.33s
BEST AUC = 98.57% AT EPOCH 10





### Quantum with PennyLane

In [3]:
model = qpctorch.quantum.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3)
qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device)

Epoch   1/10: 100%|██████████| 938/938 [06:39<00:00,  2.35batch/s, Loss = 1.7853, AUC = 85.27%]                                                                                                                                           
Epoch   2/10: 100%|██████████| 938/938 [06:45<00:00,  2.32batch/s, Loss = 1.4544, AUC = 90.11%]                                                                                                                                           
Epoch   3/10: 100%|██████████| 938/938 [06:30<00:00,  2.40batch/s, Loss = 1.2237, AUC = 93.31%]                                                                                                                                           
Epoch   4/10: 100%|██████████| 938/938 [06:37<00:00,  2.36batch/s, Loss = 1.0775, AUC = 94.67%]                                                                                                                                           
Epoch   5/10: 100%|██████████| 938/938 [06:35<00:00,  2.37ba

TOTAL TIME = 3982.81s
BEST AUC = 96.13% AT EPOCH 10





### Quantum with PennyLane with Lightning-GPU device

In [4]:
model = qpctorch.quantum.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice="lightning.gpu")
qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device)

Epoch   1/10:   0%|          | 0/938 [00:00<?, ?batch/s]                                                                                                                                                                                  

Epoch   1/10:   0%|          | 2/938 [08:41<67:48:54, 260.83s/batch]                                                                                                                                                                      

The execution is very slow, so I stopped it.

### Quantum with TensorCircuit

In [2]:
model = qpctorch.quantum.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qml_backend="tensorcircuit")
qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device)

Please first ``pip install -U cirq`` to enable related functionality in translation module
Epoch   1/10: 100%|██████████| 938/938 [01:48<00:00,  8.63batch/s, Loss = 1.8936, AUC = 85.94%]                                                                                                                                           
Epoch   2/10: 100%|██████████| 938/938 [00:43<00:00, 21.46batch/s, Loss = 1.5321, AUC = 87.37%]                                                                                                                                           
Epoch   3/10: 100%|██████████| 938/938 [00:44<00:00, 21.26batch/s, Loss = 1.3957, AUC = 88.96%]                                                                                                                                           
Epoch   4/10: 100%|██████████| 938/938 [00:44<00:00, 21.11batch/s, Loss = 1.3201, AUC = 90.09%]                                                                                                             

TOTAL TIME = 504.57s
BEST AUC = 93.51% AT EPOCH 10





## JAX

In [3]:
import traceback
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'  # See https://github.com/google/jax/issues/12461#issuecomment-1256266598
import jaxlib
from jax.config import config
config.update("jax_enable_x64", True)
import catalyst
import quantum_transformers.qmlperfcomp.jax_backend as qpcjax
train_dataloader, valid_dataloader = qpcjax.data.get_mnist_dataloaders(batch_size=64)

### Classical

In [4]:
model = qpcjax.classical.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3)
qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)

Epoch   1/10: 100%|██████████| 937/937 [00:12<00:00, 74.94batch/s, Loss = 1.7705, AUC = 80.97%]                                                                                                                                           
Epoch   2/10: 100%|██████████| 937/937 [00:05<00:00, 175.54batch/s, Loss = 1.4029, AUC = 89.91%]                                                                                                                                          
Epoch   3/10: 100%|██████████| 937/937 [00:05<00:00, 176.86batch/s, Loss = 1.1311, AUC = 93.58%]                                                                                                                                          
Epoch   4/10: 100%|██████████| 937/937 [00:05<00:00, 172.44batch/s, Loss = 0.9299, AUC = 94.91%]                                                                                                                                          
Epoch   5/10: 100%|██████████| 937/937 [00:05<00:00, 174.38b

TOTAL TIME = 60.51s
BEST AUC = 97.83% AT EPOCH 10





### Quantum with PennyLane

In [5]:
model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3)
qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)

Epoch   1/10: 100%|██████████| 937/937 [01:14<00:00, 12.65batch/s, Loss = 2.3199, AUC = 50.00%]                                                                                                                                           
Epoch   2/10: 100%|██████████| 937/937 [00:48<00:00, 19.31batch/s, Loss = 2.3050, AUC = 50.10%]                                                                                                                                           
Epoch   3/10: 100%|██████████| 937/937 [00:48<00:00, 19.35batch/s, Loss = 2.3025, AUC = 50.01%]                                                                                                                                           
Epoch   4/10: 100%|██████████| 937/937 [00:48<00:00, 19.32batch/s, Loss = 2.3020, AUC = 50.07%]                                                                                                                                           
Epoch   5/10: 100%|██████████| 937/937 [00:48<00:00, 19.22ba

TOTAL TIME = 510.67s
BEST AUC = 50.40% AT EPOCH 9





### Quantum with PennyLane with Lightning-GPU device

Not working. See: https://discuss.pennylane.ai/t/incompatible-function-arguments-error-on-lightning-qubit-with-jax/2900.

In [6]:
try:
    model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice="lightning.gpu")
    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)
except jaxlib.xla_extension.XlaRuntimeError as e:
    print(traceback.format_exc())

Traceback (most recent call last):
  File "/tmp/ipykernel_443038/3515301543.py", line 3, in <module>
    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)
  File "/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py", line 73, in train_and_evaluate
    variables = model.init(params_key, x, train=False)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py", line 1845, in init
    _, v_out = self.init_with_output(
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py"

2023-08-14 06:44:18.259223: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:
    1. (self: pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128, arg0: List[int], arg1: bool, arg2: List[float]) -> None

Invoked with: <pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128 object at 0x7fcd033c9270>, [0], False, [array([-1.99065936, -1.39067658, -1.40838223, -1.28091254, -1.236017  ,
       -1.99065936, -0.79695004, -0.97037919, -0.83079853, -0.9329554 ,
       -1.99065936, -1.00111164, -1.10162356, -1.36933823, -0.90581727,
       -1.99065936, -0.77338261, -1.07736925, -1.27714229, -1.22361589,
       -1.99065936, -1.41910524, -1.18701463, -0.96778634, -1.39457697,
       -1.99065936, -1.48105851, -1.51591239, -0.46449003, -0.78824321,
       -1.99065936, -0.71558203, -1.24144874, -1.09839354, -1.3799354 ,
 

### Quantum with PennyLane with Lightning-GPU device and catalyst

Not supported.

In [7]:
try:
    model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice="lightning.gpu", use_catalyst=True)
    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)
except catalyst.CompileError as e:
    print(traceback.format_exc())

Traceback (most recent call last):
  File "/tmp/ipykernel_443038/3502656055.py", line 3, in <module>
    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)
  File "/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py", line 73, in train_and_evaluate
    variables = model.init(params_key, x, train=False)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py", line 1845, in init
    _, v_out = self.init_with_output(
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py"

### Quantum with PennyLane with Lightning

Same error as before.

In [8]:
try:
    model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice="lightning.qubit")
    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)
except jaxlib.xla_extension.XlaRuntimeError as e:
    print(traceback.format_exc())

Traceback (most recent call last):
  File "/tmp/ipykernel_443038/3447823501.py", line 3, in <module>
    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)
  File "/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py", line 73, in train_and_evaluate
    variables = model.init(params_key, x, train=False)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py", line 1845, in init
    _, v_out = self.init_with_output(
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py"

2023-08-14 06:44:19.334928: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:
    1. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: List[int], arg1: bool, arg2: List[float]) -> None

Invoked with: <pennylane_lightning.lightning_qubit_ops.StateVectorC128 object at 0x7fcc7838ccb0>, [0], False, [array([-1.99065936, -1.39067658, -1.40838223, -1.28091254, -1.236017  ,
       -1.99065936, -0.79695004, -0.97037919, -0.83079853, -0.9329554 ,
       -1.99065936, -1.00111164, -1.10162356, -1.36933823, -0.90581727,
       -1.99065936, -0.77338261, -1.07736925, -1.27714229, -1.22361589,
       -1.99065936, -1.41910524, -1.18701463, -0.96778634, -1.39457697,
       -1.99065936, -1.48105851, -1.51591239, -0.46449003, -0.78824321,
       -1.99065936, -0.71558203, -1.24144874, -1.09839354, -1.3799354 ,
       -1.99065936, -

### Quantum with PennyLane with Lightning and catalyst

Also results in error.

In [9]:
try:
    model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice="lightning.qubit", use_catalyst=True)
    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)
except Exception as e:
    print(traceback.format_exc())

Traceback (most recent call last):
  File "/tmp/ipykernel_443038/2603097982.py", line 3, in <module>
    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)
  File "/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py", line 73, in train_and_evaluate
    variables = model.init(params_key, x, train=False)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py", line 1845, in init
    _, v_out = self.init_with_output(
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py"

2023-08-14 06:44:20.747541: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
UNKNOWN: xla_python_gpu_callback XLA extension have thrown an exception: [/__w/catalyst/catalyst/runtime-build/_deps/pennylane_lightning-src/pennylane_lightning/src/simulator/KernelMap.hpp][Line:270][Method:assignKernelForOp]: Error in PennyLane Lightning: The given interval conflicts with existing intervals.
2023-08-14 06:44:20.747573: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: xla_python_gpu_callback XLA extension have thrown an exception: [/__w/catalyst/catalyst/runtime-build/_deps/pennylane_lightning-src/pennylane_lightning/src/simulator/KernelMap.hpp][Line:270][Method:assignKernelForOp]: Error in PennyLane Lightning: The given interval conflicts with existing intervals.; current profiling annotation: XlaModule:#hl

### Quantum with TensorCircuit

In [10]:
model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qml_backend="tensorcircuit")
qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)

Epoch   1/10: 100%|██████████| 937/937 [00:44<00:00, 20.97batch/s, Loss = 2.0976, AUC = 68.97%]                                                                                                                                           
Epoch   2/10: 100%|██████████| 937/937 [00:12<00:00, 76.27batch/s, Loss = 1.9541, AUC = 76.78%]                                                                                                                                           
Epoch   3/10: 100%|██████████| 937/937 [00:12<00:00, 76.29batch/s, Loss = 1.7557, AUC = 80.60%]                                                                                                                                           
Epoch   4/10: 100%|██████████| 937/937 [00:12<00:00, 73.74batch/s, Loss = 1.6315, AUC = 84.99%]                                                                                                                                           
Epoch   5/10: 100%|██████████| 937/937 [00:12<00:00, 75.55ba

TOTAL TIME = 154.64s
BEST AUC = 93.98% AT EPOCH 10



