Use UL2 to 

(1) measure inconsistencies in its bidirectional conditionals; 

(2) improve llm inference with Emsemble of Conditionals.  




* [Imports and global utils](#0)

* [Load tokenizer and model](#1)

<h1 style="font-size: 20px;"><a class="anchor" id="0"></a>Imports and global utils</h1>

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1,3"
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Tokenizer
import numpy as np
import pickle
import time
from tqdm import tqdm
import json


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
'''a simple way to toggle print statements'''
print_status = True
if print_status:
    show = print
else:
    show = lambda *args, **kwargs: None

<h1 style="font-size: 20px;"><a class="anchor" id="1"></a>Load tokenizer and model</h1>

In [3]:
# We are using custom huggingface cache dirs in case the default one doesn't have the capacity, since the models can be quite large.
MY_HUGGINGFACE_CACHE_DIR ='huggingface_cache' # relative to this notebook path
tokenizer = AutoTokenizer.from_pretrained("google/ul2",
                                        cache_dir = MY_HUGGINGFACE_CACHE_DIR+'/google-ul2')


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
model = T5ForConditionalGeneration.from_pretrained("google/ul2", 
                                                   cache_dir=MY_HUGGINGFACE_CACHE_DIR + '/google-ul2', 
                                                   low_cpu_mem_usage=True, 
                                                   torch_dtype=torch.bfloat16).to("cuda")
model.parallelize() # TODO: the first gpu doesn't reduce its memory usage, but 9GB/18GB is on every other GPU

Loading checkpoint shards: 100%|██████████| 4/4 [01:03<00:00, 15.89s/it]


In [5]:
LAMBADA_TEST_DATA_PATH = "data/jsonls/test.jsonl"

with open(LAMBADA_TEST_DATA_PATH, "r") as f:
    lambada = [json.loads(line) for line in f.readlines()]
    
# To use the NLG mode of UL2, append [NLG] to the beginning of each input, and <extra_id_0> to the end
lambada = [
    {
        "inputs_pretokenized": "[NLG] " + x['inputs_pretokenized'] + " <extra_id_0>",
        "targets_pretokenized": x['targets_pretokenized']
    } 
    for x in lambada
]


In [6]:
'''Get a list of valid punctuations that can END a sentence. If the model generates one, it is considered that the sentence is complete
and we can parse for the last word'''

ENDING_PUNCTUATIONS = ',!.:;?'
_vocab = tokenizer.get_vocab()
ENDING_PUNCTUATIONS_IDS_LIST = [_vocab[p] for p in ENDING_PUNCTUATIONS]

In [7]:
'''Get the first word from each of the given completions (completions by the model [String]). Return the words.'''
def get_words_from_completions(completions):
    # if a punctuation can be found in the completion, get the word before the punctuation
    words = []
    for completion in completions:
        # find the punctuation
        for i in range(len(completion)):
            if completion[i] in ENDING_PUNCTUATIONS:
                word = completion[:i]
                words.append(word)
                # print(words)
                break

    # if the word starts with <pad>, remove it
    words = [word[5:] if word.startswith("<pad>") else word for word in words]

    # check it it the case that, assert that if the word starts with <extra_id_0>, ' ' follows. print the word if it is not the case
    for word in words:
        if word.startswith("<extra_id_0>") and len(word) > 13:
            if word[12] != " ":
                print('word[12] != \" \"')
                print(word)

    # if the word starts with <extra_id_0>, remove it
    words = [word[12:] if word.startswith("<extra_id_0>") else word for word in words]
    # if the word starts with ' ', remove it
    words = [word[1:] if word.startswith(" ") else word for word in words]
    # if the word ends with ' ', remove it
    words = [word[:-1] if word.endswith(" ") else word for word in words]
    # if the word is empty, remove it
    words = [word for word in words if word != ""]
    # if there are multiple words in word, remove it
    words = [word for word in words if len(word.split(" ")) == 1]
    return words

In [8]:
def get_word_from_completion(completion):
    '''Get the first word from the given completion (a completion by the model [String]). Return the word.'''
    found = False
    # if a punctuation can be found in the completion, get the string before the punctuation
    for i in range(len(completion)):
        if completion[i] in ENDING_PUNCTUATIONS:
            word = completion[:i]
            found = True
            break
    if not found:
        return None


    '''postprocess the string to remove the <pad> and <extra_id_0> tokens to get the word'''
    # if the word starts with <pad>, remove it
    word = word[5:] if word.startswith("<pad>") else word

    # check it it the case that, assert that if the word starts with <extra_id_0>, ' ' follows. print the word if it is not the case
    # if word.startswith("<extra_id_0>") and len(word) > 13:
    #     if word[12] != " ":
    #         show('word[12] != \" \"')
    #         show(word)

    # if the word starts with <extra_id_0>, remove it
    word = word[12:] if word.startswith("<extra_id_0>") else word
    # if the word starts with ' ', remove it
    word = word[1:] if word.startswith(" ") else word
    # if the word ends with ' ', remove it
    word = word[:-1] if word.endswith(" ") else word
    # if the word is empty, discount it
    word = word if word != "" else None
    # if there are multiple words in it, discount it
    if word:
        word = word if len(word.split(" ")) == 1 else None
    return word

In [9]:
def get_word_punc_pairs(completions):
    '''given a list of completions (completions by the LLM), return a list of word-punc pairs'''
    # show(completions)
    # if a punctuation can be found in the completion, get the word before the punctuation
    words = []
    for completion in completions:
        # find the punctuation
        for i in range(len(completion)):
            if completion[i] in ENDING_PUNCTUATIONS:
                word = completion[:i+1]
                words.append(word)
                # show(words)
                break
    
    # if the word starts with <pad>, remove the <pad>
    words = [word[5:] if word.startswith("<pad>") else word for word in words]
    # if the word starts with <extra_id_0>, remove the <extra_id_0>
    words = [word[12:] if word.startswith("<extra_id_0>") else word for word in words]
    # if the word starts with ' ', remove it
    words = [word[1:] if word.startswith(" ") else word for word in words]
    # if the word ends with ' ', remove it
    words = [word[:-1] if word.endswith(" ") else word for word in words]
    # if the word is empty, remove it
    words = [word for word in words if word != ""]
    # if there are multiple words in word, remove it
    words = [word for word in words if len(word.split(" ")) == 1]
    # if the length is 1, remove it (to prevent the case where it is just a punctuation)
    words = [word for word in words if len(word) > 1]
    # if the word contains <unk>, remove it
    words = [word for word in words if "<unk>" not in word]
    return list(set(words))

In [10]:
def remove_pad(completions):
    '''given a list of completions (completions by the LLM), remove the <pad>'''
    # if the word starts with <pad>, remove the <pad>
    completions = [completion[5:] if completion.startswith("<pad>") else completion for completion in completions]
    return completions

In [11]:
def remove_pad_id(completions):
    '''given a list of completions of ids (completions by the LLM), remove the <pad>'''
    pad_id = tokenizer.convert_tokens_to_ids("<pad>")
    # if the word starts with <pad>, remove the <pad>
    completions_return = []
    for completion in completions:
        if completion[0] == pad_id:
            completions_return.append(completion[1:])
        else:
            completions_return.append(completion)
    return completions_return

In [12]:
def before_first_punc(completions):
    '''given a list of completions (completions by the LLM), return the string before the first punctuation'''
    completions_return = []
    for completion in completions:
        for i in range(len(completion)):
            if completion[i] in ENDING_PUNCTUATIONS_IDS_LIST:
                completions_return.append(completion[:i+1])
                break
    return completions_return

In [13]:
# cross entroy loss with logits and labels

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) #reduction='sum'
loss_fn_sum = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum') #reduction='sum'

In [14]:
'''temporary testing stuff'''
min_len = 100000
max_len = 0
for example_index in tqdm(range(len(lambada))): # len(lambada)
    input_string = lambada[example_index]['inputs_pretokenized']
    inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
    if len(inputs[0]) < min_len:
        min_len = len(inputs[0])
    if len(inputs[0]) > max_len:
        max_len = len(inputs[0])
# min_len
max_len

100%|██████████| 5153/5153 [00:01<00:00, 3835.25it/s]


283

In [15]:
'''On lambada, generate the top completions (completions) for each example, and get the word from each completion'''
# generate for all examples, and then get the words from the completions, and compare the first one with the target
count_correct = 0
count_correct_top_num_beams = 0
count_no_words_found = 0
id_to_word_and_punc_pairs = {}
id_to_word_and_punc_pairs_processed = {}
id_to_completions = {}

MAX_COMPLETION_LENGTH = 8
NUM_BEAMS = 20

for example_index in tqdm(range(1)): # len(lambada)
    input_string = lambada[example_index]['inputs_pretokenized']
    inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
    outputs = model.generate(inputs,
                             max_length=MAX_COMPLETION_LENGTH, # for last word prediction, 8 is sufficient
                             num_beams=NUM_BEAMS, 
                             num_return_sequences=NUM_BEAMS, 
                             output_scores=True,
                             eos_token_id=tokenizer.convert_tokens_to_ids('<extra_id_1>'), 
                             return_dict_in_generate=True)
    
    completions = [tokenizer.decode(outputs['sequences'][i]) for i in range(NUM_BEAMS)]
    # print([tokenizer.batch_decode(outputs['sequences'][i]) for i in range(num_beams)])
    completions_ids = [outputs['sequences'][i] for i in range(NUM_BEAMS) 
                   if get_word_from_completion(completions[i]) is not None]

    words = get_words_from_completions(completions)
    completions_without_pad = remove_pad_id(completions_ids)
    completions_without_pad_before_punctution = before_first_punc(completions_without_pad)
    # print(words)
    if words:
        # print(completions)
        # print(words[0], lambada[example_index]['targets_pretokenized'])
        if words[0] == lambada[example_index]['targets_pretokenized'][0]:
            count_correct += 1
    else:
        count_no_words_found += 1
        print("no words found")
    word_and_punc_pairs = get_word_punc_pairs(completions)
    id_to_word_and_punc_pairs[example_index] = word_and_punc_pairs
    words_unique = list(set(words))
    id_to_word_and_punc_pairs_processed[example_index] = []
    id_to_completions[example_index] = completions_without_pad_before_punctution
    for word in words_unique:
        found = 0
        # iterate through the word and punc pairs, and find the one that matches the word
        for word_and_punc_pair in word_and_punc_pairs:
            # it is a match if pair = word + punc
            for punc in ENDING_PUNCTUATIONS:
                if word_and_punc_pair == word + punc:
                    id_to_word_and_punc_pairs_processed[example_index].append(word_and_punc_pair)
                    found = 1
                    break
            if found == 1:
                break
    # calculate the number of correct top num_beams: if the correct word is in the top num_beams, then it is correct
    for word in words_unique:
        if word == lambada[example_index]['targets_pretokenized'][0]:
            count_correct_top_num_beams += 1
            break


100%|██████████| 1/1 [00:03<00:00,  3.76s/it]


In [17]:
completions

['<pad><extra_id_0> any of that. ',
 '<pad><extra_id_0> signs anymore. i',
 '<pad><extra_id_0> any of that anymore.',
 '<pad><extra_id_0> that anymore. i',
 '<pad><extra_id_0> the sign. i',
 '<pad><extra_id_0> the angel. i',
 '<pad><extra_id_0> a sign. ',
 '<pad><extra_id_0> signs. i just',
 '<pad><extra_id_0> the sign anymore. ',
 '<pad><extra_id_0> any of that. all',
 '<pad><extra_id_0> the angel anymore. ',
 '<pad><extra_id_0> any of it. ',
 '<pad><extra_id_0> anything else. i',
 '<pad><extra_id_0> that. i just',
 '<pad><extra_id_0> that anymore. <unk>',
 '<pad><extra_id_0> signs anymore. <unk>',
 '<pad><extra_id_0> signs. i ',
 '<pad><extra_id_0> the signs anymore. ',
 '<pad><extra_id_0> angels. i',
 '<pad><extra_id_0> signs or backup. ']

In [16]:
outputs.sequences

tensor([[    0, 32099,   136,    13,    24,     3,     5,     3],
        [    0, 32099,  3957,  7595,     3,     5,     3,    23],
        [    0, 32099,   136,    13,    24,  7595,     3,     5],
        [    0, 32099,    24,  7595,     3,     5,     3,    23],
        [    0, 32099,     8,  1320,     3,     5,     3,    23],
        [    0, 32099,     8, 11831,     3,     5,     3,    23],
        [    0, 32099,     3,     9,  1320,     3,     5,     3],
        [    0, 32099,  3957,     3,     5,     3,    23,   131],
        [    0, 32099,     8,  1320,  7595,     3,     5,     3],
        [    0, 32099,   136,    13,    24,     3,     5,    66],
        [    0, 32099,     8, 11831,  7595,     3,     5,     3],
        [    0, 32099,   136,    13,    34,     3,     5,     3],
        [    0, 32099,   959,  1307,     3,     5,     3,    23],
        [    0, 32099,    24,     3,     5,     3,    23,   131],
        [    0, 32099,    24,  7595,     3,     5,     3,     2],
        [ 

In [None]:
id_to_completions_numpy = {}
for key in id_to_completions:
    completions_numpy = []
    for completion in id_to_completions[key]:
        completions_numpy.append(np.array(completion.cpu()))
    id_to_completions_numpy[key] = completions_numpy


In [None]:
timed_pickle_file_name = 'ul2_lambada_vanilla_beam_search_results_' + str(time.time()) + '.pickle'
# Save your data to a pickle file
with open(timed_pickle_file_name, 'wb') as fp:
    pickle.dump({'count_correct': count_correct,
                 'count_correct_top_num_beams': count_correct_top_num_beams,
                 'count_no_words_found': count_no_words_found,
                 'id_to_word_and_punc_pairs': id_to_word_and_punc_pairs,
                 'id_to_word_and_punc_pairs_processed': id_to_word_and_punc_pairs_processed,
                 'id_to_completions_numpy': id_to_completions_numpy}, fp)

In [None]:
# load it back
# /work/09127/tomyoung/ls6/inconsistencies_project/ul2_lambada_vanilla_beam_search_results_1683476272.4741185.pickle
timed_pickle_file_name = '/work/09127/tomyoung/ls6/inconsistencies_project/ul2_lambada_vanilla_beam_search_results_1683476272.4741185.pickle'
with open(timed_pickle_file_name, 'rb') as fp:
    ul2_lambada_vanilla_beam_search_results = pickle.load(fp)

In [None]:
id_to_completions = {}
for key in ul2_lambada_vanilla_beam_search_results['id_to_completions_numpy']:
    completions = []
    for completion in ul2_lambada_vanilla_beam_search_results['id_to_completions_numpy'][key]:
        completions.append(torch.from_numpy(completion))
    id_to_completions[key] = completions

In [None]:
input_ids = tokenizer(lambada[0]['inputs_pretokenized'], return_tensors="pt").input_ids.to("cuda")
print(lambada[0]['inputs_pretokenized'])
labels = tokenizer("<extra_id_0> " + id_to_word_and_punc_pairs_processed[0][2] + " <extra_id_1>", return_tensors="pt").input_ids.to("cuda")
print("<extra_id_0> " + id_to_word_and_punc_pairs_processed[0][2] + " <extra_id_1>")
outputs = model(input_ids, labels=labels)
loss = outputs.loss
logits = outputs.logits
logits

In [None]:
def get_avg_log_p_of_completion_without_pad(inputs_pretokenized, completion, offset=0):
    # input_ids: 1*len = words + 32099 + 1
    input_ids = tokenizer(inputs_pretokenized, return_tensors="pt").input_ids.to("cuda")
    # labels: 1*len = 32099 + words
    labels = completion.unsqueeze(0).to("cuda")
    # print('input_ids', input_ids)
    # print('labels', labels)
    # when offset is used, we move the last offset from input_ids to the front of labels.
    if offset != 0:
        to_move = input_ids[0][-offset-2:-2]
        labels = torch.cat((labels[0][0].unsqueeze(0), to_move, labels[0][1:]), dim=0).unsqueeze(0)
        input_ids = torch.cat((input_ids[0][:-offset-2], input_ids[0][-2:]), dim=0).unsqueeze(0)
    # print('input_ids offset', input_ids)
    # print('labels offset', labels)
    outputs = model(input_ids, labels=labels)
    return -outputs.loss, outputs.logits

In [None]:
def get_offsetted(inputs_pretokenized, completion, offset=0):
    # input_ids: 1*len = words + 32099 + 1
    input_ids = tokenizer(inputs_pretokenized, return_tensors="pt").input_ids.to("cuda")
    # labels: 1*len = 32099 + words
    labels = completion.unsqueeze(0).to("cuda")
    # print('input_ids', input_ids)
    # print('labels', labels)
    # when offset is used, we move the last offset from input_ids to the front of labels.
    if offset != 0:
        to_move = input_ids[0][-offset-2:-2]
        labels = torch.cat((labels[0][0].unsqueeze(0), to_move, labels[0][1:]), dim=0)
        input_ids = torch.cat((input_ids[0][:-offset-2], input_ids[0][-2:]), dim=0)
    else:
        # squeeze the batch dimension
        labels = labels[0]
        input_ids = input_ids[0]
    # print('input_ids offset', input_ids)
    # print('labels offset', labels)
    return (input_ids, labels)

In [None]:
''''obtain the offsetted input_ids and labels for each completion for each id'''
id_and_offset_to_input_and_completions = {}
max_offset = 5
for id in range(len(id_to_completions)): #len(id_to_completions)
    # # offset = 0
    # id_to_offset_to_input_and_completions[(id, 0)] = []
    # for completion in id_to_completions[id]:
    #     id_to_offset_to_input_and_completions[(id, 0)].append(get_offsetted(lambada[id]['inputs_pretokenized'], completion, offset=0))
    # print(id_to_offset_to_input_and_completions[(id, 0)])
    # print('---------------')
    # print('id:', id)
    for offset in range(max_offset):
        # print('offset:', offset)
        id_and_offset_to_input_and_completions[(id, offset)] = []
        for completion in id_to_completions[id]:
            id_and_offset_to_input_and_completions[(id, offset)].append(get_offsetted(lambada[id]['inputs_pretokenized'], completion, offset=offset))
            # print(get_offsetted(lambada[id]['inputs_pretokenized'], completion, offset=offset))
            # print('---------------')

In [None]:
def get_avg_log_p_of_completion_without_pad_batch(inputs_pretokenized_batch, completions_batch):
    # input_ids: batch_size*len = words + 32099 + 1
    input_ids = tokenizer(inputs_pretokenized_batch, return_tensors="pt", padding=True).input_ids.to("cuda")
    labels = completions_batch.to("cuda")
    outputs = model(input_ids, labels=labels)
    return -outputs.loss, outputs.logits

In [None]:
''' count the number of correct predictions again using get_avg_log_p_of_completions_without_pad and ids_to_completions_without_pad'''
count_correct_avg_log_p_reranking_without_pad = 0
for example_index in range(100): # len(lambada)
    # print(example_index)
    input_string = lambada[example_index]['inputs_pretokenized']
    completion_avg_log_p_max = -10000000
    best_completion =  ""
    print('-------------')
    for completion in id_to_completions[example_index]:
        avg_log_p, logits = get_avg_log_p_of_completion_without_pad(input_string, completion, offset=0)
        # print(avg_log_p)
        # print(logits)
        # probs = torch.nn.functional.softmax(logits, dim=-1)
        # argmax for each index
        # for i in range(probs.shape[1]):
            # print(torch.argmax(probs[0][i]))
            # print(probs[0][i][torch.argmax(probs[0][i])])
        print('avg_log_p', avg_log_p)
        if avg_log_p > completion_avg_log_p_max:
            completion_avg_log_p_max = avg_log_p
            best_completion = completion
    if best_completion != "":
        best_completion_string = tokenizer.decode(best_completion)
        if get_words_from_completions([best_completion_string]) != []:
            best_word = get_words_from_completions([best_completion_string])[0]
            # print(best_word)
            # print(best_completion)
            if best_word == lambada[example_index]['targets_pretokenized'][0]:
                count_correct_avg_log_p_reranking_without_pad += 1

In [None]:
''' count the number of correct predictions again using get_avg_log_p_of_completions_without_pad and ids_to_completions_without_pad using batch processing'''
count_correct_avg_log_p_reranking_without_pad_batch = 0
id_to_offset_to_completion_probs = dict()
for example_index in tqdm(range(100)): # len(lambada)
    # print(example_index)
    input_string = lambada[example_index]['inputs_pretokenized']
    input_ids = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
    if len(id_to_completions[example_index]) == 0:
        continue
    completion_avg_log_p_max = -10000000
    best_completion =  ""
    # for completion in id_to_completions[example_index]:
    #     avg_log_p, logits = get_avg_log_p_of_completion_without_pad(input_string, completion, offset=1)
    #     if avg_log_p > completion_avg_log_p_max:
    #         completion_avg_log_p_max = avg_log_p
    #         best_completion = completion
    completions_batch = torch.nn.utils.rnn.pad_sequence(id_to_completions[example_index], batch_first=True, padding_value=tokenizer.pad_token_id)
    # completions_batch
    # create a number of input_ids same to the number of elements in ids_to_completions[0]
    input_ids_batch = torch.cat([input_ids for i in range(len(id_to_completions[example_index]))], dim=0)
    # print('input_ids_batch', input_ids_batch)
    # print('completions_batch', completions_batch)
    outputs = model(input_ids_batch, labels=completions_batch)
    # print('-------------')
    for completion_index in range(len(id_to_completions[example_index])):
        avg_log_p = -loss_fn(outputs.logits[completion_index][1:], completions_batch[completion_index][1:]) # [1:] to remove the first token <extra_id_0>
        # print('avg_log_p', avg_log_p)
        if avg_log_p > completion_avg_log_p_max:
            completion_avg_log_p_max = avg_log_p
            best_completion = completions_batch[completion_index]

    if best_completion != "":
        best_completion_string = tokenizer.decode(best_completion)
        if get_words_from_completions([best_completion_string]) != []:
            best_word = get_words_from_completions([best_completion_string])[0]
            if best_word == lambada[example_index]['targets_pretokenized'][0]:
                count_correct_avg_log_p_reranking_without_pad_batch += 1

In [None]:
min_input_length = 0
for example_index in tqdm(range(100)): # len(lambada)
    input_string = lambada[example_index]['inputs_pretokenized']
    input_ids = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
    min_input_length = min(min_input_length, len(input_ids[0]))

In [None]:
''' obtain the avg_log_ps '''
import traceback
import datetime

id_and_offset_to_completion_probs = dict()
failed_example_indices = []
for example_index in tqdm(range(len(lambada))): # len(lambada)
    try:
        if len(id_to_completions[example_index]) == 0:
            continue
        for offset in range(max_offset):
            completions_batch = torch.nn.utils.rnn.pad_sequence([id_and_offset_to_input_and_completions[(example_index, offset)][i][1] for i in range(len(id_to_completions[example_index]))], batch_first=True, padding_value=tokenizer.pad_token_id)
            input_ids_batch = torch.cat([id_and_offset_to_input_and_completions[(example_index, offset)][i][0].unsqueeze(0) for i in range(len(id_to_completions[example_index]))], dim=0)
            outputs = model(input_ids_batch, labels=completions_batch)
            for completion_index in range(len(id_to_completions[example_index])):
                avg_log_p = -loss_fn(outputs.logits[completion_index][1+offset:], completions_batch[completion_index][1+offset:]) # [1:] to remove the first token <extra_id_0>
                id_and_offset_to_completion_probs[(example_index, offset, completion_index)] = avg_log_p.detach().cpu().tolist()
            
            # allocated_memory_bytes = torch.cuda.memory_allocated()
            # # Convert the allocated memory to gigabytes
            # allocated_memory_gb = allocated_memory_bytes / (1024 ** 3)
            # print(f"Current GPU memory allocation: {allocated_memory_gb} GB")
    except Exception as e:
        print(f"An error occurred: {e}")
        print('example_index:', example_index, ' failed')
        failed_example_indices.append(example_index)
        traceback.print_exc()

# save avg_log_ps into a pickle file with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
with open(f'id_and_offset_to_completion_probs_{timestamp}.pickle', 'wb') as handle:
    pickle.dump(id_and_offset_to_completion_probs, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
failed_example_indices
# save failed_example_indices into a pickle file with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
with open(f'failed_example_indices_{timestamp}.pickle', 'wb') as handle:
    pickle.dump(failed_example_indices, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# id_and_offset_to_completion_probs_20230509-155934.pickle
import pickle
with open('id_and_offset_to_completion_probs_20230528-064154_max_offset_61.pickle', 'rb') as handle:
    id_and_offset_to_completion_probs = pickle.load(handle)
# failed_example_indices_20230509-155934.pickle
with open('failed_example_indices_20230528-064154_61.pickle', 'rb') as handle:
    failed_example_indices = pickle.load(handle)

In [None]:
len(failed_example_indices)

In [None]:
import pickle
# id_and_offset_to_completion_probs_20230516-222830_max_offset_21.pickle
with open('id_and_offset_to_completion_probs_20230516-222830_max_offset_21.pickle', 'rb') as handle:
    id_and_offset_to_completion_probs = pickle.load(handle)
# failed_example_indices_20230516-222830_max_offset_21.pickle
with open('failed_example_indices_20230516-222830_21.pickle', 'rb') as handle:
    failed_example_indices = pickle.load(handle)

In [None]:
list(id_and_offset_to_completion_probs.keys())[:100]

In [None]:
from tqdm import tqdm
'''EOC with max pooling'''
max_offset_test = 60
offset_to_accuracy = dict()
for offset_test in range(max_offset_test):
    count_eoc = 0
    # postprocess the id_and_offset_to_completion_probs to get the best completion
    for example_index in tqdm(range(len(lambada))): # len(lambada)
        if len(id_to_completions[example_index]) == 0 or example_index in failed_example_indices:
            continue
        completion_avg_log_p_max = -10000000
        best_completion =  ""
        for offset in range(offset_test+1):
            for completion_index in range(len(id_to_completions[example_index])):
                avg_log_p = id_and_offset_to_completion_probs[(example_index, offset, completion_index)]
                if avg_log_p > completion_avg_log_p_max:
                    completion_avg_log_p_max = avg_log_p
                    best_completion = id_to_completions[example_index][completion_index]

        best_completion_string = tokenizer.decode(best_completion)
        # print('best_completion_string', best_completion_string)
        if get_words_from_completions([best_completion_string]) != []:
            best_word = get_words_from_completions([best_completion_string])[0]
            if best_word == lambada[example_index]['targets_pretokenized'][0]:
                count_eoc += 1
    offset_to_accuracy[offset_test] = count_eoc / (len(lambada) - len(failed_example_indices))

In [None]:
'''EOC with avg pooling'''
max_offset_test = 60
offset_to_accuracy_avg_pooling = dict()
for offset_test in range(max_offset_test):
    count_eoc = 0
    # postprocess the id_and_offset_to_completion_probs to get the best completion
    for example_index in tqdm(range(len(lambada))): # len(lambada)
        if len(id_to_completions[example_index]) == 0 or example_index in failed_example_indices:
            continue
        completion_avg_log_p_avg_over_offset_max = -10000000
        best_completion =  ""
        for completion_index in range(len(id_to_completions[example_index])):
            completion_avg_log_p_avg_over_offset = 0
            for offset in range(offset_test+1):
                avg_log_p = id_and_offset_to_completion_probs[(example_index, offset, completion_index)]
                completion_avg_log_p_avg_over_offset += avg_log_p
            completion_avg_log_p_avg_over_offset /= (offset_test+1)
            if completion_avg_log_p_avg_over_offset > completion_avg_log_p_avg_over_offset_max:
                completion_avg_log_p_avg_over_offset_max = completion_avg_log_p_avg_over_offset
                best_completion = id_to_completions[example_index][completion_index]
        best_completion_string = tokenizer.decode(best_completion)
        # print('best_completion_string', best_completion_string)
        if get_words_from_completions([best_completion_string]) != []:
            best_word = get_words_from_completions([best_completion_string])[0]
            if best_word == lambada[example_index]['targets_pretokenized'][0]:
                count_eoc += 1
    offset_to_accuracy_avg_pooling[offset_test] = count_eoc / (len(lambada) - len(failed_example_indices))

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

# Load a font
font_path = '/usr/share/fonts/urw-base35/NimbusMonoPS-Italic.otf'
font_prop = fm.FontProperties(fname=font_path)

# offset = 0 corresponds to the baseline, which is no. ensembled conditionals = 1; adjust the offset by 1
no_ensembled_conditionals_to_accuracy = dict()
for offset in range(1, max_offset_test+1):
    no_ensembled_conditionals_to_accuracy[offset] = offset_to_accuracy[offset-1]


max_line = plt.plot(list(no_ensembled_conditionals_to_accuracy.keys()), list(no_ensembled_conditionals_to_accuracy.values()), label='max')
plt.xlabel('No. ensembled conditionals', fontsize=14)
plt.ylabel('Accuracy', fontsize=14)
# the interval on x should be 10
plt.xticks(np.arange(10, max(list(no_ensembled_conditionals_to_accuracy.keys()))+1, 10))\
# add a tick at 1 on the x axis
plt.xticks(list(plt.xticks()[0]) + [1])

plt.xticks(fontsize=13)
plt.yticks(fontsize=13)

# add a dot at each point
plt.scatter(list(no_ensembled_conditionals_to_accuracy.keys()), list(no_ensembled_conditionals_to_accuracy.values()))


# add a yellow horizontal line at y=offset_to_accuracy[0]
plt.axhline(y=no_ensembled_conditionals_to_accuracy[1], color='y', linestyle='--')
# add the word "baseline" at the end of the yellow line in the font of calibri
plt.text(48, no_ensembled_conditionals_to_accuracy[1] + 0.0002, 'baseline', fontproperties=font_prop, fontsize=13)

# plot the accuracy with avg pooling
avg_line = plt.plot([item+1 for item in list(offset_to_accuracy_avg_pooling.keys())], list(offset_to_accuracy_avg_pooling.values()), color='r', label='avg')
# add a dot at each point
plt.scatter([item+1 for item in list(offset_to_accuracy_avg_pooling.keys())], list(offset_to_accuracy_avg_pooling.values()), color='r')

plt.scatter(1, no_ensembled_conditionals_to_accuracy[1], color='y')


plt.legend(handles=[max_line[0], avg_line[0]], loc='upper center', bbox_to_anchor=(0.9, 0.45), ncol=1, fontsize=10)


plt.tight_layout()

# show the plot at a high resolution
plt.savefig('no_ensembled_conditionals_to_accuracy_combined.png', dpi=1200)

# plt.show()


In [None]:
offset_to_accuracy

In [None]:
import matplotlib.font_manager

fonts = matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf')

print('Number of fonts: ', len(fonts))
for font in fonts:
    print(font)


In [None]:
max(list(offset_to_accuracy.keys()))+1

In [None]:
plt.plot(list(offset_to_accuracy_avg_pooling.keys()), list(offset_to_accuracy_avg_pooling.values()))
plt.xlabel('offset')
plt.ylabel('accuracy')
# the interval on x should be 1
plt.xticks(np.arange(min(list(offset_to_accuracy_avg_pooling.keys())), max(list(offset_to_accuracy_avg_pooling.keys()))+1, 1.0))
# add a dot at each point
plt.scatter(list(offset_to_accuracy_avg_pooling.keys()), list(offset_to_accuracy_avg_pooling.values()))
plt.show()

In [None]:
# empty the cache
torch.cuda.empty_cache()

In [None]:
import torch

# Get the current GPU memory allocation
allocated_memory_bytes = torch.cuda.memory_allocated(device=0)
# Convert the allocated memory to gigabytes
allocated_memory_gb = allocated_memory_bytes / (1024 ** 3)
print(f"Current GPU memory allocation: {allocated_memory_gb} GB")

In [None]:
'''Concatenate all completions to get a huge tensor'''
all_completions = []

In [None]:
completions_batch = torch.nn.utils.rnn.pad_sequence(id_to_completions[0] + id_to_completions[1] , batch_first=True, padding_value=tokenizer.pad_token_id)
completions_batch

In [None]:
input_string = lambada[0]['inputs_pretokenized']
input_ids = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
id_to_completions[0]

In [None]:
total_no_completions = 0
no_completions_list = []
for id in id_to_completions:
    total_no_completions += len(id_to_completions[id])
    no_completions_list.append(len(id_to_completions[id]))
# avg
total_no_completions/len(id_to_completions)

In [None]:
'''test to make sure that padding and concatenating individual examaples doesn't mess up results'''

# completions is below
# [tensor([32099,  3957,     3,     5], device='cuda:0'),
#  tensor([32099,    24,     3,     5], device='cuda:0'),
#  tensor([32099,  3957,     3,     5], device='cuda:0'),
#  tensor([32099, 11831,     7,     3,     5], device='cuda:0')]

#convert to completions_batch by padding to the max length
completions_batch = torch.nn.utils.rnn.pad_sequence(id_to_completions[0], batch_first=True, padding_value=tokenizer.pad_token_id)
# completions_batch
# create a number of input_ids same to the number of elements in ids_to_completions[0]
input_ids_batch = torch.cat([input_ids for i in range(len(id_to_completions[0]))], dim=0)

outputs = model(input_ids_batch, labels=completions_batch)
# outputs = model(input_ids_batch[1].unsqueeze(0), labels=completions_batch[1].unsqueeze(0))
# outputs.loss

In [None]:
completions_batch_1 = id_to_completions[0][1].unsqueeze(0)
input_ids_batch_1 = input_ids
outputs_1 = model(input_ids_batch_1, labels=completions_batch_1)

In [None]:
outputs.logits[1][:-1] == outputs_1.logits[0]
print(outputs.logits[1][:-1])
print(outputs_1.logits[0])
print(completions_batch_1)

loss0 = loss_fn(outputs.logits[1][:-1], completions_batch_1[0])
loss0_padded = loss_fn(outputs.logits[1], completions_batch[1])
loss0_1 = loss_fn(outputs_1.logits[0], completions_batch_1[0])
print(loss0)
print(loss0_padded)
print(loss0_1)

In [None]:
completions

In [None]:
id_to_completions[0]

In [None]:
print(outputs.logits[1].shape)
print(outputs_1.logits[0].shape)

In [None]:
# loss_fn(outputs.logits, completions_batch)
logits_0 = outputs.logits[1]
completions_batch_0 = completions_batch[1]
# logits_0.shape

loss = loss_fn(logits_0, completions_batch_0)
loss

In [None]:
outputs['scores']
for i in range(len(outputs['scores'])):
    probs_outputs_beam_scores = torch.nn.functional.softmax(outputs['scores'][i], dim=-1)
    print(probs_outputs_beam_scores[0])
    # argmax 
    print(torch.argmax(probs_outputs_beam_scores[0]))
    # decode it
    # print(tokenizer.decode(torch.argmax(probs_outputs_beam_scores[0])))
    # the value
    print(probs_outputs_beam_scores[0][torch.argmax(probs_outputs_beam_scores[0])])

In [None]:
input_string = lambada[0]['inputs_pretokenized']
inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
num_beams = 2
outputs = model.generate(inputs, 
                        # max_length=8, 
                        num_beams=num_beams, 
                        num_return_sequences=num_beams,
                        # eos_token_id=tokenizer.convert_tokens_to_ids('<extra_id_1>'),
                        return_dict_in_generate=True,
                        output_scores=True)

In [None]:
input_ids = tokenizer(lambada[0]['inputs_pretokenized'], return_tensors="pt").input_ids.to("cuda")
print(lambada[0]['inputs_pretokenized'])
labels = tokenizer("<extra_id_0> " + id_to_word_and_punc_pairs_processed[0][2] + " <extra_id_1>", return_tensors="pt").input_ids.to("cuda")
print("<extra_id_0> " + id_to_word_and_punc_pairs_processed[0][2] + " <extra_id_1>")
outputs = model(input_ids, labels=labels)
loss = outputs.loss
logits = outputs.logits
# logits

In [None]:
# calculate loss with cross entropy loss function
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
loss

In [None]:
input_string = "[NLG] A man is having a bun for <extra_id_0>"
inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
num_beams = 1
# outputs = model.generate(inputs, num_beams=num_beams, max_length=3, num_return_sequences=num_beams, output_scores=True, return_dict_in_generate=True)
outputs = model.generate(inputs, max_length=3, output_scores=True, return_dict_in_generate=True)
# outputs = model.generate(inputs, output_scores=True, return_dict_in_generate=True)

for i in range(num_beams):
    # print(outputs['sequences'][i])
    print(tokenizer.decode(outputs['sequences'][i]))
    # decode outputs['sequences'][i] one by one token
    print('-------------')
    # print(outputs['sequences_scores'][i])

In [None]:
# /work/09127/tomyoung/ls6/data/pkls/url_to_probs_c4_dict_with_labels_t5_11b_valid.pkl
import pickle
with open('/work/09127/tomyoung/ls6/data/pkls/url_to_probs_c4_dict_with_labels_t5_11b_valid.pkl', 'rb') as f:
    url_to_probs_c4_dict_with_labels_t5_11b_valid = pickle.load(f)

In [None]:
# '/work/09127/tomyoung/ls6/data/pkls/acceptable_alternatives_1000_ignore_cws_nos_50_valid.pkl'
import pickle
with open('/work/09127/tomyoung/ls6/data/pkls/acceptable_alternatives_1000_ignore_cws_nos_50_valid.pkl', 'rb') as f:
    acceptable_alternatives_1000_ignore_cws_nos_50_valid = pickle.load(f)

In [None]:
acceptable_alternatives_1000_ignore_cws_nos_50_valid[(180, 11, 8)][3]

In [None]:
list(acceptable_alternatives_1000_ignore_cws_nos_50_valid.keys())[320:]

In [None]:
all_keys = list(acceptable_alternatives_1000_ignore_cws_nos_50_valid.keys())
for key in all_keys[1000:2000]:
    if len(acceptable_alternatives_1000_ignore_cws_nos_50_valid[key][3]) > 1:
        print(key)
        for item in acceptable_alternatives_1000_ignore_cws_nos_50_valid[key][3]:
            print(item)
        # print(*acceptable_alternatives_1000_ignore_cws_nos_50_valid[key][3])

In [None]:
dicts_realnewslike[0][0]

In [None]:
c4_json_file = '/work/09127/tomyoung/ls6/data/jsons/c4-validation.00000-of-00001-list-of-lists.json'
import json
with open(c4_json_file, 'r', encoding='utf8') as f:
    dicts_realnewslike = json.load(f)

In [None]:
from transformers import T5Tokenizer, BartTokenizer
t5_tokenizer = T5Tokenizer.from_pretrained('t5-3b')
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")

In [None]:
'''based on url_to_probs_c4_dict_with_labels_t5_11b_valid and acceptable_alternatives_1000_ignore_cws_nos_50_valid
generate the input string and target string for the models like ul2 and glm'''
proposed_bigram = t5_tokenizer.decode(url_to_probs_c4_dict_with_labels_t5_11b_valid[((0,17,4),(0,17,5),1)]['proposed bigram'])

acceptable_alternatives_1000_ignore_cws_nos_50_valid[(0,17,4)][0][1]

proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[(0,17,4)][0][1][4+1:4+3])
print('proposed_bigram:', proposed_bigram)

preceding_tokens = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[(0,17,4)][0][1][:4])
print('preceding_tokens:', preceding_tokens)

following_tokens = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[(0,17,4)][0][1][4+3:])
print('following_tokens:', following_tokens)


In [None]:
url_to_probs_c4_dict_with_labels_t5_11b_valid_keys = list(url_to_probs_c4_dict_with_labels_t5_11b_valid.keys())

url_to_ul2_probs_dict = {}
import math
for i in tqdm(range(len(url_to_probs_c4_dict_with_labels_t5_11b_valid_keys))):#len(url_to_probs_c4_dict_with_labels_t5_11b_valid_keys)
    key = url_to_probs_c4_dict_with_labels_t5_11b_valid_keys[i]
    url_to_ul2_probs_dict[url_to_probs_c4_dict_with_labels_t5_11b_valid_keys[i]] = {}
    show(key)
    show(key[0])
    show(acceptable_alternatives_1000_ignore_cws_nos_50_valid[key[0]])
    # acceptable_alternatives_1000_ignore_cws_nos_50_valid
    # process the proposed bigram and token
    url = key[0]
    story_id = key[0][0]
    paragraph_id = key[0][1]
    completion_id = key[2]
    proposed_token_pos = key[0][2]

    ''' proposed token '''
    # acceptable_alternatives_1000_ignore_cws_nos_50_valid has an <s> </s>
    proposed_token = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+1:proposed_token_pos+2])
    preceding_tokens_to_proposed_token = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][:proposed_token_pos+1])
    following_tokens_to_proposed_token = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+2:])
    show('proposed_token:', proposed_token)
    show('preceding_tokens_to_proposed_token:', preceding_tokens_to_proposed_token)
    show('following_tokens_to_proposed_token:', following_tokens_to_proposed_token)
    # generate the input string by adding a <extra_id_0> in the middle of the sentence
    input_string = preceding_tokens_to_proposed_token + ' <extra_id_0>' + following_tokens_to_proposed_token
    # remove <s> and </s> in the input string
    input_string = input_string.replace('<s>', '')
    input_string = input_string.replace('</s>', '')
    # add [NLU] to the input string
    input_string = '[NLU] ' + input_string
    show('input_string:', input_string)
    target_string = '<extra_id_0>' + proposed_token + ' <extra_id_1>'
    show('target_string_proposed_token:', target_string)
    inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
    labels = tokenizer(target_string, return_tensors="pt").input_ids.to("cuda")
    show('labels_proposed_token:', labels)
    # remove the last </s> token from labels
    labels = labels[:, :-1].contiguous()
    show('inputs:', inputs)
    show('labels:', labels)
    outputs = model(inputs, labels=labels)
    # print('outputs:', outputs)
    log_p = -loss_fn_sum(outputs.logits[0][1:], labels[0][1:]) # [1:] to remove the first token <extra_id_0>
    show('log_p_proposed_token:', log_p)
    example_raw_sequence = dicts_realnewslike[story_id][paragraph_id]
    show('example_raw_sequence:', example_raw_sequence)
    # tokenize the raw sequence
    example_raw_sequence_bart_tokenized = bart_tokenizer.tokenize(example_raw_sequence)
    bart_ids_original = bart_tokenizer.convert_tokens_to_ids(example_raw_sequence_bart_tokenized)
    
    ''' original token '''
    original_token = bart_tokenizer.decode(bart_ids_original[proposed_token_pos])
    show('original_token:', original_token)
    target_string_original_token = '<extra_id_0>' + original_token + ' <extra_id_1>'
    show('target_string_original_token:', target_string_original_token)
    labels_original_token = tokenizer(target_string_original_token, return_tensors="pt").input_ids.to("cuda")
    labels_original_token = labels_original_token[:, :-1].contiguous()
    show('labels_original_token:', labels_original_token)
    outputs = model(inputs, labels=labels_original_token)
    log_p_original_token = -loss_fn_sum(outputs.logits[0][1:], labels_original_token[0][1:]) # [1:] to remove the first token <extra_id_0>
    show('log_p_original_token:', log_p_original_token)
    
    
    ''' proposed bigram '''
    constant_token_pos = key[1][2]
    proposed_token_is_to_the_left = proposed_token_pos < constant_token_pos
    if proposed_token_is_to_the_left:
        proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+1:proposed_token_pos+3])
        preceding_tokens_to_proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][:proposed_token_pos+1])
        following_tokens_to_proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+3:])
        original_bigram = bart_tokenizer.decode(bart_ids_original[proposed_token_pos:proposed_token_pos+2])
    else:
        proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos:proposed_token_pos+2])
        preceding_tokens_to_proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][:proposed_token_pos])
        following_tokens_to_proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+2:])
        original_bigram = bart_tokenizer.decode(bart_ids_original[proposed_token_pos-1:proposed_token_pos+1])
    show('proposed_bigram:', proposed_bigram)
    show('preceding_tokens_to_proposed_bigram:', preceding_tokens_to_proposed_bigram)
    show('following_tokens_to_proposed_bigram:', following_tokens_to_proposed_bigram)
    show('original_bigram:', original_bigram)
    
    
    # generate the input string by adding a <extra_id_0> in the middle of the sentence
    input_string = preceding_tokens_to_proposed_bigram + ' <extra_id_0>' + following_tokens_to_proposed_bigram
    # remove <s> and </s> in the input string
    input_string = input_string.replace('<s>', '')
    input_string = input_string.replace('</s>', '')
    # add [NLU] to the input string
    input_string = '[NLU] ' + input_string
    show('input_string:', input_string)
    target_string = '<extra_id_0>' + proposed_bigram + ' <extra_id_1>'
    show('target_string:', target_string)
    inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
    labels = tokenizer(target_string, return_tensors="pt").input_ids.to("cuda")
    # remove the last </s> token from labels
    labels = labels[:, :-1].contiguous()
    show('inputs:', inputs)
    show('labels:', labels)
    outputs = model(inputs, labels=labels)
    # print('outputs:', outputs)
    log_p_proposed_bigram = -loss_fn_sum(outputs.logits[0][1:], labels[0][1:]) #
    show('log_p_proposed_bigram:', log_p_proposed_bigram)


    '''original bigram'''
    target_string_original_bigram = '<extra_id_0>' + original_bigram + ' <extra_id_1>'
    show('target_string_original_bigram:', target_string_original_bigram)
    labels_original_bigram = tokenizer(target_string_original_bigram, return_tensors="pt").input_ids.to("cuda")
    labels_original_bigram = labels_original_bigram[:, :-1].contiguous()
    outputs = model(inputs, labels=labels_original_bigram)
    # print('outputs:', outputs)
    log_p_original_bigram = -loss_fn_sum(outputs.logits[0][1:], labels_original_bigram[0][1:]) #
    show('log_p_original_bigram:', log_p_original_bigram)
    # put the log_p's into the dictionary
    url_to_ul2_probs_dict[key] = {'proposed token': math.exp(log_p.to(torch.float32).detach().cpu().numpy()),
                                'original token': math.exp(log_p_original_token.to(torch.float32).detach().cpu().numpy()),
                                'proposed bigram': math.exp(log_p_proposed_bigram.to(torch.float32).detach().cpu().numpy()),
                                'original bigram': math.exp(log_p_original_bigram.to(torch.float32).detach().cpu().numpy())}
url_to_ul2_probs_dict_filepath = '/work/09127/tomyoung/ls6/data/pkls/url_to_ul2_probs_dict_valid.pkl'
with open(url_to_ul2_probs_dict_filepath, 'wb') as f:
    pickle.dump(url_to_ul2_probs_dict, f)


In [None]:
url_to_ul2_probs_dict_filepath = '/work/09127/tomyoung/ls6/data/pkls/url_to_ul2_probs_dict_valid.pkl'
with open(url_to_ul2_probs_dict_filepath, 'wb') as f:
    pickle.dump(url_to_ul2_probs_dict, f)


In [None]:
# /work/09127/tomyoung/ls6/data/pkls/url_to_probs_c4_dict_with_labels_t5_11b_valid.pkl
import pickle
with open('/work/09127/tomyoung/ls6/data/pkls/url_to_probs_c4_dict_with_labels_t5_11b_valid.pkl', 'rb') as f:
    url_to_probs_c4_dict_with_labels_t5_11b_valid = pickle.load(f)

In [None]:
input_string = "[NLG] The headquarters of microsoft is in <extra_id_0>"                                               

inputs = tokenizer(input_string, return_tensors="pt", add_special_tokens=False).input_ids.to("cuda")

outputs = model.generate(inputs, max_length=200)

print(tokenizer.decode(outputs[0]))

In [None]:
input_string = "[NLG] The headquarters of microsoft is in <extra_id_0>"
target_string = "<extra_id_0> redmond washington <extra_id_1>"
inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
labels = tokenizer(target_string, return_tensors="pt").input_ids.to("cuda")
labels = labels[:, :-1].contiguous()
outputs = model(inputs, labels=labels)
print('outputs:', outputs)
log_p = -loss_fn_sum(outputs.logits[0][1:-1], labels[0][1:-1]) #
print('log_p:', log_p)

probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
for i in range(len(labels[0])):
    print('label:', labels[0][i], 'prob:', probs[0][i][labels[0][i]])
    

In [None]:
labels.shape

In [None]:
labels = tokenizer(target_string, return_tensors="pt").input_ids.to("cuda")
labels

In [None]:
import pickle
with open('/work/09127/tomyoung/ls6/data/pkls/dict_url_to_completions_5_grams.pkl','rb') as f:
    dict_url_to_completions_5_grams = pickle.load(f)
# get a list of punctuations
import string
punctuations = string.punctuation

In [None]:
dict_url_to_completions_5_grams_keys = list(dict_url_to_completions_5_grams.keys())
url_to_ul2_5_gram_probs_dict = {}
# make sure it ends with a punctuation
for i in tqdm(range(1)):
    key = dict_url_to_completions_5_grams_keys[i]
    # print('key:', key)
    # print(dict_url_to_completions_5_grams[key])
    # alternative: make sure it ends with a punctuation
    if dict_url_to_completions_5_grams[key]['alternative'][-1] not in punctuations:
        continue
    # for 10-grams
    # get the input string 
    original_sentence_words = dict_url_to_completions_5_grams[key]['original_sentence'].split(' ')
    # len should >= 15
    if len(original_sentence_words) < 15:
        continue
    input_for_10_grams = '[NLG] ' + ' '.join(original_sentence_words[:-10]) + ' <extra_id_0>'
    input_for_10_grams_ids = tokenizer(input_for_10_grams, return_tensors="pt").input_ids.to("cuda")
    # original 10-gram
    original_10_gram = "<extra_id_0> " + ' '.join(original_sentence_words[-10:]) + ' <extra_id_1>'
    labels_original_10_gram = tokenizer(original_10_gram, return_tensors="pt").input_ids.to("cuda")
    labels_original_10_gram = labels_original_10_gram[:, :-1].contiguous() # remove the last token '</s>'
    outputs = model(input_for_10_grams_ids, labels=labels_original_10_gram)
    log_p_original_10_gram = -loss_fn_sum(outputs.logits[0][1:-1], labels_original_10_gram[0][1:-1]) # lose the <extra_id_0> and <extra_id_1>
    
    # probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    # for i in range(len(labels_original_10_gram[0])):
    #     print('label:', labels_original_10_gram[0][i], 'prob:', probs[0][i][labels_original_10_gram[0][i]])

    # continue
    # proposed 10-gram
    proposed_10_gram = "<extra_id_0> " + \
     ' '.join(original_sentence_words[-10:-5]) + ' ' + \
    dict_url_to_completions_5_grams[key]['alternative'] + ' <extra_id_1>'
    labels_proposed_10_gram = tokenizer(proposed_10_gram, return_tensors="pt").input_ids.to("cuda")
    labels_proposed_10_gram = labels_proposed_10_gram[:, :-1].contiguous() # remove the last token '</s>'
    outputs = model(input_for_10_grams_ids, labels=labels_proposed_10_gram)
    log_p_proposed_10_gram = -loss_fn_sum(outputs.logits[0][1:-1], labels_proposed_10_gram[0][1:-1]) # lose the <extra_id_0> and <extra_id_1>
    # for 5-grams
    # get the input string
    input_for_5_grams = '[NLG] ' + ' '.join(original_sentence_words[:-5]) + ' <extra_id_0>'
    input_for_5_grams_ids = tokenizer(input_for_5_grams, return_tensors="pt").input_ids.to("cuda")
    # original 5-gram
    original_5_gram = "<extra_id_0> " + ' '.join(original_sentence_words[-5:]) + ' <extra_id_1>'
    labels_original_5_gram = tokenizer(original_5_gram, return_tensors="pt").input_ids.to("cuda")
    labels_original_5_gram = labels_original_5_gram[:, :-1].contiguous() # remove the last token '</s>'
    outputs = model(input_for_5_grams_ids, labels=labels_original_5_gram)
    log_p_original_5_gram = -loss_fn_sum(outputs.logits[0][1:-1], labels_original_5_gram[0][1:-1]) # lose the <extra_id_0> and <extra_id_1>
    # proposed 5-gram
    proposed_5_gram = "<extra_id_0> " + dict_url_to_completions_5_grams[key]['alternative'] + ' <extra_id_1>'
    labels_proposed_5_gram = tokenizer(proposed_5_gram, return_tensors="pt").input_ids.to("cuda")
    labels_proposed_5_gram = labels_proposed_5_gram[:, :-1].contiguous() # remove the last token '</s>'
    outputs = model(input_for_5_grams_ids, labels=labels_proposed_5_gram)
    log_p_proposed_5_gram = -loss_fn_sum(outputs.logits[0][1:-1], labels_proposed_5_gram[0][1:-1]) # lose the <extra_id_0> and <extra_id_1>
    # add them to the dictionary
    url_to_ul2_5_gram_probs_dict[key] = {'proposed 5_gram': math.exp(log_p_proposed_5_gram.to(torch.float32).detach().cpu().numpy()),
                                         'original 5_gram': math.exp(log_p_original_5_gram.to(torch.float32).detach().cpu().numpy()),
                                         'proposed 10_gram': math.exp(log_p_proposed_10_gram.to(torch.float32).detach().cpu().numpy()),
                                         'original 10_gram': math.exp(log_p_original_10_gram.to(torch.float32).detach().cpu().numpy())}                                        


In [None]:
# save url_to_ul2_5_gram_probs_dict as a pkl
with open('/work/09127/tomyoung/ls6/data/pkls/url_to_ul2_5_gram_probs_dict.pkl', 'wb') as f:
    pickle.dump(url_to_ul2_5_gram_probs_dict, f)

In [None]:
# check the number of different tokens between the two 5-grams
for key in dict_url_to_completions_5_grams.keys():
    proposed_5_gram = dict_url_to_completions_5_grams[key]['alternative'].split(' ')
    original_5_gram = dict_url_to_completions_5_grams[key]['original_sentence'].split(' ')[-5:]
    number_of_different_tokens = 0
    for i in range(5):
        if proposed_5_gram[i] != original_5_gram[i]:
            number_of_different_tokens += 1
    print('key:', key)
    print('original_sentence:', dict_url_to_completions_5_grams[key]['original_sentence'])
    print('proposed_5_gram:', dict_url_to_completions_5_grams[key]['alternative'])
    print('original_5_gram:', ' '.join(dict_url_to_completions_5_grams[key]['original_sentence'].split(' ')[-5:]))
    print(number_of_different_tokens)