This code Use UL2 to 

(1) measure inconsistencies in its bidirectional conditionals; 

(2) improve llm inference with Emsemble of Conditionals.  



Table of Contents


* [Imports and global utils](#0)

* [Load tokenizer and model](#1)

* [Ensemble of Conditionals](#2)

<h1 style="font-size: 20px;"><a class="anchor" id="0"></a>Imports and global utils</h1>

In [14]:
'''imports'''
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
import general_utils
# clear GPU memory
if True:
    general_utils.kill_gpu_process(os.environ["CUDA_VISIBLE_DEVICES"])
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Tokenizer
import numpy as np
import pickle
import time
from tqdm import tqdm
import json
from lambada_utils import LambadaOutputProcessor


In [2]:
'''toggle 'print' on and off by creating 'show' as a proxy'''
enable_print = True
if enable_print:
    show = print
else:
    show = lambda *args, **kwargs: None

<h1 style="font-size: 20px;"><a class="anchor" id="1"></a>Load tokenizer and model</h1>

In [3]:
# We are using custom huggingface cache dirs in case the default one doesn't have the capacity, since the models can be quite large.
MY_HUGGINGFACE_CACHE_DIR ='huggingface_cache' # relative to this notebook path
tokenizer = AutoTokenizer.from_pretrained("google/ul2",
                                        cache_dir = MY_HUGGINGFACE_CACHE_DIR+'/google-ul2')


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
model = T5ForConditionalGeneration.from_pretrained("google/ul2", 
                                                   cache_dir=MY_HUGGINGFACE_CACHE_DIR + '/google-ul2', 
                                                   low_cpu_mem_usage=True, 
                                                   torch_dtype=torch.bfloat16).to("cuda")
model.parallelize() # TODO: the first gpu doesn't reduce its memory usage, but 9GB/18GB is on every other GPU

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



In [13]:
tokenizer.convert_tokens_to_ids("<extra_id_0>")

32099

<h1 style="font-size: 20px;"><a class="anchor" id="2"></a>Ensemble of Conditionals</h1>


In [7]:
LAMBADA_TEST_DATA_PATH = "data/jsonls/test.jsonl"

with open(LAMBADA_TEST_DATA_PATH, "r") as f:
    lambada = [json.loads(line) for line in f.readlines()]

# TODO: explain the term pretokenized
UL2_MODE = "[NLG]"

# To use the NLG mode of UL2, append [NLG] to the beginning of each input, and <extra_id_0> to the end
lambada = [
    {
        "inputs_pretokenized": UL2_MODE + " " + x['inputs_pretokenized'] + " <extra_id_0>",
        "targets_pretokenized": x['targets_pretokenized']
    } 
    for x in lambada
]


In [8]:
# instantiate the output processor
processor = LambadaOutputProcessor(tokenizer, lambada)

<details>
<summary>Strategy for different punctuations: (click to expand)</summary>

In the LAMBADA last word prediction task, natural language models (LLMs) may append various punctuations to the same last word, leading to different completions. For example, to complete the sentence "My color of my pet dog is":

Possible Completions:

1. _white._ with probability `p_1`
2. _white!_ with probability `p_2` (assuming `p_1 > p_2`)
3. _black,_ with probability `p_3`
4. _black?_ with probability `p_4` (assuming `p_3 > p_4`)

Strategies to Rank _white_ and _black_:

1. Maximum Probability Strategy

- Probability of _white_: `p(white) = p_1`
- Probability of _black_: `p(black) = p_3`

2. Sum of Probabilities Strategy

- Probability of _white_: `p(white) = p_1 + p_2`
- Probability of _black_: `p(black) = p_3 + p_4`

Afterwards `p(_white_)` and `p(_black_)` may need normalization.

In [9]:
'''On lambada, generate the top completions (through beam search) for each example, and get the word from each completion.'''
RUN_BEAM_SEARCH_CELL = False
if RUN_BEAM_SEARCH_CELL:
    # generate for all examples, and then get the words from the completions, and compare the first one with the target
    count_correct = 0 # No. correct last word predictions if only the top completion is considered
    count_correct_top_num_beams = 0 # ... if the top num_beams completions are considered
    count_no_words_found = 0  # No. examples where no valid last word is found

    # punctuated_word: the last word and the punctuation that follows it
    id_to_punctuated_words = {} # maps example index to a list of word and punc pairs; every punc is kept for each word
    id_to_punctuated_words_unique = {} # ...; every punc is kept for each word
    id_to_completions_ids = {}

    MAX_COMPLETION_LENGTH = 8 # for last word prediction, 8 is sufficient
    NUM_BEAMS = 20 # 20 is sufficient; more doesn't help

    # for example_index in tqdm(range(10)): # len(lambada)
    for example_index in tqdm(range(len(lambada))): # len(lambada)
        input_string = lambada[example_index]['inputs_pretokenized']
        inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
        outputs = model.generate(inputs,
                                max_length=MAX_COMPLETION_LENGTH, 
                                num_beams=NUM_BEAMS, 
                                num_return_sequences=NUM_BEAMS, 
                                output_scores=True,
                                eos_token_id=tokenizer.convert_tokens_to_ids('<extra_id_1>'), 
                                return_dict_in_generate=True)
        
        completions = [tokenizer.decode(outputs['sequences'][i]) for i in range(NUM_BEAMS)]
        completions_ids = [
            outputs['sequences'][i].cpu()
            for i in range(NUM_BEAMS)
            if processor.get_word_from_completion(completions[i]) is not None # if the completion has a valid last word
        ]

        words = processor.get_words_from_completions(completions)

        # TODO: combine them and move to utils.py
        completions_without_pad = processor.remove_pad_id(completions_ids)
        completions_without_pad_before_punctution = processor.before_first_punc(completions_without_pad)
        
        
        if words:
            if words[0] == lambada[example_index]['targets_pretokenized'][0]:
                count_correct += 1
        else:
            count_no_words_found += 1
            show("no words found")
        punctuated_words = processor.get_punctuated_words(completions)
        id_to_punctuated_words[example_index] = punctuated_words
        words_unique = list(set(words))
        id_to_punctuated_words_unique[example_index] = []
        
        id_to_completions_ids[example_index] = completions_without_pad_before_punctution

        # find the best punctuatuation for each unique word (Maximum Probability Strategy, 
        # completions are naturally ordered by probs by generate()) TODO: move this for loop to utils.py
        for word in words_unique:
            found = 0
            # iterate through the word and punc pairs, and find the one that matches the word
            for punctuated_word in punctuated_words:
                # it is a match if pair = word + punc
                ENDING_PUNCTUATIONS = ',!.:;?'
                for punc in ENDING_PUNCTUATIONS:
                    if punctuated_word == word + punc:
                        id_to_punctuated_words_unique[example_index].append(punctuated_word)
                        found = 1
                        break
                if found == 1:
                    break
        
        # calculate the number of correct top num_beams: if the correct word is in the top num_beams, then it is correct
        for word in words_unique:
            if word == lambada[example_index]['targets_pretokenized'][0]:
                count_correct_top_num_beams += 1
                break


In [10]:
'''Save the beam search results by generate()'''
RUN_SAVE_BEAM_SEARCH_RESULTS_CELL = False
if RUN_SAVE_BEAM_SEARCH_RESULTS_CELL:
    timed_pickle_filename = 'data/pkls/ul2_lambada_vanilla_beam_search_results_' + general_utils.get_time() + '.pickle'
    print(timed_pickle_filename)

    data_keys = ['count_correct', 'count_correct_top_num_beams', 'count_no_words_found',
                'id_to_punctuated_words', 'id_to_punctuated_words_unique', 'id_to_completions_ids']
    data = {}
    for key in data_keys:
        data[key] = locals()[key]

    with open(timed_pickle_filename, 'wb') as fp:
        pickle.dump(data, fp)

In [11]:
'''Load the beam search results'''
timed_pickle_filename = 'data/pkls/ul2_lambada_vanilla_beam_search_results_2023-11-11 20:08:17.pickle'
with open(timed_pickle_filename, 'rb') as fp:
    ul2_lambada_vanilla_beam_search_results = pickle.load(fp)
id_to_completions_ids = ul2_lambada_vanilla_beam_search_results['id_to_completions_ids']

<details>
<summary>K-offset Ensemble (click to expand)</summary>

__K-offset Ensemble__ is a particular type of __Ensemble of Conditionals__ for last word prediction tasks like lambada.

It aims to augment the only conditional distribution obtained by masking the last word with more distributions. The new distributions are obtained by masking the last __offset__ + 1 words.

An example with the _lambada[0]_

_lambada[0]['input_pretokenized']_: `... his mouth curved in a confident grin , i do n't care about <last_word>`

We consider candidates `['angels.', 'signs.', 'that.']`.

The baseline approach is to input `... his mouth curved in a confident grin , i do n't care about <extra_id_0>` to UL2 and obtain the distribution containing the 3 candidates.

For the offset=1 case in K-offset Ensemble, we mask an extra token `about` in the end and input instead

`... his mouth curved in a confident grin , i do n't care <extra_id_1>`

This gives us a different distribution regarding `['about angels.', 'about signs.', 'about that.']`. They are given in an autoregressive manner
e.g., `p(about angels) = p(about) * p(angels|about)`. Therefore we will use conditionals in the style of `p(angels|about)` to augment the baseline conditionals.

Cases where __K__ is larger are simple to extend into.




In [12]:
MAX_OFFSET = 30

In [13]:
'''Generate the offset samples'''
RUN_GENERATE_OFFSET_SAMPLES_CELL = False
if RUN_GENERATE_OFFSET_SAMPLES_CELL:
    id_and_offset_to_inputs_and_completions = \
        processor.get_offset_samples(
            ul2_lambada_vanilla_beam_search_results['id_to_completions_ids'], 
            max_offset=MAX_OFFSET
        )

In [14]:
'''Save the offset samples'''
RUN_SAVE_OFFSET_SAMPLES_CELL = False
if RUN_SAVE_OFFSET_SAMPLES_CELL:    
    timed_pickle_filename = 'data/pkls/offset_samples_' + 'max_offset_' + str(MAX_OFFSET) + '_' + general_utils.get_time() + '.pickle'
    print(timed_pickle_filename)
    with open(timed_pickle_filename, 'wb') as fp:
        pickle.dump(id_and_offset_to_inputs_and_completions, fp)

In [16]:
'''Load the offset samples'''
timed_pickle_filename = 'data/pkls/offset_samples_max_offset_30_2023-11-13-01:28:11.pickle'
with open(timed_pickle_filename, 'rb') as fp:
    id_and_offset_to_inputs_and_completions = pickle.load(fp)

In [33]:
''' obtain the avg_log_p_map '''
RUN_GET_AVG_LOG_PS_CELL = True
if RUN_GET_AVG_LOG_PS_CELL:
# id_and_offset_to_input_and_completions:
# (id, offset) -> [(input_ids_0, completion_ids_0), (input_ids_1, completion_ids_1), ...]
    avg_log_p_map = dict() # (id, offset, completion_index) -> avg_log_p of the tokens constituting the last word (might be punctuated)

    ce_loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) #reduction='avg'
    ce_loss_sum = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum') #reduction='sum'

    for example_index in tqdm(range(len(lambada))): 
    # for example_index in tqdm(range(1)): 
        if len(id_to_completions_ids[example_index]) == 0:
            continue
        for offset in range(MAX_OFFSET):
            # we batch the completions for each (example, offset)
            completions_batch = torch.nn.utils.rnn.pad_sequence(
                [
                    id_and_offset_to_inputs_and_completions[(example_index, offset)][i][1] # [1] is the completion_ids
                    for i in range(len(id_to_completions_ids[example_index]))
                ], 
                batch_first=True, 
                padding_value=tokenizer.pad_token_id
            ).cuda()

            input_ids_batch = torch.cat(
                [
                    id_and_offset_to_inputs_and_completions[(example_index, offset)][i][0].unsqueeze(0) 
                    for i in range(len(id_to_completions_ids[example_index]))
                ], 
                dim=0
            ).cuda()

            outputs = model(input_ids_batch, labels=completions_batch)

            for completion_index in range(len(id_to_completions_ids[example_index])):
                avg_log_p = -ce_loss(
                    # Only care about the tokens corresponding to the last word and omit offset tokens 
                    # the first one is <extra_id_0> and omitted
                    outputs.logits[completion_index][1+offset:], 
                    completions_batch[completion_index][1+offset:]
                )
                avg_log_p_map[(example_index, offset, completion_index)] = \
                    avg_log_p.detach().cpu().tolist()

  0%|          | 8/5153 [00:56<9:52:06,  6.91s/it] 

In [None]:
'''Save the avg_log_p_map'''
RUN_SAVE_AVG_LOG_P_MAP_CELL = False
if RUN_SAVE_AVG_LOG_P_MAP_CELL:
    pickle_filename = 'data/pkls/avg_log_p_map_' + 'max_offset_' + str(MAX_OFFSET) + '_' + general_utils.get_time() + '.pickle'
    print(pickle_filename)
    with open(pickle_filename, 'wb') as handle:
        pickle.dump(avg_log_p_map, handle)

In [13]:
'''Load the avg_log_p_map'''
pickle_filename = 'data/pkls/avg_log_p_map_2023-11-12 21:51:05.pickle'
# avg_log_p_map (Dict): (id, offset, completion_index) -> avg_log_p of the tokens constituting the last word (might be punctuated)
with open(pickle_filename, 'rb') as handle:
    avg_log_p_map = pickle.load(handle)

In [16]:
'''K-offset Ensemble with max reduction to emsemble the K different conditionals for the same last word, 
i.e., only the maximum avg_log_p is kept for each last word across different offsets. 
'''

# We test K-offset ensemble for K up to MAX_OFFSET_TEST; MAX_OFFSET_TEST should be <= MAX_OFFSET used during avg_log_p_map generation
MAX_OFFSET_TEST = 5 
offset_to_accuracy = dict()
for offset_test in range(MAX_OFFSET_TEST):
    count_correct = 0 # No. correct last word predictions with K-offset
    # Get the best completion based on avg_log_p_map
    for example_index in tqdm(range(len(lambada))): # len(lambada)
        
        # Create a list of tuples (avg_log_p, completion) for each completion
        avg_log_p_and_completion = [
            (avg_log_p_map[(example_index, offset, completion_index)], id_to_completions_ids[example_index][completion_index])
            for offset in range(offset_test + 1)
            for completion_index in range(len(id_to_completions_ids[example_index]))
        ]

        # Find the tuple with the maximum avg_log_p; this is essentially max reduction
        if len(avg_log_p_and_completion) == 0:
            continue
        best_avg_log_p, best_completion = max(avg_log_p_and_completion, key=lambda x: x[0])
        if processor.is_correct_completion(example_index, best_completion):
            count_correct += 1
    offset_to_accuracy[offset_test] = count_correct / (len(lambada))

100%|██████████| 5153/5153 [00:00<00:00, 48211.36it/s]
100%|██████████| 5153/5153 [00:00<00:00, 45864.34it/s]
100%|██████████| 5153/5153 [00:00<00:00, 39298.96it/s]
100%|██████████| 5153/5153 [00:00<00:00, 34342.56it/s]
100%|██████████| 5153/5153 [00:00<00:00, 30638.27it/s]


In [18]:
offset_to_accuracy

{0: 0.7719774888414516,
 1: 0.7745002910925675,
 2: 0.7758587230739376,
 3: 0.7745002910925675,
 4: 0.7756646613623133}

In [None]:
'''EOC with avg reduction'''
# avg reduction is not to be confused with avg_log_p, which is the average of log_p of the tokens constituting the last (punctuated) word
max_offset_test = 60
offset_to_accuracy_avg_reduction = dict()
for offset_test in range(max_offset_test):
    count_eoc = 0
    # postprocess the avg_log_p_map to get the best completion
    for example_index in tqdm(range(len(lambada))): # len(lambada)
        if len(id_to_completions[example_index]) == 0 or example_index in failed_example_indices:
            continue
        completion_avg_log_p_avg_over_offset_max = -10000000
        best_completion =  ""
        for completion_index in range(len(id_to_completions[example_index])):
            completion_avg_log_p_avg_over_offset = 0
            for offset in range(offset_test+1):
                avg_log_p = avg_log_p_map[(example_index, offset, completion_index)]
                completion_avg_log_p_avg_over_offset += avg_log_p
            completion_avg_log_p_avg_over_offset /= (offset_test+1)
            if completion_avg_log_p_avg_over_offset > completion_avg_log_p_avg_over_offset_max:
                completion_avg_log_p_avg_over_offset_max = completion_avg_log_p_avg_over_offset
                best_completion = id_to_completions[example_index][completion_index]
        best_completion_string = tokenizer.decode(best_completion)
        # print('best_completion_string', best_completion_string)
        if processor.get_words_from_completions([best_completion_string]) != []:
            best_word = processor.get_words_from_completions([best_completion_string])[0]
            if best_word == lambada[example_index]['targets_pretokenized'][0]:
                count_eoc += 1
    offset_to_accuracy_avg_reduction[offset_test] = count_eoc / (len(lambada) - len(failed_example_indices))

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

# Load a font
font_path = '/usr/share/fonts/urw-base35/NimbusMonoPS-Italic.otf'
font_prop = fm.FontProperties(fname=font_path)

# offset = 0 corresponds to the baseline, which is no. ensembled conditionals = 1; adjust the offset by 1
no_ensembled_conditionals_to_accuracy = dict()
for offset in range(1, max_offset_test+1):
    no_ensembled_conditionals_to_accuracy[offset] = offset_to_accuracy[offset-1]


max_line = plt.plot(list(no_ensembled_conditionals_to_accuracy.keys()), list(no_ensembled_conditionals_to_accuracy.values()), label='max')
plt.xlabel('No. ensembled conditionals', fontsize=14)
plt.ylabel('Accuracy', fontsize=14)
# the interval on x should be 10
plt.xticks(np.arange(10, max(list(no_ensembled_conditionals_to_accuracy.keys()))+1, 10))\
# add a tick at 1 on the x axis
plt.xticks(list(plt.xticks()[0]) + [1])

plt.xticks(fontsize=13)
plt.yticks(fontsize=13)

# add a dot at each point
plt.scatter(list(no_ensembled_conditionals_to_accuracy.keys()), list(no_ensembled_conditionals_to_accuracy.values()))


# add a yellow horizontal line at y=offset_to_accuracy[0]
plt.axhline(y=no_ensembled_conditionals_to_accuracy[1], color='y', linestyle='--')
# add the word "baseline" at the end of the yellow line in the font of calibri
plt.text(48, no_ensembled_conditionals_to_accuracy[1] + 0.0002, 'baseline', fontproperties=font_prop, fontsize=13)

# plot the accuracy with avg reduction
avg_line = plt.plot([item+1 for item in list(offset_to_accuracy_avg_reduction.keys())], list(offset_to_accuracy_avg_reduction.values()), color='r', label='avg')
# add a dot at each point
plt.scatter([item+1 for item in list(offset_to_accuracy_avg_reduction.keys())], list(offset_to_accuracy_avg_reduction.values()), color='r')

plt.scatter(1, no_ensembled_conditionals_to_accuracy[1], color='y')


plt.legend(handles=[max_line[0], avg_line[0]], loc='upper center', bbox_to_anchor=(0.9, 0.45), ncol=1, fontsize=10)


plt.tight_layout()

# show the plot at a high resolution
plt.savefig('no_ensembled_conditionals_to_accuracy_combined.png', dpi=1200)

# plt.show()


In [None]:
offset_to_accuracy

In [None]:
import matplotlib.font_manager

fonts = matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf')

print('Number of fonts: ', len(fonts))
for font in fonts:
    print(font)


In [None]:
max(list(offset_to_accuracy.keys()))+1

In [None]:
plt.plot(list(offset_to_accuracy_avg_reduction.keys()), list(offset_to_accuracy_avg_reduction.values()))
plt.xlabel('offset')
plt.ylabel('accuracy')
# the interval on x should be 1
plt.xticks(np.arange(min(list(offset_to_accuracy_avg_reduction.keys())), max(list(offset_to_accuracy_avg_reduction.keys()))+1, 1.0))
# add a dot at each point
plt.scatter(list(offset_to_accuracy_avg_reduction.keys()), list(offset_to_accuracy_avg_reduction.values()))
plt.show()

In [None]:
# empty the cache
torch.cuda.empty_cache()

In [None]:
import torch

# Get the current GPU memory allocation
allocated_memory_bytes = torch.cuda.memory_allocated(device=0)
# Convert the allocated memory to gigabytes
allocated_memory_gb = allocated_memory_bytes / (1024 ** 3)
print(f"Current GPU memory allocation: {allocated_memory_gb} GB")

In [None]:
'''Concatenate all completions to get a huge tensor'''
all_completions = []

In [None]:
completions_batch = torch.nn.utils.rnn.pad_sequence(id_to_completions[0] + id_to_completions[1] , batch_first=True, padding_value=tokenizer.pad_token_id)
completions_batch

In [None]:
input_string = lambada[0]['inputs_pretokenized']
input_ids = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
id_to_completions[0]

In [None]:
total_no_completions = 0
no_completions_list = []
for id in id_to_completions:
    total_no_completions += len(id_to_completions[id])
    no_completions_list.append(len(id_to_completions[id]))
# avg
total_no_completions/len(id_to_completions)

In [None]:
'''test to make sure that padding and concatenating individual examaples doesn't mess up results'''

# completions is below
# [tensor([32099,  3957,     3,     5], device='cuda:0'),
#  tensor([32099,    24,     3,     5], device='cuda:0'),
#  tensor([32099,  3957,     3,     5], device='cuda:0'),
#  tensor([32099, 11831,     7,     3,     5], device='cuda:0')]

#convert to completions_batch by padding to the max length
completions_batch = torch.nn.utils.rnn.pad_sequence(id_to_completions[0], batch_first=True, padding_value=tokenizer.pad_token_id)
# completions_batch
# create a number of input_ids same to the number of elements in ids_to_completions[0]
input_ids_batch = torch.cat([input_ids for i in range(len(id_to_completions[0]))], dim=0)

outputs = model(input_ids_batch, labels=completions_batch)
# outputs = model(input_ids_batch[1].unsqueeze(0), labels=completions_batch[1].unsqueeze(0))
# outputs.loss

In [None]:
completions_batch_1 = id_to_completions[0][1].unsqueeze(0)
input_ids_batch_1 = input_ids
outputs_1 = model(input_ids_batch_1, labels=completions_batch_1)

In [None]:
outputs.logits[1][:-1] == outputs_1.logits[0]
print(outputs.logits[1][:-1])
print(outputs_1.logits[0])
print(completions_batch_1)

loss0 = ce_loss(outputs.logits[1][:-1], completions_batch_1[0])
loss0_padded = ce_loss(outputs.logits[1], completions_batch[1])
loss0_1 = ce_loss(outputs_1.logits[0], completions_batch_1[0])
print(loss0)
print(loss0_padded)
print(loss0_1)

In [None]:
completions

In [None]:
id_to_completions[0]

In [None]:
print(outputs.logits[1].shape)
print(outputs_1.logits[0].shape)

In [None]:
# ce_loss(outputs.logits, completions_batch)
logits_0 = outputs.logits[1]
completions_batch_0 = completions_batch[1]
# logits_0.shape

loss = ce_loss(logits_0, completions_batch_0)
loss

In [None]:
outputs['scores']
for i in range(len(outputs['scores'])):
    probs_outputs_beam_scores = torch.nn.functional.softmax(outputs['scores'][i], dim=-1)
    print(probs_outputs_beam_scores[0])
    # argmax 
    print(torch.argmax(probs_outputs_beam_scores[0]))
    # decode it
    # print(tokenizer.decode(torch.argmax(probs_outputs_beam_scores[0])))
    # the value
    print(probs_outputs_beam_scores[0][torch.argmax(probs_outputs_beam_scores[0])])

In [None]:
input_string = lambada[0]['inputs_pretokenized']
inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
num_beams = 2
outputs = model.generate(inputs, 
                        # max_length=8, 
                        num_beams=num_beams, 
                        num_return_sequences=num_beams,
                        # eos_token_id=tokenizer.convert_tokens_to_ids('<extra_id_1>'),
                        return_dict_in_generate=True,
                        output_scores=True)

In [None]:
input_ids = tokenizer(lambada[0]['inputs_pretokenized'], return_tensors="pt").input_ids.to("cuda")
print(lambada[0]['inputs_pretokenized'])
labels = tokenizer("<extra_id_0> " + id_to_punctuated_words_processed[0][2] + " <extra_id_1>", return_tensors="pt").input_ids.to("cuda")
print("<extra_id_0> " + id_to_punctuated_words_processed[0][2] + " <extra_id_1>")
outputs = model(input_ids, labels=labels)
loss = outputs.loss
logits = outputs.logits
# logits

In [None]:
# calculate loss with cross entropy loss function
ce_loss = nn.CrossEntropyLoss()
loss = ce_loss(logits.view(-1, logits.shape[-1]), labels.view(-1))
loss

In [None]:
input_string = "[NLG] A man is having a bun for <extra_id_0>"
inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
num_beams = 1
# outputs = model.generate(inputs, num_beams=num_beams, max_length=3, num_return_sequences=num_beams, output_scores=True, return_dict_in_generate=True)
outputs = model.generate(inputs, max_length=3, output_scores=True, return_dict_in_generate=True)
# outputs = model.generate(inputs, output_scores=True, return_dict_in_generate=True)

for i in range(num_beams):
    # print(outputs['sequences'][i])
    print(tokenizer.decode(outputs['sequences'][i]))
    # decode outputs['sequences'][i] one by one token
    print('-------------')
    # print(outputs['sequences_scores'][i])

In [None]:
# /work/09127/tomyoung/ls6/data/pkls/url_to_probs_c4_dict_with_labels_t5_11b_valid.pkl
import pickle
with open('/work/09127/tomyoung/ls6/data/pkls/url_to_probs_c4_dict_with_labels_t5_11b_valid.pkl', 'rb') as f:
    url_to_probs_c4_dict_with_labels_t5_11b_valid = pickle.load(f)

In [None]:
# '/work/09127/tomyoung/ls6/data/pkls/acceptable_alternatives_1000_ignore_cws_nos_50_valid.pkl'
import pickle
with open('/work/09127/tomyoung/ls6/data/pkls/acceptable_alternatives_1000_ignore_cws_nos_50_valid.pkl', 'rb') as f:
    acceptable_alternatives_1000_ignore_cws_nos_50_valid = pickle.load(f)

In [None]:
acceptable_alternatives_1000_ignore_cws_nos_50_valid[(180, 11, 8)][3]

In [None]:
list(acceptable_alternatives_1000_ignore_cws_nos_50_valid.keys())[320:]

In [None]:
all_keys = list(acceptable_alternatives_1000_ignore_cws_nos_50_valid.keys())
for key in all_keys[1000:2000]:
    if len(acceptable_alternatives_1000_ignore_cws_nos_50_valid[key][3]) > 1:
        print(key)
        for item in acceptable_alternatives_1000_ignore_cws_nos_50_valid[key][3]:
            print(item)
        # print(*acceptable_alternatives_1000_ignore_cws_nos_50_valid[key][3])

In [None]:
dicts_realnewslike[0][0]

In [None]:
c4_json_file = '/work/09127/tomyoung/ls6/data/jsons/c4-validation.00000-of-00001-list-of-lists.json'
import json
with open(c4_json_file, 'r', encoding='utf8') as f:
    dicts_realnewslike = json.load(f)

In [None]:
from transformers import T5Tokenizer, BartTokenizer
t5_tokenizer = T5Tokenizer.from_pretrained('t5-3b')
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")

In [None]:
'''based on url_to_probs_c4_dict_with_labels_t5_11b_valid and acceptable_alternatives_1000_ignore_cws_nos_50_valid
generate the input string and target string for the models like ul2 and glm'''
proposed_bigram = t5_tokenizer.decode(url_to_probs_c4_dict_with_labels_t5_11b_valid[((0,17,4),(0,17,5),1)]['proposed bigram'])

acceptable_alternatives_1000_ignore_cws_nos_50_valid[(0,17,4)][0][1]

proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[(0,17,4)][0][1][4+1:4+3])
print('proposed_bigram:', proposed_bigram)

preceding_tokens = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[(0,17,4)][0][1][:4])
print('preceding_tokens:', preceding_tokens)

following_tokens = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[(0,17,4)][0][1][4+3:])
print('following_tokens:', following_tokens)


In [None]:
url_to_probs_c4_dict_with_labels_t5_11b_valid_keys = list(url_to_probs_c4_dict_with_labels_t5_11b_valid.keys())

url_to_ul2_probs_dict = {}
import math
for i in tqdm(range(len(url_to_probs_c4_dict_with_labels_t5_11b_valid_keys))):#len(url_to_probs_c4_dict_with_labels_t5_11b_valid_keys)
    key = url_to_probs_c4_dict_with_labels_t5_11b_valid_keys[i]
    url_to_ul2_probs_dict[url_to_probs_c4_dict_with_labels_t5_11b_valid_keys[i]] = {}
    show(key)
    show(key[0])
    show(acceptable_alternatives_1000_ignore_cws_nos_50_valid[key[0]])
    # acceptable_alternatives_1000_ignore_cws_nos_50_valid
    # process the proposed bigram and token
    url = key[0]
    story_id = key[0][0]
    paragraph_id = key[0][1]
    completion_id = key[2]
    proposed_token_pos = key[0][2]

    ''' proposed token '''
    # acceptable_alternatives_1000_ignore_cws_nos_50_valid has an <s> </s>
    proposed_token = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+1:proposed_token_pos+2])
    preceding_tokens_to_proposed_token = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][:proposed_token_pos+1])
    following_tokens_to_proposed_token = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+2:])
    show('proposed_token:', proposed_token)
    show('preceding_tokens_to_proposed_token:', preceding_tokens_to_proposed_token)
    show('following_tokens_to_proposed_token:', following_tokens_to_proposed_token)
    # generate the input string by adding a <extra_id_0> in the middle of the sentence
    input_string = preceding_tokens_to_proposed_token + ' <extra_id_0>' + following_tokens_to_proposed_token
    # remove <s> and </s> in the input string
    input_string = input_string.replace('<s>', '')
    input_string = input_string.replace('</s>', '')
    # add [NLU] to the input string
    input_string = '[NLU] ' + input_string
    show('input_string:', input_string)
    target_string = '<extra_id_0>' + proposed_token + ' <extra_id_1>'
    show('target_string_proposed_token:', target_string)
    inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
    labels = tokenizer(target_string, return_tensors="pt").input_ids.to("cuda")
    show('labels_proposed_token:', labels)
    # remove the last </s> token from labels
    labels = labels[:, :-1].contiguous()
    show('inputs:', inputs)
    show('labels:', labels)
    outputs = model(inputs, labels=labels)
    # print('outputs:', outputs)
    log_p = -ce_loss_sum(outputs.logits[0][1:], labels[0][1:]) # [1:] to remove the first token <extra_id_0>
    show('log_p_proposed_token:', log_p)
    example_raw_sequence = dicts_realnewslike[story_id][paragraph_id]
    show('example_raw_sequence:', example_raw_sequence)
    # tokenize the raw sequence
    example_raw_sequence_bart_tokenized = bart_tokenizer.tokenize(example_raw_sequence)
    bart_ids_original = bart_tokenizer.convert_tokens_to_ids(example_raw_sequence_bart_tokenized)
    
    ''' original token '''
    original_token = bart_tokenizer.decode(bart_ids_original[proposed_token_pos])
    show('original_token:', original_token)
    target_string_original_token = '<extra_id_0>' + original_token + ' <extra_id_1>'
    show('target_string_original_token:', target_string_original_token)
    labels_original_token = tokenizer(target_string_original_token, return_tensors="pt").input_ids.to("cuda")
    labels_original_token = labels_original_token[:, :-1].contiguous()
    show('labels_original_token:', labels_original_token)
    outputs = model(inputs, labels=labels_original_token)
    log_p_original_token = -ce_loss_sum(outputs.logits[0][1:], labels_original_token[0][1:]) # [1:] to remove the first token <extra_id_0>
    show('log_p_original_token:', log_p_original_token)
    
    
    ''' proposed bigram '''
    constant_token_pos = key[1][2]
    proposed_token_is_to_the_left = proposed_token_pos < constant_token_pos
    if proposed_token_is_to_the_left:
        proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+1:proposed_token_pos+3])
        preceding_tokens_to_proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][:proposed_token_pos+1])
        following_tokens_to_proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+3:])
        original_bigram = bart_tokenizer.decode(bart_ids_original[proposed_token_pos:proposed_token_pos+2])
    else:
        proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos:proposed_token_pos+2])
        preceding_tokens_to_proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][:proposed_token_pos])
        following_tokens_to_proposed_bigram = bart_tokenizer.decode(acceptable_alternatives_1000_ignore_cws_nos_50_valid[url][0][completion_id][proposed_token_pos+2:])
        original_bigram = bart_tokenizer.decode(bart_ids_original[proposed_token_pos-1:proposed_token_pos+1])
    show('proposed_bigram:', proposed_bigram)
    show('preceding_tokens_to_proposed_bigram:', preceding_tokens_to_proposed_bigram)
    show('following_tokens_to_proposed_bigram:', following_tokens_to_proposed_bigram)
    show('original_bigram:', original_bigram)
    
    
    # generate the input string by adding a <extra_id_0> in the middle of the sentence
    input_string = preceding_tokens_to_proposed_bigram + ' <extra_id_0>' + following_tokens_to_proposed_bigram
    # remove <s> and </s> in the input string
    input_string = input_string.replace('<s>', '')
    input_string = input_string.replace('</s>', '')
    # add [NLU] to the input string
    input_string = '[NLU] ' + input_string
    show('input_string:', input_string)
    target_string = '<extra_id_0>' + proposed_bigram + ' <extra_id_1>'
    show('target_string:', target_string)
    inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
    labels = tokenizer(target_string, return_tensors="pt").input_ids.to("cuda")
    # remove the last </s> token from labels
    labels = labels[:, :-1].contiguous()
    show('inputs:', inputs)
    show('labels:', labels)
    outputs = model(inputs, labels=labels)
    # print('outputs:', outputs)
    log_p_proposed_bigram = -ce_loss_sum(outputs.logits[0][1:], labels[0][1:]) #
    show('log_p_proposed_bigram:', log_p_proposed_bigram)


    '''original bigram'''
    target_string_original_bigram = '<extra_id_0>' + original_bigram + ' <extra_id_1>'
    show('target_string_original_bigram:', target_string_original_bigram)
    labels_original_bigram = tokenizer(target_string_original_bigram, return_tensors="pt").input_ids.to("cuda")
    labels_original_bigram = labels_original_bigram[:, :-1].contiguous()
    outputs = model(inputs, labels=labels_original_bigram)
    # print('outputs:', outputs)
    log_p_original_bigram = -ce_loss_sum(outputs.logits[0][1:], labels_original_bigram[0][1:]) #
    show('log_p_original_bigram:', log_p_original_bigram)
    # put the log_p's into the dictionary
    url_to_ul2_probs_dict[key] = {'proposed token': math.exp(log_p.to(torch.float32).detach().cpu().numpy()),
                                'original token': math.exp(log_p_original_token.to(torch.float32).detach().cpu().numpy()),
                                'proposed bigram': math.exp(log_p_proposed_bigram.to(torch.float32).detach().cpu().numpy()),
                                'original bigram': math.exp(log_p_original_bigram.to(torch.float32).detach().cpu().numpy())}
url_to_ul2_probs_dict_filepath = '/work/09127/tomyoung/ls6/data/pkls/url_to_ul2_probs_dict_valid.pkl'
with open(url_to_ul2_probs_dict_filepath, 'wb') as f:
    pickle.dump(url_to_ul2_probs_dict, f)


In [None]:
url_to_ul2_probs_dict_filepath = '/work/09127/tomyoung/ls6/data/pkls/url_to_ul2_probs_dict_valid.pkl'
with open(url_to_ul2_probs_dict_filepath, 'wb') as f:
    pickle.dump(url_to_ul2_probs_dict, f)


In [None]:
# /work/09127/tomyoung/ls6/data/pkls/url_to_probs_c4_dict_with_labels_t5_11b_valid.pkl
import pickle
with open('/work/09127/tomyoung/ls6/data/pkls/url_to_probs_c4_dict_with_labels_t5_11b_valid.pkl', 'rb') as f:
    url_to_probs_c4_dict_with_labels_t5_11b_valid = pickle.load(f)

In [None]:
input_string = "[NLG] The headquarters of microsoft is in <extra_id_0>"                                               

inputs = tokenizer(input_string, return_tensors="pt", add_special_tokens=False).input_ids.to("cuda")

outputs = model.generate(inputs, max_length=200)

print(tokenizer.decode(outputs[0]))

In [None]:
input_string = "[NLG] The headquarters of microsoft is in <extra_id_0>"
target_string = "<extra_id_0> redmond washington <extra_id_1>"
inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
labels = tokenizer(target_string, return_tensors="pt").input_ids.to("cuda")
labels = labels[:, :-1].contiguous()
outputs = model(inputs, labels=labels)
print('outputs:', outputs)
log_p = -ce_loss_sum(outputs.logits[0][1:-1], labels[0][1:-1]) #
print('log_p:', log_p)

probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
for i in range(len(labels[0])):
    print('label:', labels[0][i], 'prob:', probs[0][i][labels[0][i]])
    

In [None]:
labels.shape

In [None]:
labels = tokenizer(target_string, return_tensors="pt").input_ids.to("cuda")
labels

In [None]:
import pickle
with open('/work/09127/tomyoung/ls6/data/pkls/dict_url_to_completions_5_grams.pkl','rb') as f:
    dict_url_to_completions_5_grams = pickle.load(f)
# get a list of punctuations
import string
punctuations = string.punctuation

In [None]:
dict_url_to_completions_5_grams_keys = list(dict_url_to_completions_5_grams.keys())
url_to_ul2_5_gram_probs_dict = {}
# make sure it ends with a punctuation
for i in tqdm(range(1)):
    key = dict_url_to_completions_5_grams_keys[i]
    # print('key:', key)
    # print(dict_url_to_completions_5_grams[key])
    # alternative: make sure it ends with a punctuation
    if dict_url_to_completions_5_grams[key]['alternative'][-1] not in punctuations:
        continue
    # for 10-grams
    # get the input string 
    original_sentence_words = dict_url_to_completions_5_grams[key]['original_sentence'].split(' ')
    # len should >= 15
    if len(original_sentence_words) < 15:
        continue
    input_for_10_grams = '[NLG] ' + ' '.join(original_sentence_words[:-10]) + ' <extra_id_0>'
    input_for_10_grams_ids = tokenizer(input_for_10_grams, return_tensors="pt").input_ids.to("cuda")
    # original 10-gram
    original_10_gram = "<extra_id_0> " + ' '.join(original_sentence_words[-10:]) + ' <extra_id_1>'
    labels_original_10_gram = tokenizer(original_10_gram, return_tensors="pt").input_ids.to("cuda")
    labels_original_10_gram = labels_original_10_gram[:, :-1].contiguous() # remove the last token '</s>'
    outputs = model(input_for_10_grams_ids, labels=labels_original_10_gram)
    log_p_original_10_gram = -ce_loss_sum(outputs.logits[0][1:-1], labels_original_10_gram[0][1:-1]) # lose the <extra_id_0> and <extra_id_1>
    
    # probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    # for i in range(len(labels_original_10_gram[0])):
    #     print('label:', labels_original_10_gram[0][i], 'prob:', probs[0][i][labels_original_10_gram[0][i]])

    # continue
    # proposed 10-gram
    proposed_10_gram = "<extra_id_0> " + \
     ' '.join(original_sentence_words[-10:-5]) + ' ' + \
    dict_url_to_completions_5_grams[key]['alternative'] + ' <extra_id_1>'
    labels_proposed_10_gram = tokenizer(proposed_10_gram, return_tensors="pt").input_ids.to("cuda")
    labels_proposed_10_gram = labels_proposed_10_gram[:, :-1].contiguous() # remove the last token '</s>'
    outputs = model(input_for_10_grams_ids, labels=labels_proposed_10_gram)
    log_p_proposed_10_gram = -ce_loss_sum(outputs.logits[0][1:-1], labels_proposed_10_gram[0][1:-1]) # lose the <extra_id_0> and <extra_id_1>
    # for 5-grams
    # get the input string
    input_for_5_grams = '[NLG] ' + ' '.join(original_sentence_words[:-5]) + ' <extra_id_0>'
    input_for_5_grams_ids = tokenizer(input_for_5_grams, return_tensors="pt").input_ids.to("cuda")
    # original 5-gram
    original_5_gram = "<extra_id_0> " + ' '.join(original_sentence_words[-5:]) + ' <extra_id_1>'
    labels_original_5_gram = tokenizer(original_5_gram, return_tensors="pt").input_ids.to("cuda")
    labels_original_5_gram = labels_original_5_gram[:, :-1].contiguous() # remove the last token '</s>'
    outputs = model(input_for_5_grams_ids, labels=labels_original_5_gram)
    log_p_original_5_gram = -ce_loss_sum(outputs.logits[0][1:-1], labels_original_5_gram[0][1:-1]) # lose the <extra_id_0> and <extra_id_1>
    # proposed 5-gram
    proposed_5_gram = "<extra_id_0> " + dict_url_to_completions_5_grams[key]['alternative'] + ' <extra_id_1>'
    labels_proposed_5_gram = tokenizer(proposed_5_gram, return_tensors="pt").input_ids.to("cuda")
    labels_proposed_5_gram = labels_proposed_5_gram[:, :-1].contiguous() # remove the last token '</s>'
    outputs = model(input_for_5_grams_ids, labels=labels_proposed_5_gram)
    log_p_proposed_5_gram = -ce_loss_sum(outputs.logits[0][1:-1], labels_proposed_5_gram[0][1:-1]) # lose the <extra_id_0> and <extra_id_1>
    # add them to the dictionary
    url_to_ul2_5_gram_probs_dict[key] = {'proposed 5_gram': math.exp(log_p_proposed_5_gram.to(torch.float32).detach().cpu().numpy()),
                                         'original 5_gram': math.exp(log_p_original_5_gram.to(torch.float32).detach().cpu().numpy()),
                                         'proposed 10_gram': math.exp(log_p_proposed_10_gram.to(torch.float32).detach().cpu().numpy()),
                                         'original 10_gram': math.exp(log_p_original_10_gram.to(torch.float32).detach().cpu().numpy())}                                        


In [None]:
# save url_to_ul2_5_gram_probs_dict as a pkl
with open('/work/09127/tomyoung/ls6/data/pkls/url_to_ul2_5_gram_probs_dict.pkl', 'wb') as f:
    pickle.dump(url_to_ul2_5_gram_probs_dict, f)

In [None]:
# check the number of different tokens between the two 5-grams
for key in dict_url_to_completions_5_grams.keys():
    proposed_5_gram = dict_url_to_completions_5_grams[key]['alternative'].split(' ')
    original_5_gram = dict_url_to_completions_5_grams[key]['original_sentence'].split(' ')[-5:]
    number_of_different_tokens = 0
    for i in range(5):
        if proposed_5_gram[i] != original_5_gram[i]:
            number_of_different_tokens += 1
    print('key:', key)
    print('original_sentence:', dict_url_to_completions_5_grams[key]['original_sentence'])
    print('proposed_5_gram:', dict_url_to_completions_5_grams[key]['alternative'])
    print('original_5_gram:', ' '.join(dict_url_to_completions_5_grams[key]['original_sentence'].split(' ')[-5:]))
    print(number_of_different_tokens)