In [None]:
!pip install -q transformers
!pip install -q simpletransformers
!pip install -q datasets

In [2]:
import pandas as pd
from datasets import load_dataset
from simpletransformers.classification import ClassificationModel
import pandas as pd
import logging
import sklearn
import torch
from torch.cuda import is_available

# Load Dataset

In [None]:
# Loading the training and test data 
dataset_train = load_dataset('imdb',split='train')
train_df=pd.DataFrame(dataset_train)

dataset_test = load_dataset('imdb',split='test')
test_df=pd.DataFrame(dataset_test)

# Setting up hyperparameters

In [5]:
# The training arguments are adopted from the Yang et al. (2019) XLNet paper.
training_arguments = {
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'sliding_window': True,
    'max_seq_length': 64,
    'num_train_epochs': 1,
    'learning_rate': 0.00001,
    'weight_decay': 0.01,
    'train_batch_size': 128,
    'fp16': True,
    'output_dir': '/outputs/',
}

# Setting up CUDA

In [6]:
# Set device (preferrably CUDA)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

cuda


# Training

In [None]:
# Set up the logger
logging.basicConfig(level=logging.DEBUG)
XLNet_transformers_logger = logging.getLogger('transformers')
XLNet_transformers_logger.setLevel(logging.WARNING)

# Load the pre-trained base cased XLNet model.
XLNet_base_model = ClassificationModel('xlnet', 'xlnet-base-cased', num_labels=2, args=training_arguments, use_cuda=use_cuda) 

# Train the model with no validation dataset
XLNet_base_model.train_model(train_df)

# Evaluate the model
output = XLNet_base_model.eval_model(test_df, acc=sklearn.metrics.accuracy_score)
result = output[0]

# Evaluation

In [12]:
print("Evaluation Loss:", result['eval_loss'])
print("Accuracy:", result['acc'])
Precision = result['tp']/(result['tp'] + result['fp'])
print("Precision:", Precision)
Recall = result['tp']/(result['tp'] + result['fn'])
print("Recall:", Recall)
F1_Score = (2 * Precision * Recall)/ (Precision + Recall)
print("F1-Score:", F1_Score)

Evaluation Loss: 0.38348051576621583
Accuracy: 0.91688
Precision: 0.8845188902007084
Recall: 0.95896
F1-Score: 0.92023645017657


# End Remarks
- Without any data preprocessing, the base XLNet model achieved 91.7% accuracy on the test set after only 1 epoch.
- XLNet is bigger than BERT and its training objective is to predict word in a sequence from any permutation of other words in the sequence which capture relationships among words better. It used 'relative positional encoding' to capture the positional information of words.