# IMDb Reviews (Quantum)

This notebook trains and evaluates a quantum transformer for the IMDb Reviews sentiment classification task. Note that this is a text classification task.
You can find information about the dataset at https://www.tensorflow.org/datasets/catalog/imdb_reviews.

In [1]:
import jax

from quantum_transformers.datasets import get_imdb_dataloaders
from quantum_transformers.training import train_and_evaluate
from quantum_transformers.transformers import Transformer
from quantum_transformers.quantum_layer import get_circuit

data_dir = '/global/cfs/cdirs/m4392/salcc/data'

2023-10-09 15:23:21.880365: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-09 15:23:21.880392: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-09 15:23:21.880411: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Please first ``pip install -U cirq`` to enable related functionality in translation module


The models are trained using the following devices:

In [2]:
for d in jax.devices():
    print(d, d.device_kind)

gpu:0 NVIDIA A100-SXM4-40GB


Let's check how big is the vocabulary, and see an example of one example review (both in tokenized and raw form).

In [3]:
(imdb_train_dataloader, imdb_valid_dataloader, imdb_test_dataloader), vocab, tokenizer = get_imdb_dataloaders(batch_size=32, data_dir=data_dir, max_vocab_size=20_000, max_seq_len=512)
print(f"Vocabulary size: {len(vocab)}")
first_batch = next(iter(imdb_train_dataloader))
print(first_batch[0][0])
print(' '.join(map(bytes.decode, tokenizer.detokenize(first_batch[0])[0].numpy().tolist())))

Cardinalities (train, val, test): 22500 2500 25000
Vocabulary size: 19769
[  129    50   397   183    42  1734   940    17   101   163   495   163
  1023    96   163   270    17    50   510   376   102   103   109    17
   259   183   433   121   298   110    95 13096   586    17  7746  7130
    99   177   102   103    96    50    10    54   576   240   267   109
   108   131   102   104    50   142   167   152  1042   113    17   163
   381    42   259    17     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0  

Now let's train the quantum vision transformer on the best hyperparameters found using random hyperparameter search.

In [4]:
model = Transformer(num_tokens=len(vocab), max_seq_len=512, num_classes=2, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3,
                    quantum_attn_circuit=get_circuit(), quantum_mlp_circuit=get_circuit())
train_and_evaluate(model, imdb_train_dataloader, imdb_valid_dataloader, imdb_test_dataloader, num_classes=2, num_epochs=30)

Number of parameters = 122096


Epoch   1/30: 100%|██████████| 703/703 [01:44<00:00,  6.76batch/s, Loss = 0.6954, AUC = 49.56%]
Epoch   2/30: 100%|██████████| 703/703 [01:09<00:00, 10.18batch/s, Loss = 0.6930, AUC = 50.92%]
Epoch   3/30: 100%|██████████| 703/703 [01:09<00:00, 10.19batch/s, Loss = 0.6907, AUC = 58.98%]
Epoch   4/30: 100%|██████████| 703/703 [01:09<00:00, 10.19batch/s, Loss = 0.6861, AUC = 69.50%]
Epoch   5/30: 100%|██████████| 703/703 [01:08<00:00, 10.19batch/s, Loss = 0.6691, AUC = 71.83%]
Epoch   6/30: 100%|██████████| 703/703 [01:09<00:00, 10.19batch/s, Loss = 0.6076, AUC = 78.95%]
Epoch   7/30: 100%|██████████| 703/703 [01:08<00:00, 10.19batch/s, Loss = 0.5150, AUC = 85.68%]
Epoch   8/30: 100%|██████████| 703/703 [01:08<00:00, 10.19batch/s, Loss = 0.4528, AUC = 87.83%]
Epoch   9/30: 100%|██████████| 703/703 [01:08<00:00, 10.19batch/s, Loss = 0.4179, AUC = 89.23%]
Epoch  10/30: 100%|██████████| 703/703 [01:09<00:00, 10.19batch/s, Loss = 0.4574, AUC = 89.75%]
Epoch  11/30: 100%|██████████| 703/703 [

Total training time = 2105.31s, best validation AUC = 91.69% at epoch 25


Testing: 100%|██████████| 781/781 [00:28<00:00, 26.94batch/s, Loss = 0.9679, AUC = 89.46%]


(Array(0.96788067, dtype=float32),
 89.45922399288703,
 array([0.00000000e+00, 0.00000000e+00, 2.40038406e-04, ...,
        9.99679949e-01, 9.99679949e-01, 1.00000000e+00]),
 array([0.        , 0.00760365, 0.01816872, ..., 0.99991996, 1.        ,
        1.        ]))