# XLNet

XLNet is a generalized autoregressive language model that learns unsupervised representations of text sequences. This model incorporates modelling techniques from Autoencoder(AE) models(BERT) into AR models while avoiding limitations of AE.

To read about its architecture more, please refer [this](https://analyticsindiamag.com/guide-to-xlnet-for-language-understanding/) article.

# Usage

Let’s try using the XLNET base model for the purpose of classification. 

Unfortunately, XLNet isn’t available in the TensorFlow hub yet. We still can clone the official implementation from GitHub and work with it. This model is huge so it requires a system with lots of VRAM. Use tensorflow version 1.x as the current implementation may not work with 2.0


In [None]:
!python -m pip install pip --upgrade --user -q --no-warn-script-location
!python -m pip install numpy pandas seaborn matplotlib scipy statsmodels sklearn nltk gensim sentencepiece tensorflow keras --user -q --no-warn-script-location

import IPython
IPython.Application.instance().kernel.do_shutdown(True)

In [None]:
%tensorflow_version 1.x

TensorFlow 1.x selected.


In [None]:
! wget https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip
! unzip cased_L-12_H-768_A-12.zip

--2021-06-18 07:12:11--  https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.65.80, 142.251.33.208, 142.250.81.208, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.65.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 433638019 (414M) [application/zip]
Saving to: ‘cased_L-12_H-768_A-12.zip’


2021-06-18 07:12:14 (130 MB/s) - ‘cased_L-12_H-768_A-12.zip’ saved [433638019/433638019]

Archive:  cased_L-12_H-768_A-12.zip
   creating: xlnet_cased_L-12_H-768_A-12/
  inflating: xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt.index  
  inflating: xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt.data-00000-of-00001  
  inflating: xlnet_cased_L-12_H-768_A-12/spiece.model  
  inflating: xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt.meta  
  inflating: xlnet_cased_L-12_H-768_A-12/xlnet_config.json  


In [None]:
! wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
! tar zxf aclImdb_v1.tar.gz

--2021-06-18 07:12:19--  http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10
Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84125825 (80M) [application/x-gzip]
Saving to: ‘aclImdb_v1.tar.gz’


2021-06-18 07:12:22 (27.3 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]



In [None]:
! git clone https://github.com/zihangdai/xlnet.git

Cloning into 'xlnet'...
remote: Enumerating objects: 122, done.[K
remote: Total 122 (delta 0), reused 0 (delta 0), pack-reused 122[K
Receiving objects: 100% (122/122), 2.92 MiB | 33.99 MiB/s, done.
Resolving deltas: 100% (59/59), done.


In [None]:
SCRIPTS_DIR = 'xlnet'
DATA_DIR = 'aclImdb'
OUTPUT_DIR = 'proc_data/imdb'
PRETRAINED_MODEL_DIR = 'xlnet_cased_L-12_H-768_A-12'
CHECKPOINT_DIR = 'exp/imdb'

In [None]:
train_command = "python xlnet/run_classifier.py \
  --do_train=True \
  --do_eval=True \
  --eval_all_ckpt=True \
  --task_name=imdb \
  --data_dir="+DATA_DIR+" \
  --output_dir="+OUTPUT_DIR+" \
  --model_dir="+CHECKPOINT_DIR+" \
  --uncased=False \
  --spiece_model_file="+PRETRAINED_MODEL_DIR+"/spiece.model \
  --model_config_path="+PRETRAINED_MODEL_DIR+"/xlnet_config.json \
  --init_checkpoint="+PRETRAINED_MODEL_DIR+"/xlnet_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=8 \
  --eval_batch_size=8 \
  --num_hosts=1 \
  --num_core_per_host=1 \
  --learning_rate=2e-5 \
  --train_steps=4000 \
  --warmup_steps=500 \
  --save_steps=500 \
  --iterations=500"

! {train_command}




W0618 07:12:37.479147 140369649297280 module_wrapper.py:139] From xlnet/run_classifier.py:637: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.


W0618 07:12:37.479389 140369649297280 module_wrapper.py:139] From xlnet/run_classifier.py:637: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.


W0618 07:12:37.479596 140369649297280 module_wrapper.py:139] From xlnet/run_classifier.py:661: The name tf.gfile.Exists is deprecated. Please use tf.io.gfile.exists instead.


W0618 07:12:37.479862 140369649297280 module_wrapper.py:139] From xlnet/run_classifier.py:662: The name tf.gfile.MakeDirs is deprecated. Please use tf.io.gfile.makedirs instead.


W0618 07:12:37.544680 140369649297280 module_wrapper.py:139] From /content/xlnet/model_utils.py:27: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.


W0618 07:12:37.545109 140369649297280 module_wrapper.py:139] From /cont

In [None]:
import tensorflow as tf

for example in tf.python_io.tf_record_iterator("/content/proc_data/imdb/spiece.model.len-128.dev.eval.tf_record"):
    print(tf.train.Example.FromString(example))

In [None]:
%load_ext tensorboard

In [None]:
!tensorboard --inspect --event_file=/content/exp/imdb/eval/events.out.tfevents.1615095693.59aa587475c1

In [None]:
loss=[]
eval_loss=[]
eval_accuracy=[]
for e in tf.train.summary_iterator('/content/exp/imdb/eval/events.out.tfevents.1615095693.59aa587475c1'):
    for v in e.summary.value:
        if v.tag == 'loss':
            loss.append(v.simple_value)
        if v.tag == 'eval_loss':
            eval_loss.append(v.simple_value)
        if v.tag == 'eval_accuracy':
            eval_accuracy.append(v.simple_value)

In [None]:
loss=[]
for e in tf.train.summary_iterator('/content/exp/imdb/events.out.tfevents.1615093424.59aa587475c1'):
    for v in e.summary.value:
        if v.tag == 'loss':
            loss.append(v.simple_value)
        

In [None]:
import matplotlib.pyplot as plt
# fig,axes=plt.subplots(1,2)
plt.figure(figsize=(10,5))
plt.subplot(1, 2, 1)
plt.plot(eval_loss)
plt.xlabel('epochs')
plt.ylabel('Validation loss')
# plt.show()
plt.subplot(1, 2, 2)
plt.plot(eval_accuracy)
plt.xlabel('epochs')
plt.ylabel('Validation accuracy')
plt.tight_layout()
plt.savefig('acc.png')
plt.show()