In [1]:
from quantum_transformers.datasets.nlp import get_imdb_dataloaders
from quantum_transformers.training import train
from quantum_transformers.transformers import ClassicalTransformer

import torch

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.get_device_name() if device.type == 'cuda' else 'cpu'

'NVIDIA A100-PCIE-40GB'

In [3]:
(imdb_train_dataloader, imdb_valid_dataloader), vocab = get_imdb_dataloaders(batch_size=32)
print(f"Number of training examples: {len(imdb_train_dataloader.dataset)}, Number of validation examples: {len(imdb_valid_dataloader.dataset)}")
print(f"Vocabulary size: {len(vocab)}")
print(imdb_train_dataloader.dataset[0][1])

Number of training examples: 25000, Number of validation examples: 25000
Vocabulary size: 20439
I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, 

In [4]:
model = ClassicalTransformer(num_tokens=len(vocab), num_classes=2, hidden_size=64, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=32)
train(model, imdb_train_dataloader, imdb_valid_dataloader, num_classes=2, learning_rate=0.0003, num_epochs=30, device=device)

Epoch 1/30 (49.87s): Loss = 0.5543, AUC = 80.43%
Epoch 2/30 (98.11s): Loss = 0.4564, AUC = 83.66%
Epoch 3/30 (147.50s): Loss = 0.4526, AUC = 85.12%
Epoch 4/30 (195.85s): Loss = 0.4005, AUC = 86.74%
Epoch 5/30 (244.34s): Loss = 0.3795, AUC = 87.78%
Epoch 6/30 (292.87s): Loss = 0.4103, AUC = 88.44%
Epoch 7/30 (341.44s): Loss = 0.3947, AUC = 88.79%
Epoch 8/30 (389.75s): Loss = 0.3592, AUC = 89.39%
Epoch 9/30 (439.29s): Loss = 0.3721, AUC = 89.84%
Epoch 10/30 (489.57s): Loss = 0.3602, AUC = 90.24%
Epoch 11/30 (539.49s): Loss = 0.3558, AUC = 90.58%
Epoch 12/30 (589.37s): Loss = 0.4232, AUC = 90.67%
Epoch 13/30 (638.96s): Loss = 0.3888, AUC = 90.89%
Epoch 14/30 (689.06s): Loss = 0.3637, AUC = 91.10%
Epoch 15/30 (736.32s): Loss = 0.4269, AUC = 91.21%
Epoch 16/30 (783.19s): Loss = 0.4057, AUC = 91.33%
Epoch 17/30 (829.86s): Loss = 0.4214, AUC = 91.43%
Epoch 18/30 (876.85s): Loss = 0.3710, AUC = 91.58%
Epoch 19/30 (923.44s): Loss = 0.4700, AUC = 91.63%
Epoch 20/30 (970.42s): Loss = 0.4868, AUC 

In [5]:
model = ClassicalTransformer(num_tokens=len(vocab), num_classes=2, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3)
train(model, imdb_train_dataloader, imdb_valid_dataloader, num_classes=2, learning_rate=0.0003, num_epochs=30, device=device)

Epoch 1/30 (43.91s): Loss = 0.6910, AUC = 54.30%
Epoch 2/30 (88.22s): Loss = 0.6935, AUC = 55.08%
Epoch 3/30 (132.36s): Loss = 0.6860, AUC = 55.89%
Epoch 4/30 (176.74s): Loss = 0.6737, AUC = 58.78%
Epoch 5/30 (220.92s): Loss = 0.5804, AUC = 63.60%
Epoch 6/30 (265.14s): Loss = 0.5982, AUC = 66.66%
Epoch 7/30 (309.26s): Loss = 0.5927, AUC = 69.03%
Epoch 8/30 (353.41s): Loss = 0.5644, AUC = 71.09%
Epoch 9/30 (397.74s): Loss = 0.4978, AUC = 72.87%
Epoch 10/30 (442.29s): Loss = 0.4949, AUC = 74.38%
Epoch 11/30 (486.43s): Loss = 0.4723, AUC = 75.55%
Epoch 12/30 (529.93s): Loss = 0.4650, AUC = 76.65%
Epoch 13/30 (574.03s): Loss = 0.4591, AUC = 77.57%
Epoch 14/30 (618.21s): Loss = 0.4530, AUC = 78.47%
Epoch 15/30 (662.22s): Loss = 0.4545, AUC = 79.29%
Epoch 16/30 (706.22s): Loss = 0.4576, AUC = 79.83%
Epoch 17/30 (750.47s): Loss = 0.4451, AUC = 80.52%
Epoch 18/30 (794.84s): Loss = 0.4210, AUC = 81.14%
Epoch 19/30 (839.23s): Loss = 0.4195, AUC = 81.71%
Epoch 20/30 (883.80s): Loss = 0.4481, AUC 