In [62]:
import torch
import json
import time
import statistics
import numpy as np
import pickle as pkl

from tqdm import tqdm


from transformers import AutoTokenizer
from transformers.models.blenderbot.configuration_blenderbot import BlenderbotConfig
from pruning.pruned_blender_bot import PrunedBlenderbotForConditionalGeneration
from node_attribution.utils import count_params
from rouge_score import rouge_scorer

In [63]:
tokenizer = AutoTokenizer.from_pretrained(f"facebook/blenderbot-3B")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [64]:
data = pkl.load(open("44_human_filtered_conv_pairs.pkl", "rb"))
cali_data = data[:22]
val_data = data[22:]

In [65]:
def score(model, tokenizer, sentence):
    if sentence.startswith("chatbot"):
        input_seq, output_seq = sentence.split("user:")
        input_seq = input_seq.split("chatbot:")[-1].strip()
        output_seq = output_seq.strip()
    else:
        input_seq, output_seq = sentence.split("chatbot:")
        input_seq = input_seq.split("user:")[-1].strip()
        output_seq = output_seq.strip()
        
#     print(input_seq)
#     print("-" * 100)
#     print(output_seq)
#     print("-" * 100)
    
    inputs = tokenizer(input_seq, return_tensors="pt")
    input_ids = inputs.input_ids

    decoder_inputs = tokenizer(output_seq.strip(), return_tensors="pt")
    decoder_input_ids = decoder_inputs.input_ids
    
    repeat_input = input_ids.repeat(decoder_input_ids.size(-1)-2, 1)
    repeat_decoder_input = decoder_input_ids.repeat(decoder_input_ids.size(-1)-2, 1)
    
    mask = torch.ones(decoder_input_ids.size(-1) - 1).diag(1)[:-2]
    masked_input = repeat_decoder_input.masked_fill(mask == 1, tokenizer.pad_token_id)
    
    labels = repeat_decoder_input.masked_fill( masked_input != tokenizer.pad_token_id, -100)
    
    with torch.inference_mode():
        loss = model(input_ids=repeat_input, decoder_input_ids=masked_input, labels=labels).loss
        
    return np.exp(loss.item())

In [66]:
def calc_p(data):
    perplexity_sum = 0
    for pair in tqdm(data):
        perplexity = score(sentence=pair, model=pruned_model, tokenizer=tokenizer)
        perplexity_sum += perplexity
    
    p = perplexity_sum / len(data)
    
    return p

In [67]:
weights_path = "pruned_3B_blender_bot2.pt"
state_dict_shapes_path = "pruned_3B_blender_bot2_state_dict_shapes.pkl"

blenderbot_config = BlenderbotConfig(
    vocab_size=8008,
    max_length=60,
    max_position_embeddings=128,
    model_type="blenderbot",
    encoder_layers=2,
    encoder_ffn_dim=10240,
    encoder_attention_heads=32,
    decoder_layers=24,
    decoder_ffn_dim=10240,
    decoder_attention_heads=32,
    encoder_layerdrop=0.0,
    decoder_layerdrop=0.0,
    layernorm_variant="prelayernorm",
    use_cache=True,
    is_encoder_decoder=True,
    activation_function="gelu",
    d_model=2560,
    dropout=0.1,
    attention_dropout=0.0,
    activation_dropout=0.0,
    init_std=0.02,
    decoder_start_token_id=1,
    scale_embedding=True,
    pad_token_id=0,
    bos_token_id=1,
    eos_token_id=2,
    encoder_no_repeat_ngram_size=3,
    forced_eos_token_id=2,
    num_beams=10,
    length_penalty=0.65,
    min_length=20,
    static_position_embeddings=False,
    add_bias_logits=False,
    add_final_layer_norm=True,
    classif_dropout=0.0,
    classifier_dropout=0.0,
    do_blenderbot_90_layernorm=True,
    extra_layer_norm=False,
    extra_pos_embeddings=0,
    force_bos_token_to_be_generated=False,
    gradient_checkpointing=False,
    normalize_before=True,
    normalize_embedding=False,
    unk_token_id=3,
    architectures=["BlenderbotForConditionalGeneration"],
    num_hidden_layers=2,
)

In [89]:
pruned_model = PrunedBlenderbotForConditionalGeneration(blenderbot_config, state_dict_shapes_path)

In [90]:
pruned_model.load_state_dict(torch.load(weights_path))

<All keys matched successfully>

In [91]:
pruned_percent = 1.0 - (count_params(pruned_model)[-1] / 2696268800)
print(pruned_percent)

0.15003262063485656


In [92]:
# print(calc_p(cali_data))
# print(calc_p(val_data))

In [93]:
num_trials = 20

In [94]:
line = "I'm looking for someone to talk to. Life can be so lonely sometimes and it helps to have someone to vent to."
inputs = tokenizer([line], return_tensors="pt")
pruned_times = []

for i in range(num_trials):
    start = time.time()
    outputs = pruned_model.generate(**inputs)
    end = time.time()
    pruned_times.append(end - start)
    print(f"inference time: {end - start}")
    
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

inference time: 10.828724145889282
inference time: 10.946922063827515
inference time: 13.529675006866455
inference time: 10.88916301727295
inference time: 11.00428295135498
inference time: 10.672850847244263
inference time: 11.297038793563843
inference time: 10.755046129226685
inference time: 10.746734857559204
inference time: 11.206049919128418
inference time: 10.88215708732605
inference time: 13.233487844467163
inference time: 10.81781005859375
inference time: 15.966655254364014
inference time: 27.881234884262085
inference time: 11.979538917541504
inference time: 10.672956943511963
inference time: 10.222210884094238
inference time: 16.153496980667114
inference time: 10.618239164352417
[' That sounds like a good plan. I hope you find someone to help you out with that.']


In [95]:
mean_pruned = np.mean(pruned_times)
std_pruned = np.std(pruned_times)
mode = statistics.median(pruned_times)
print(mean_pruned)
print(std_pruned)
print(mode)

12.515213787555695
3.8983528676686006
10.918042540550232
