In [1]:
import sys
sys.path.append('../src/')
from transformers import XLMRobertaTokenizer, XLMRobertaConfig, XLMRobertaForSequenceClassification
from textpruner import EmbeddingPruner, TransformerPruner
from textpruner import GeneralConfig, EmbeddingPruningConfig, TransformerPruningConfig
import torch,os,json,tqdm
import logging
logger = logging.getLogger(__name__)

In [2]:
device='cuda' #'cuda'
output_dir='./pruned_models'
batch_size=32
eval_langs = ['zh']
data_dir='/yrfs1/rc/zqyang5/Xtreme/datasets/XNLI'
split = 'test'
max_seq_length=128
taskname='xnli'

#### Init Model

In [3]:
xlmr_vocab_file='./pretrained-models/xlm-r-base/sentencepiece.bpe.model'
xlmr_config_file='./pretrained-models/xlm-r-base/config.json'
xlmr_classification_ckpt_file ='/work/rc/zqyang5/for_pruning/xnli/xnli_zhTrainXb_lr2e4_s4_bs32/gs49084.pkl'

def init_xlmrForSentenceClassification(config_file, vocab_file, ckpt_file, num_labels=3):
    tokenizer=XLMRobertaTokenizer(vocab_file=vocab_file)
    config=XLMRobertaConfig.from_json_file(config_file)
    config.num_labels=num_labels
    model = XLMRobertaForSequenceClassification.from_pretrained(ckpt_file,config=config)
    return tokenizer,model


tokenizer, model = init_xlmrForSentenceClassification(xlmr_config_file, xlmr_vocab_file, 
            xlmr_classification_ckpt_file,num_labels=3)
model.to(device)
print("Current Vocab size:",tokenizer.vocab_size)
print("Current Embedding size:", model.get_input_embeddings().weight.shape)

Current Vocab size: 250002
Current Embedding size: torch.Size([250002, 768])


In [4]:
# from transformers import BertTokenizer 
# bert_tokenizer = BertTokenizer(vocab_file='/yrfs1/rc/zqyang5/pretrained-models/bert/base_uncased/vocab.txt')
# print(tokenizer([("Hello world","Goodbye!")],max_length=10,truncation=True,padding='max_length',return_token_type_ids=True))
# print(bert_tokenizer([("Hello world","Goodbye!")],max_length=10,truncation=True,padding='max_length',return_token_type_ids=True))

# Embedding Pruning

#### Init data: extract sentences from the datasets

In [5]:
# def extract_sentences_from_xnli(data_files):
#     results = []
#     for data_file in data_files:
#         with open(data_file,'r',encoding='utf-8') as f:
#             lines = f.readlines()
#         for line in tqdm.tqdm(lines):
#                 fields = line.strip().split('\t')
#                 for field in fields:
#                     results.append(field)
#     return results

# data_files = ['./datasets/multinli.train.zh.tsv']
# lines = extract_sentences_from_xnli(data_files)

#### Init Embedding Pruner

In [6]:
# embedding_pruner = EmbeddingPruner(model, tokenizer)
# embedding_pruner.prune_embeddings(dataiter=lines)

In [7]:
# print("New embedding size:", model.get_input_embeddings().weight.shape)
# embedding_pruner.save_model()

#### reload

In [8]:
pruned_tokenizer, pruned_model = init_xlmrForSentenceClassification(
    config_file='./pruned_models/pruned_23553V/config.json',
    vocab_file='./pruned_models/pruned_23553V/sentencepiece.bpe.model',
    ckpt_file='./pruned_models/pruned_23553V/model.pkl')

pruned_model.to(device)
print("New vocab size:", pruned_tokenizer.vocab_size)

New vocab size: 23553


# Measure Performance

In [9]:
from predict_function import predict
def measure_performance(model, eval_datasets,eval_langs,output_dir, device, predict_batch_size, head_mask=None, ffn_mask=None):
    base_model = getattr(model, model.base_model_prefix, model)
    n_layers = base_model.config.num_hidden_layers
    inter_weights = []
    inter_biases = []
    output_weights = []
    layers = base_model.encoder.layer
    if ffn_mask is not None:
        print("Masking intermediate FFN")
        for layer_num in range(n_layers):
                inter_weights.append(layers[layer_num].intermediate.dense.weight.data.clone()) #.detach().to(device)
                inter_biases.append(layers[layer_num].intermediate.dense.bias.data.clone()) #.detach().to(device)
                output_weights.append(layers[layer_num].output.dense.weight.data.clone()) #.detach().to(device)
                
                layers[layer_num].intermediate.dense.weight.data *= ffn_mask[layer_num].unsqueeze(1)
                layers[layer_num].intermediate.dense.bias.data *= ffn_mask[layer_num]
                layers[layer_num].output.dense.weight.data *= ffn_mask[layer_num].unsqueeze(0)

    res = predict(model,eval_datasets,step=0,eval_langs=eval_langs,output_dir=output_dir,
                   device=device, predict_batch_size=predict_batch_size, head_mask = head_mask, in_lang = None)
    # restore intermediate ffn weight TODO
    if ffn_mask is not None:
        for layer_num in range(n_layers):
                layers[layer_num].intermediate.dense.weight.data.copy_(inter_weights[layer_num])
                layers[layer_num].intermediate.dense.bias.data.copy_(inter_biases[layer_num])
                layers[layer_num].output.dense.weight.data.copy_(output_weights[layer_num])

    
    if head_mask is not None:
        print (head_mask)
        print (f"number of heads: {head_mask.sum()}/{head_mask.view(-1).size(0)}")
        num_heads_per_layer = head_mask.sum(dim=-1)
        print (f"Min/Max number of heads per layer: {num_heads_per_layer.min()}, {num_heads_per_layer.max()}")
    if ffn_mask is not None:
        print (f"ffn size: {ffn_mask.sum()}/{ffn_mask.view(-1).size(0)}")
        ffn_size_per_layer = ffn_mask.sum(dim=-1)
        print (f"Min/Max number of ffn per layer: {ffn_size_per_layer.min()}, {ffn_size_per_layer.max()}")
    print (f"Performance: {res}")
    metric_filename = os.path.join(output_dir, f'eval_results.txt')
    with open(metric_filename,'a') as f:
        if head_mask is not None:
            line = f"{head_mask.sum()}/{head_mask.view(-1).size(0)} {res}\n"
        else:
            line = f"Full-Head {res}\n"
        f.write(line)

In [10]:
from utils import MultilingualNLIDataset
from torch.utils.data import DataLoader, SequentialSampler

eval_dataset = MultilingualNLIDataset(
    task=taskname, data_dir=data_dir, split=split, prefix='xlmrbase_pruned',
    max_seq_length=max_seq_length, langs=eval_langs, local_rank=-1, tokenizer=pruned_tokenizer)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=batch_size)
eval_datasets = [eval_dataset.lang_datasets[lang] for lang in eval_langs]

Init NLIDataset


In [11]:
# measure_performance(pruned_model, eval_datasets, eval_langs,output_dir, device, batch_size)

# Transformer Pruning

In [12]:
# from utils import MultilingualNLIDataset
# from torch.utils.data import DataLoader, SequentialSampler

# eval_dataset = MultilingualNLIDataset(
#     task=taskname, data_dir=data_dir, split=split, prefix='xlmrbase',
#     max_seq_length=max_seq_length, langs=eval_langs, local_rank=-1, tokenizer=tokenizer)
# eval_sampler = SequentialSampler(eval_dataset)
# eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=batch_size)

#### Init Transofmer Pruner

In [13]:
general_config = GeneralConfig()
transformer_pruning_config = TransformerPruningConfig(
    ffn_size=2048, num_of_heads=7*12, is_iterative=True, head_is_layerwise=False, n_iter=2)
transformer_pruner = TransformerPruner(
    pruned_model,general_config,
    transformer_pruning_config=transformer_pruning_config)

#### Define aux functions

In [14]:
def adaptor(model_outputs, batch):
    return {'loss':model_outputs[0]}
def batch_postprocessor(batch):
    input_ids, attention_mask, token_type_ids, labels = batch
    return {'input_ids':input_ids, 'attention_mask':attention_mask, 
            'token_type_ids':token_type_ids, 'labels':labels}

In [15]:
#transformer_pruner.prune_transformer_with_masks(head_mask=transformer_pruner.head_mask,ffn_mask=transformer_pruner.ffn_mask)
transformer_pruner.prune_transformer(eval_dataloader,adaptor,batch_postprocessor)

Evaluating:   0%|          | 0/157 [00:00<?, ?it/s]

n_iter: 0
dffn_size, dnum_of_heads: 36864 144 19456 114
Num_layers:12
Num_heads:12
Num_layers:12
Head_size:64
{0: [], 1: [5, 6, 11], 2: [8, 11], 3: [2, 4, 5, 6], 4: [], 5: [], 6: [11], 7: [1, 3, 11], 8: [0, 11], 9: [2, 6, 7], 10: [0, 4, 5, 10, 11], 11: [1, 2, 4, 5, 8, 9, 10]}
n_iter: 1


Evaluating: 100%|██████████| 157/157 [01:27<00:00,  1.78it/s]


dffn_size, dnum_of_heads: 36864 144 2048 84
Num_layers:12
{0: [1, 9, 10], 1: [1, 2, 5, 6, 9, 11], 2: [0, 2, 3, 4, 8, 10, 11], 3: [0, 1, 2, 4, 5, 6, 8], 4: [5, 7], 5: [5, 6], 6: [1, 4, 11], 7: [1, 2, 3, 11], 8: [0, 2, 8, 9, 11], 9: [0, 1, 2, 6, 7, 8, 9], 10: [0, 1, 4, 5, 6, 10, 11], 11: [1, 2, 4, 5, 8, 9, 10]}


In [16]:
print(sum(transformer_pruner.ffn_mask[0].tolist()))
print(sum(transformer_pruner.ffn_mask[1].tolist()))
print(sum(transformer_pruner.ffn_mask[2].tolist()))

2048.0
2048.0
2048.0


In [17]:
transformer_pruner.save_model()
#transformer_pruner.save_masks()

In [18]:
new_model=XLMRobertaForSequenceClassification.from_pretrained('./pruned_models/pruned_84H2048FFN/model.pkl',config='./pruned_models/pruned_84H2048FFN/config.json')

## Measure performance

In [19]:
new_model.to(device)
head_mask = torch.tensor(transformer_pruner.head_mask).to(pruned_model.device)
ffn_mask = torch.tensor(transformer_pruner.ffn_mask).to(pruned_model.device)
measure_performance(new_model, eval_datasets, eval_langs,output_dir, device, batch_size) #, ffn_mask=ffn_mask,head_mask =head_mask)

Evaluating: 100%|██████████| 157/157 [00:21<00:00,  7.47it/s]

Performance: {'zh': {'acc': 0.7499001996007985}}





In [20]:
for k,v in pruned_model.state_dict().items():
    if 'attention' in k: 
        print(k,v.shape)

roberta.encoder.layer.0.attention.self.query.weight torch.Size([576, 768])
roberta.encoder.layer.0.attention.self.query.bias torch.Size([576])
roberta.encoder.layer.0.attention.self.key.weight torch.Size([576, 768])
roberta.encoder.layer.0.attention.self.key.bias torch.Size([576])
roberta.encoder.layer.0.attention.self.value.weight torch.Size([576, 768])
roberta.encoder.layer.0.attention.self.value.bias torch.Size([576])
roberta.encoder.layer.0.attention.output.dense.weight torch.Size([768, 576])
roberta.encoder.layer.0.attention.output.dense.bias torch.Size([768])
roberta.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
roberta.encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])
roberta.encoder.layer.1.attention.self.query.weight torch.Size([384, 768])
roberta.encoder.layer.1.attention.self.query.bias torch.Size([384])
roberta.encoder.layer.1.attention.self.key.weight torch.Size([384, 768])
roberta.encoder.layer.1.attention.self.key.bias torch.Size([38