In [157]:
import torch
import json
import time
import statistics
import numpy as np
import pickle as pkl

from tqdm import tqdm


from transformers import AutoTokenizer
from transformers.models.blenderbot.configuration_blenderbot import BlenderbotConfig
from pruning.pruned_blender_bot import PrunedBlenderbotForConditionalGeneration
from node_attribution.utils import count_params
from rouge_score import rouge_scorer

In [158]:
tokenizer = AutoTokenizer.from_pretrained(f"facebook/blenderbot-400M-distill")

In [159]:
data = pkl.load(open("44_human_filtered_conv_pairs.pkl", "rb"))
cali_data = data[:22]
val_data = data[22:]

In [160]:
def score(model, tokenizer, sentence):
    if sentence.startswith("chatbot"):
        input_seq, output_seq = sentence.split("user:")
        input_seq = input_seq.split("chatbot:")[-1].strip()
        output_seq = output_seq.strip()
    else:
        input_seq, output_seq = sentence.split("chatbot:")
        input_seq = input_seq.split("user:")[-1].strip()
        output_seq = output_seq.strip()
        
#     print(input_seq)
#     print("-" * 100)
#     print(output_seq)
#     print("-" * 100)
    
    inputs = tokenizer(input_seq, return_tensors="pt")
    input_ids = inputs.input_ids

    decoder_inputs = tokenizer(output_seq.strip(), return_tensors="pt")
    decoder_input_ids = decoder_inputs.input_ids
    
    repeat_input = input_ids.repeat(decoder_input_ids.size(-1)-2, 1)
    repeat_decoder_input = decoder_input_ids.repeat(decoder_input_ids.size(-1)-2, 1)
    
    mask = torch.ones(decoder_input_ids.size(-1) - 1).diag(1)[:-2]
    masked_input = repeat_decoder_input.masked_fill(mask == 1, tokenizer.pad_token_id)
    
    labels = repeat_decoder_input.masked_fill( masked_input != tokenizer.pad_token_id, -100)
    
    with torch.inference_mode():
        loss = model(input_ids=repeat_input, decoder_input_ids=masked_input, labels=labels).loss
        
    return np.exp(loss.item())

In [161]:
def calc_p(data):
    perplexity_sum = 0
    for pair in tqdm(data):
        perplexity = score(sentence=pair, model=pruned_model, tokenizer=tokenizer)
        perplexity_sum += perplexity
    
    p = perplexity_sum / len(data)
    
    return p

In [173]:
weights_path = "pruned_400m_blender_bot2.pt"
state_dict_shapes_path = "pruned_400m_blender_bot2_state_dict_shapes.pkl"

blenderbot_config = BlenderbotConfig(
    vocab_size=8008,
    max_position_embeddings=128,
    encoder_layers=2,
    encoder_ffn_dim=5120,
    encoder_attention_heads=32,
    decoder_layers=12,
    decoder_ffn_dim=5120,
    decoder_attention_heads=32,
    encoder_layerdrop=0.0,
    decoder_layerdrop=0.0,
    use_cache=True,
    is_encoder_decoder=True,
    activation_function="gelu",
    d_model=1280,
    dropout=0.1,
    attention_dropout=0.0,
    activation_dropout=0.0,
    init_std=0.02,
    decoder_start_token_id=1,
    scale_embedding=False,
    pad_token_id=0,
    bos_token_id=1,
    eos_token_id=2,
    encoder_no_repeat_ngram_size=3,
    forced_eos_token_id=2,
    num_beams=10
)

In [174]:
pruned_model = PrunedBlenderbotForConditionalGeneration(blenderbot_config, state_dict_shapes_path)

In [175]:
pruned_model.load_state_dict(torch.load(weights_path))

<All keys matched successfully>

In [169]:
pruned_percent = 1.0 - (count_params(pruned_model)[-1] / 364802560)
print(pruned_percent)

0.0


In [170]:
# print(calc_p(cali_data))
# print(calc_p(val_data))

In [171]:
num_trials = 20

In [172]:
line = "I'm looking for someone to talk to. Life can be so lonely sometimes and it helps to have someone to vent to."
inputs = tokenizer([line], return_tensors="pt")
pruned_times = []

# for i in range(num_trials):
# start = time.time()
outputs = pruned_model.generate(**inputs)
# end = time.time()
# pruned_times.append(end - start)
# print(f"inference time: {end - start}")
    
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

['']


In [13]:
mean_pruned = np.mean(pruned_times)
std_pruned = np.std(pruned_times)
mode = statistics.median(pruned_times)
print(mean_pruned)
print(std_pruned)
print(mode)

43.58427990674973
0.8181698710717864
43.70455062389374


In [14]:
# while True:
#     line = input("You:")
#     inputs = tokenizer(line, return_tensors="pt")
#     outputs = pruned_model.generate(
#         input_ids=inputs["input_ids"], 
#         max_new_tokens=20, 
#         do_sample=True, 
#         top_k=50, 
#         top_p=0.95,
#     )
#     print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

In [14]:
# full_line = "Person: My favorite movie is the The Day After Tomorrow\nSocialBot: Oh, interesting, I am not familiar with that movie! Can you tell me more about it?"
# prompt_line = "Person: My favorite movie is The Day After Tomorrow\nSocialBot: "
# completion = full_line.split(prompt_line)[-1]
# inputs = tokenizer(prompt_line, return_tensors="pt")

# #for i in range(num_trials):
# outputs = pruned_model.generate(
#     input_ids=inputs["input_ids"], 
#     max_new_tokens=25, 
#     do_sample=True, 
#     top_k=50, 
#     top_p=0.95,
# )
# out_seq = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# out_seq = out_seq[0].split("Person: My favorite movie is The Day After Tomorrow\nSocialBot: ")[-1]
# r_scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
# rouge_scores = r_scorer.score(completion, out_seq)

In [15]:
# from transformers import AutoTokenizer, BloomForCausalLM

In [16]:
# model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m")

In [17]:
# inputs = tokenizer("Person: My favorite movie is The Day After Tomorrow\nSocialBot: ", return_tensors="pt")
# start = time.time()
# outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=25, do_sample=True, top_k=50, top_p=0.95)
# end = time.time()
# print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# print(f"inference time: {end - start}")

In [126]:
# def score(model, tokenizer, sentence):
#     if sentence.startswith("chatbot"):
#         input_seq, output_seq = sentence.split("user:")
#         input_seq = input_seq.split("chatbot:")[-1].strip()
#         output_seq = output_seq.strip()
#     else:
#         input_seq, output_seq = sentence.split("chatbot:")
#         input_seq = input_seq.split("user:")[-1].strip()
#         output_seq = output_seq.strip()
        
#     print(input_seq)
#     print("-" * 100)
#     print(output_seq)
#     print("-" * 100)
    
#     inputs = tokenizer(input_seq, return_tensors="pt")
#     input_ids = inputs.input_ids

#     decoder_inputs = tokenizer(output_seq.strip(), return_tensors="pt")
#     decoder_input_ids = decoder_inputs.input_ids
    
    
#     with torch.inference_mode():
#         logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
#         probs = torch.nn.functional.softmax(logits, dim=-1)
#         neg_log_likelihood = torch.mul(torch.log(probs), -1)
#         seq_neg_log_likelihood_sum = 0
#         for token_index in range(len(decoder_input_ids[0])):
#             token_id = decoder_input_ids[0][token_index]
#             seq_neg_log_likelihood_sum += neg_log_likelihood[0][token_index][token_id]
#         loss = seq_neg_log_likelihood_sum / len(decoder_input_ids[0])
#         print(loss)
        
#     return np.exp(loss)