This is the pytorch implementation of Triplet Attention in the KDD'21 paper: Triplet Attention: Rethinking the similarity in Transformers.
- Python 3.6
- numpy==1.17.3
- scipy==1.1.0
- pandas==0.25.1
- torch==1.2.0
- tqdm==4.36.1
- matplotlib==3.1.1
- tokenizers==0.10.3
- ...
Dependencies can be installed using the following command:
pip install -r requirements.txt
We implement BERT-A3 and DistilBERT-A3 in huggingface transformers
, you can use BERT-A3 or DistilBERT-A3 model like BERT or DistilBERT model in huggingface transformers
.
build BERT-A3
from transformers import BertTokenizer, BertModel, BertConfig
config = BertConfig.from_pretrained('bert-base-uncased')
config.group_size = 2 # number of triplet attention head
config.cross_type = 0 # cross product type (0: L cross product with permutation, 1: L*L cross product)
config.agg_type = 0 # aggregation type when using L*L cross product
config.absolute_flag = 0 # whether to use absolute value of triplet attention (1: use abs)
config.random_flag = 0 # permutation type (0: multi permutation, 1:only random permutation)
config.permute_type = '1,2,3,4,5' # permutaion type groups
config.permute_back = 0 # whether to do permutation inverse (0: do permutation inverse)
config.Tlayers = '0,1,2,8,9,10' # layers which use triplet attention heads
config.key2_flag = 0 # whether to use key2 linear layer to get key_triplet2 (0: use key2 layer)
config.head_choice = 12 # whether to choose triplet attention heads randomly (0: choose the last 3*group_size heads as triplet attention heads, 1: randomly choose 3*group_size heads as triplet attention heads)
bert_A3 = BertModel.from_pretrained('bert-base-uncased', config=config)
use BERT-A3
from transformers import BertTokenizer, BertModel, BertConfig
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
config = BertConfig.from_pretrained('bert-base-uncased')
bert_A3 = BertModel.from_pretrained('bert-base-uncased', config=config)
inputs = tokenizer('hello world',return_tensors="pt")
outputs = bert_A3(**inputs)
You can refer to /transformers/models/bert/modeling_bert.py
to get more details.
Commands for training and testing the model BERT-A3 on GLUE task (rte):
python run_glue.py --model_name_or_path bert-base-uncased --task_name rte --do_train --do_eval --do_predict --max_seq_length 128 --per_device_train_batch_size 16 --per_device_eval_batch_size 16 --cross_type 0 --agg_type 0 --tlayers '0,1,2,3' --learning_rate 3e-5 --num_train_epochs 6 --key2_flag 0 --random_flag 0 --absolute_flag 0 --permute_back 0 --permute_type '0,1,3,5,6' --head_choice 0 --group_size 1 --overwrite_output_dir --output_dir ./run/
If you find this repository useful in your research, please consider citing the following paper:
@inproceedings{haoyietal-tripletAttention-2021,
author = {Haoyi Zhou and
Jianxin Li and
Jieqi Peng and
Shuai Zhang and
Shanghang Zhang},
editor = {Feida Zhu and
Beng Chin Ooi and
Chunyan Miao},
title = {Triplet Attention: Rethinking the Similarity in Transformers},
booktitle = {The 27th {ACM} {SIGKDD} Conference on Knowledge Discovery and Data Mining, {KDD} 2021, Virtual Event},
pages = {2378--2388},
publisher = {{ACM}},
year = {2021},
}