Enhancing the BERT training with Semi-supervised Generative Adversarial Networks and LIME visualizations
GAN-Bert
combines the power of pre-trained Bert
and GANs
.
GAN-Bert
uses three data sets:
labeled.tsv
- Examples with labels for supervised trainingunlabeled.tsv
- Examples without labels for adversial trainingtest.tsv
- Examples with labels for evaluation
Every example in labeled.tsv
, unlabeled.tsv
and test.tsv
must come from same distribution (source).
For K-class classification task, modify line: 105
in data_processors.py
to include the class labels (in upper case) along with UNK
label for unlabeled examples.
-
label sentence label_1 sentence_1 label_2 sentence_2 ...
-
label sentence UNK sentence_1 UNK sentence_2 ...
Tutorial for training the model on Google Colab.
To run the codes, install BERT_base_cased
model as
!wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip
!unzip cased_L-12_H-768_A-12.zip
The codes were tested on Google Colab using GPU runtime. To perform similar experiment, execute
!pip uninstall tensorflow
!pip install tensorflow-gpu==1.14.0
!pip install gast==0.2.2
!pip install git+https://github.com/guillaumegenthial/tf_metrics.git
!pip install nltk
!pip install autocorrect
!pip install lime
!pip install tqdm
To make sure that the runtime will be using GPU, try this
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
!python -c 'import tensorflow as tf; tf.test.gpu_device_name()'
Then, run the ganbert
model as
%%shell
python -u ganbert.py \
--num_classes=3 \
--label_rate=0.02 \
--do_train=true \
--do_eval=true \
--do_predict=false \
--data_dir=data \
--vocab_file=cased_L-12_H-768_A-12/vocab.txt \
--bert_config_file=cased_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=cased_L-12_H-768_A-12/bert_model.ckpt \
--max_seq_length=64 \
--train_batch_size=64 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--warmup_proportion=0.1 \
--do_lower_case=false \
--output_dir=ganbert_output_model
To make visualization, prepare a .txt
file of instances you want to visualize. Then, run the ganbert
model as
%%shell
python -u ganbert.py \
--num_classes=3 \
--label_rate=0.02 \
--do_train=true \
--do_eval=false \
--do_visual=true \
--do_predict=false \
--data_dir=data \
--comment_dir=path_to_the_txt_file \
--visual_dir=directory_to_store_visualization \
--vocab_file=cased_L-12_H-768_A-12/vocab.txt \
--bert_config_file=cased_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=cased_L-12_H-768_A-12/bert_model.ckpt \
--max_seq_length=64 \
--train_batch_size=64 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--warmup_proportion=0.1 \
--do_lower_case=false \
--num_features=5 \
--num_samples=20 \
--output_dir=ganbert_output_model
The visualizations will be generated as HTML files (one file per instance) in visual_dir
directory. num_features
and num_samples
are hyper-parameters for LIME visualization. Read the documentation for more details.