In [1]:
from transformers import XLMRobertaTokenizer, XLMRobertaConfig, XLMRobertaForSequenceClassification
from textpruner import EmbeddingPruner, TransformerPruner
from textpruner import GeneralConfig, EmbeddingPruningConfig, TransformerPruningConfig

In [2]:
import torch,os,json,tqdm
import logging
# Setup logging
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    level=logging.INFO,)
logger = logging.getLogger(__name__)


In [3]:
device='cuda' #'cuda'
output_dir='./pruned_models'
batch_size=32
eval_langs = ['zh']
data_dir='/yrfs1/rc/zqyang5/Xtreme/datasets/XNLI'
split = 'test'
max_seq_length=128
taskname='xnli'

#### Init Model

In [4]:
#from modeling import XLMRForGLUESimple
#xlmr_class_ckpt_file='/work/rc/zqyang5/cross-lingual/xnli/xlmr/xnli_XbTrainEn_lr3e4_s4_bs32/gs42948.pkl'
xlmr_vocab_file='./pretrained-models/xlm-r-base/sentencepiece.bpe.model'
xlmr_config_file='./pretrained-models/xlm-r-base/config.json'
xlmr_ckpt_file = './pretrained-models/xlm-r-base/pytorch_model.bin'
xlmr_classification_ckpt_file ='/work/rc/zqyang5/for_pruning/xnli/xnli_zhTrainXb_lr2e4_s4_bs32/gs49084.pkl'

def init_xlmrForSentenceClassification(config_file=None, vocab_file=None, ckpt_file=None, num_labels=3):
    tokenizer, model = None, None
    if vocab_file is not None:
        tokenizer=XLMRobertaTokenizer(vocab_file=vocab_file)
    if config_file is not None:
        config=XLMRobertaConfig.from_json_file(config_file)
        config.num_labels=num_labels
        if ckpt_file is not None:
            model = XLMRobertaForSequenceClassification.from_pretrained(ckpt_file,config=config)
        else:
            model = XLMRobertaForSequenceClassification(config=config)
    return tokenizer,model


tokenizer, model = init_xlmrForSentenceClassification(
    xlmr_config_file, 
    xlmr_vocab_file, 
    xlmr_classification_ckpt_file,
    num_labels=3)
model.to(device)
print("Current Vocab size:",tokenizer.vocab_size)
print("Current Embedding size:", model.get_input_embeddings().weight.shape)



2021-06-25 11:08:35,515 - INFO - transformers.modeling_utils - loading weights file /work/rc/zqyang5/for_pruning/xnli/xnli_zhTrainXb_lr2e4_s4_bs32/gs49084.pkl
2021-06-25 11:08:55,327 - INFO - transformers.modeling_utils - All model checkpoint weights were used when initializing XLMRobertaForSequenceClassification.

2021-06-25 11:08:55,329 - INFO - transformers.modeling_utils - All the weights of XLMRobertaForSequenceClassification were initialized from the model checkpoint at /work/rc/zqyang5/for_pruning/xnli/xnli_zhTrainXb_lr2e4_s4_bs32/gs49084.pkl.
If your task is similar to the task the model of the ckeckpoint was trained on, you can already use XLMRobertaForSequenceClassification for predictions without further training.


Current Vocab size: 250002
Current Embedding size: torch.Size([250002, 768])


In [5]:
# from transformers import BertTokenizer 
# bert_tokenizer = BertTokenizer(vocab_file='/yrfs1/rc/zqyang5/pretrained-models/bert/base_uncased/vocab.txt')
# print(tokenizer([("Hello world","Goodbye!")],max_length=10,truncation=True,padding='max_length',return_token_type_ids=True))
# print(bert_tokenizer([("Hello world","Goodbye!")],max_length=10,truncation=True,padding='max_length',return_token_type_ids=True))

# Embedding Pruning

#### Init data: extract sentences from the datasets

In [7]:
def extract_sentences_from_xnli(data_files):
    results = []
    for data_file in data_files:
        with open(data_file,'r',encoding='utf-8') as f:
            lines = f.readlines()
        for line in tqdm.tqdm(lines):
                fields = line.strip().split('\t')
                for field in fields:
                    results.append(field)
    return results

data_files = ['./datasets/multinli.train.zh.tsv'] #,'./datasets/multinli.train.zh.tsv']
lines = extract_sentences_from_xnli(data_files)
print("Number of lines: ", len(lines))
print(f"Example:{lines[10]}")

100%|██████████| 392703/392703 [00:00<00:00, 601176.19it/s]


Number of lines:  1178109
Example:我 团队 的 一个 成员 将 非常 精确 地 执行 你 的 命令


#### Init Embedding Pruner

In [8]:
general_config = GeneralConfig()
embedding_pruning_config = EmbeddingPruningConfig()
embedding_pruner = EmbeddingPruner(model, tokenizer)

2021-06-25 11:09:18,268 - INFO - textpruner.configurations - Using current cuda device
2021-06-25 11:09:18,269 - INFO - textpruner.configurations - Using current cuda device


In [9]:
embedding_pruner.prune_embeddings(dataiter=lines)

100%|██████████| 1178109/1178109 [03:02<00:00, 6449.16it/s]


In [10]:
print("New embedding size:", model.get_input_embeddings().weight.shape)
embedding_pruner.save_model()

New embedding size: torch.Size([23553, 768])
New embedding size 23553 pruned vocab file has been saved to ./pruned_models/pruned_23553V/sentencepiece.bpe.model. Reintialize the tokenizer!


2021-06-25 11:12:28,147 - INFO - transformers.configuration_utils - Configuration saved in ./pruned_models/pruned_23553V/config.json
2021-06-25 11:12:28,148 - INFO - textpruner.embedding_pruner - Model and configuration have been saved to ./pruned_models/pruned_23553V


#### reload

In [11]:
pruned_tokenizer, pruned_model = init_xlmrForSentenceClassification(
    config_file='./pruned_models/pruned_23553V/config.json',
    vocab_file='./pruned_models/pruned_23553V/sentencepiece.bpe.model',
    ckpt_file='./pruned_models/pruned_23553V/model.pkl')

pruned_model.to(device)
print("New vocab size:", pruned_tokenizer.vocab_size)

2021-06-25 11:12:33,240 - INFO - transformers.modeling_utils - loading weights file ./pruned_models/pruned_23553V/model.pkl
2021-06-25 11:12:39,177 - INFO - transformers.modeling_utils - All model checkpoint weights were used when initializing XLMRobertaForSequenceClassification.

2021-06-25 11:12:39,178 - INFO - transformers.modeling_utils - All the weights of XLMRobertaForSequenceClassification were initialized from the model checkpoint at ./pruned_models/pruned_23553V/model.pkl.
If your task is similar to the task the model of the ckeckpoint was trained on, you can already use XLMRobertaForSequenceClassification for predictions without further training.


New vocab size: 23553


# Measure Performance

In [12]:
from predict_function import predict
def measure_performance(model, eval_datasets,eval_langs,output_dir, device, predict_batch_size, head_mask=None, ffn_mask=None):
    base_model = getattr(model, model.base_model_prefix, model)
    n_layers = base_model.config.num_hidden_layers
    inter_weights = []
    inter_biases = []
    output_weights = []
    layers = base_model.encoder.layer
    if ffn_mask is not None:
        print("Masking intermediate FFN")
        for layer_num in range(n_layers):
                inter_weights.append(layers[layer_num].intermediate.dense.weight.data.clone()) #.detach().to(device)
                inter_biases.append(layers[layer_num].intermediate.dense.bias.data.clone()) #.detach().to(device)
                output_weights.append(layers[layer_num].output.dense.weight.data.clone()) #.detach().to(device)
                
                layers[layer_num].intermediate.dense.weight.data *= ffn_mask[layer_num].unsqueeze(1)
                layers[layer_num].intermediate.dense.bias.data *= ffn_mask[layer_num]
                layers[layer_num].output.dense.weight.data *= ffn_mask[layer_num].unsqueeze(0)

    res = predict(model,eval_datasets,step=0,eval_langs=eval_langs,output_dir=output_dir,
                   device=device, predict_batch_size=predict_batch_size, head_mask = head_mask, in_lang = None)
    # restore intermediate ffn weight TODO
    if ffn_mask is not None:
        for layer_num in range(n_layers):
                layers[layer_num].intermediate.dense.weight.data.copy_(inter_weights[layer_num])
                layers[layer_num].intermediate.dense.bias.data.copy_(inter_biases[layer_num])
                layers[layer_num].output.dense.weight.data.copy_(output_weights[layer_num])

    
    if head_mask is not None:
        print (head_mask)
        print (f"number of heads: {head_mask.sum()}/{head_mask.view(-1).size(0)}")
        num_heads_per_layer = head_mask.sum(dim=-1)
        print (f"Min/Max number of heads per layer: {num_heads_per_layer.min()}, {num_heads_per_layer.max()}")
    if ffn_mask is not None:
        print (f"ffn size: {ffn_mask.sum()}/{ffn_mask.view(-1).size(0)}")
        ffn_size_per_layer = ffn_mask.sum(dim=-1)
        print (f"Min/Max number of ffn per layer: {ffn_size_per_layer.min()}, {ffn_size_per_layer.max()}")
    print (f"Performance: {res}")
    metric_filename = os.path.join(output_dir, f'eval_results.txt')
    with open(metric_filename,'a') as f:
        if head_mask is not None:
            line = f"{head_mask.sum()}/{head_mask.view(-1).size(0)} {res}\n"
        else:
            line = f"Full-Head {res}\n"
        f.write(line)

In [13]:
from utils import MultilingualNLIDataset
from torch.utils.data import DataLoader, SequentialSampler

eval_dataset = MultilingualNLIDataset(
    task=taskname, data_dir=data_dir, split=split, prefix='xlmrbase_pruned',
    max_seq_length=max_seq_length, langs=eval_langs, local_rank=-1, tokenizer=pruned_tokenizer)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=batch_size)
eval_datasets = [eval_dataset.lang_datasets[lang] for lang in eval_langs]

2021-06-25 11:12:45,264 - INFO - utils - Loading features from cached file /yrfs1/rc/zqyang5/Xtreme/datasets/XNLI/xlmrbase_pruned_test_128_zh.tensor


Init NLIDataset


In [9]:
measure_performance(pruned_model, eval_datasets, eval_langs,output_dir, device, batch_size)

2021-06-25 10:38:59,098 - INFO - predict_function - Predicting...
2021-06-25 10:38:59,098 - INFO - predict_function - ***** Running predictions *****
2021-06-25 10:38:59,099 - INFO - predict_function -  task name = xnli
2021-06-25 10:38:59,099 - INFO - predict_function -  lang : zh
2021-06-25 10:38:59,100 - INFO - predict_function -   Num  examples = 5010
Evaluating: 100%|██████████| 157/157 [00:53<00:00,  2.95it/s]
2021-06-25 10:39:52,360 - INFO - predict_function - ***** Eval results 0 Lang zh *****
2021-06-25 10:39:52,361 - INFO - predict_function - zh acc = 0.77964


Performance: {'zh': {'acc': 0.7796407185628742}}


# Transformer Pruning

In [8]:
# from utils import MultilingualNLIDataset
# from torch.utils.data import DataLoader, SequentialSampler

# eval_dataset = MultilingualNLIDataset(
#     task=taskname, data_dir=data_dir, split=split, prefix='xlmrbase',
#     max_seq_length=max_seq_length, langs=eval_langs, local_rank=-1, tokenizer=tokenizer)
# eval_sampler = SequentialSampler(eval_dataset)
# eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=batch_size)

#### Init Transofmer Pruner

In [14]:
general_config = GeneralConfig()
transformer_pruning_config = TransformerPruningConfig(ffn_size=2048, num_of_heads=9)
transformer_pruner = TransformerPruner(pruned_model,general_config,transformer_pruning_config=transformer_pruning_config)

2021-06-25 11:12:57,897 - INFO - textpruner.configurations - Using current cuda device


#### Define aux functions

In [15]:
def adaptor(model_outputs, batch):
    return {'loss':model_outputs[0]}
def batch_postprocessor(batch):
    input_ids, attention_mask, token_type_ids, labels = batch
    return {'input_ids':input_ids, 'attention_mask':attention_mask, 'token_type_ids':token_type_ids, 'labels':labels}

In [16]:
transformer_pruner.prune_transformer(dataloader=eval_dataloader,adaptor=adaptor,batch_postprocessor=batch_postprocessor)

2021-06-25 11:13:10,155 - INFO - textpruner.transformer_pruner - Calculating head importance and ffn importance
2021-06-25 11:13:10,155 - INFO - textpruner.transformer_pruner - ***** Running Forward and Backward from get_ffn_and_head_importance_score*****
2021-06-25 11:13:10,156 - INFO - textpruner.transformer_pruner -  Length of dataloader = 157
Evaluating: 100%|██████████| 157/157 [02:30<00:00,  1.04it/s]
2021-06-25 11:15:41,140 - INFO - textpruner.transformer_pruner - save...
2021-06-25 11:15:41,164 - INFO - textpruner.transformer_pruner - New ffn size:[2048.0, 2048.0, 2048.0, 2048.0, 2048.0, 2048.0, 2048.0, 2048.0, 2048.0, 2048.0, 2048.0, 2048.0]
2021-06-25 11:15:41,165 - INFO - textpruner.transformer_pruner - New num heads:[9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0]
2021-06-25 11:15:41,165 - INFO - textpruner.transformer_pruner - Head and ffn masks have been generated, can be accessed via self.head_mask and self.ffn_mask


In [17]:
transformer_pruner.prune_transformer_with_masks(head_mask=transformer_pruner.head_mask,ffn_mask=transformer_pruner.ffn_mask)

Num_layers:12
{0: [1, 9, 10], 1: [5, 6, 11], 2: [4, 8, 11], 3: [2, 4, 6], 4: [5, 7, 10], 5: [3, 5, 6], 6: [1, 4, 11], 7: [1, 3, 11], 8: [0, 9, 11], 9: [2, 6, 7], 10: [4, 10, 11], 11: [1, 4, 5]}


In [15]:
#transformer_pruner.save_model()
transformer_pruner.save_masks()

2021-06-25 10:40:19,174 - INFO - textpruner.transformer_pruner - Masks have been saved to ./pruned_models/pruned_Head_FFN_masks


## Measure performance

In [20]:
head_mask = torch.tensor(transformer_pruner.head_mask).to(pruned_model.device)
ffn_mask = torch.tensor(transformer_pruner.ffn_mask).to(pruned_model.device)
measure_performance(pruned_model, eval_datasets, eval_langs,output_dir, device, batch_size) #, ffn_mask=ffn_mask,head_mask =head_mask)

2021-06-25 11:17:38,047 - INFO - predict_function - Predicting...
2021-06-25 11:17:38,050 - INFO - predict_function - ***** Running predictions *****
2021-06-25 11:17:38,050 - INFO - predict_function -  task name = xnli
2021-06-25 11:17:38,050 - INFO - predict_function -  lang : zh
2021-06-25 11:17:38,051 - INFO - predict_function -   Num  examples = 5010
Evaluating: 100%|██████████| 157/157 [00:40<00:00,  3.88it/s]
2021-06-25 11:18:18,519 - INFO - predict_function - ***** Eval results 0 Lang zh *****
2021-06-25 11:18:18,520 - INFO - predict_function - zh acc = 0.77066


Performance: {'zh': {'acc': 0.7706586826347306}}


In [19]:
for k,v in pruned_model.state_dict().items():
    print(k,v.shape)

roberta.embeddings.word_embeddings.weight torch.Size([23553, 768])
roberta.embeddings.position_embeddings.weight torch.Size([514, 768])
roberta.embeddings.token_type_embeddings.weight torch.Size([1, 768])
roberta.embeddings.LayerNorm.weight torch.Size([768])
roberta.embeddings.LayerNorm.bias torch.Size([768])
roberta.encoder.layer.0.attention.self.query.weight torch.Size([576, 768])
roberta.encoder.layer.0.attention.self.query.bias torch.Size([576])
roberta.encoder.layer.0.attention.self.key.weight torch.Size([576, 768])
roberta.encoder.layer.0.attention.self.key.bias torch.Size([576])
roberta.encoder.layer.0.attention.self.value.weight torch.Size([576, 768])
roberta.encoder.layer.0.attention.self.value.bias torch.Size([576])
roberta.encoder.layer.0.attention.output.dense.weight torch.Size([768, 576])
roberta.encoder.layer.0.attention.output.dense.bias torch.Size([768])
roberta.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
roberta.encoder.layer.0.attention.output.

In [27]:
transformer_pruner.save_model()

2021-06-25 10:45:44,075 - INFO - transformers.configuration_utils - Configuration saved in ./pruned_models/pruned_9.0H2048.0FFN/config.json
2021-06-25 10:45:44,076 - INFO - textpruner.transformer_pruner - Model and configuration have been saved to ./pruned_models/pruned_9.0H2048.0FFN
