# IMDB

Information about the dataset: https://www.tensorflow.org/datasets/catalog/imdb_reviews

In [1]:
import jax
import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')  # Ensure TF does not see GPU and grab all GPU memory.
tf.random.set_seed(42)  # For reproducibility.

from quantum_transformers.datasets import get_imdb_dataloaders
from quantum_transformers.training import train_and_evaluate
from quantum_transformers.transformers import Transformer

data_dir = '/global/cfs/cdirs/m4392/salcc/data'

2023-08-27 07:21:06.859612: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-08-27 07:21:06.859651: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-08-27 07:21:06.859675: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


Please first ``pip install -U cirq`` to enable related functionality in translation module


In [2]:
for d in jax.devices():
    print(d, d.device_kind)

gpu:0 NVIDIA A100-SXM4-40GB


In [3]:
(imdb_train_dataloader, imdb_val_dataloader, imdb_test_dataloader), vocab, tokenizer = get_imdb_dataloaders(batch_size=32, data_dir=data_dir, max_vocab_size=20_000, max_seq_len=512)
print(f"Vocabulary size: {len(vocab)}")
first_batch = next(iter(imdb_train_dataloader))
print(first_batch[0][0])
print(' '.join(map(bytes.decode, tokenizer.detokenize(first_batch[0])[0].numpy().tolist())))

Vocabulary size: 19769
[  136    95  3739    97   103   111   159   182   192   674   122  7218
   739    98    95   202    15   101   151   118   166   133   120   143
   541    97   148    15   373    15    96  1664   875    17  3282    15
   124  2883   121 13747    97   739    98    95  4431    15   123   118
  1174    42  1885  6049    17   126   188  1483   147    42   111   102
   153   125  1603  1915  2808    95   532   111   206   129   768  3960
    15   706   108    95  2209  1589    61    16  3932    17   170   114
  8124   103   111   106  1280    17     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0    

In [4]:
model = Transformer(num_tokens=len(vocab), max_seq_len=512, num_classes=2, hidden_size=64, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=32)
train_and_evaluate(model, imdb_train_dataloader, imdb_val_dataloader, imdb_test_dataloader, num_classes=2, num_epochs=30)

Epoch   1/30: 100%|██████████| 703/703 [00:10<00:00, 65.29batch/s, Loss = 0.5371, AUC = 81.43%] 
Epoch   2/30: 100%|██████████| 703/703 [00:04<00:00, 157.47batch/s, Loss = 0.3601, AUC = 92.67%]
Epoch   3/30: 100%|██████████| 703/703 [00:04<00:00, 158.32batch/s, Loss = 0.3100, AUC = 94.51%]
Epoch   4/30: 100%|██████████| 703/703 [00:04<00:00, 156.87batch/s, Loss = 0.3042, AUC = 94.81%]
Epoch   5/30: 100%|██████████| 703/703 [00:04<00:00, 155.62batch/s, Loss = 0.4049, AUC = 94.48%]
Epoch   6/30: 100%|██████████| 703/703 [00:04<00:00, 160.08batch/s, Loss = 0.4603, AUC = 93.93%]
Epoch   7/30: 100%|██████████| 703/703 [00:04<00:00, 158.13batch/s, Loss = 0.5529, AUC = 94.15%]
Epoch   8/30: 100%|██████████| 703/703 [00:04<00:00, 158.78batch/s, Loss = 0.4452, AUC = 94.12%]
Epoch   9/30: 100%|██████████| 703/703 [00:04<00:00, 159.93batch/s, Loss = 0.5232, AUC = 94.32%]
Epoch  10/30: 100%|██████████| 703/703 [00:04<00:00, 157.99batch/s, Loss = 0.8063, AUC = 92.93%]
Epoch  11/30: 100%|██████████|

Total training time = 140.17s, best validation AUC = 94.81% at epoch 4


Testing: 100%|██████████| 781/781 [00:04<00:00, 176.52batch/s, Loss = 0.3609, AUC = 92.90%]


In [5]:
model = Transformer(num_tokens=len(vocab), max_seq_len=512, num_classes=2, hidden_size=8, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=4)
train_and_evaluate(model, imdb_train_dataloader, imdb_val_dataloader, imdb_test_dataloader, num_classes=2, num_epochs=30)

Epoch   1/30: 100%|██████████| 703/703 [00:08<00:00, 78.31batch/s, Loss = 0.6904, AUC = 54.57%] 
Epoch   2/30: 100%|██████████| 703/703 [00:03<00:00, 183.14batch/s, Loss = 0.6843, AUC = 59.06%]
Epoch   3/30: 100%|██████████| 703/703 [00:03<00:00, 182.05batch/s, Loss = 0.5915, AUC = 79.31%]
Epoch   4/30: 100%|██████████| 703/703 [00:03<00:00, 181.99batch/s, Loss = 0.4942, AUC = 85.90%]
Epoch   5/30: 100%|██████████| 703/703 [00:03<00:00, 181.91batch/s, Loss = 0.4596, AUC = 88.95%]
Epoch   6/30: 100%|██████████| 703/703 [00:03<00:00, 184.21batch/s, Loss = 0.4069, AUC = 91.04%]
Epoch   7/30: 100%|██████████| 703/703 [00:03<00:00, 184.09batch/s, Loss = 0.3745, AUC = 92.02%]
Epoch   8/30: 100%|██████████| 703/703 [00:03<00:00, 182.92batch/s, Loss = 0.3801, AUC = 92.36%]
Epoch   9/30: 100%|██████████| 703/703 [00:03<00:00, 181.71batch/s, Loss = 0.4106, AUC = 92.84%]
Epoch  10/30: 100%|██████████| 703/703 [00:03<00:00, 184.05batch/s, Loss = 0.4546, AUC = 93.38%]
Epoch  11/30: 100%|██████████|

Total training time = 120.23s, best validation AUC = 93.39% at epoch 12


Testing: 100%|██████████| 781/781 [00:04<00:00, 188.65batch/s, Loss = 0.6430, AUC = 91.69%]
