### Imports and global utils

In [1]:
'''imports'''
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1,4,5,6,7"
os.environ["CUDA_VISIBLE_DEVICES"]="4"
import pickle
# clear GPU memory
from utils import general_utils, eoc
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Tokenizer
from tqdm import tqdm
from datasets import load_dataset
from typing import Tuple, List
import torch.nn.functional as F

### Load model

In [None]:
# We are using custom huggingface cache dirs in case the default one doesn't have the capacity, since the models can be quite large.
MY_HUGGINGFACE_CACHE_DIR ='huggingface_cache' # relative to this notebook path
tokenizer = AutoTokenizer.from_pretrained("google/ul2",
                                        cache_dir = MY_HUGGINGFACE_CACHE_DIR+'/google-ul2')


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
RUN_CELL = True    # Load model 1
if RUN_CELL:
    model = T5ForConditionalGeneration.from_pretrained("google/ul2",
                                                        cache_dir=MY_HUGGINGFACE_CACHE_DIR + '/google-ul2',
                                                        low_cpu_mem_usage=True,
                                                        torch_dtype=torch.bfloat16,
                                                        device_map='balanced')

Loading checkpoint shards:  25%|██▌       | 1/4 [00:06<00:20,  6.82s/it]

In [None]:
UL2_MODE = '[NLG]'

### Load dataset and specify partition

In [None]:
DATASET_PATH = "Rowan/hellaswag" # NOTE: hellaswag only has the validation set

# validation lens
# max: 130, min: 19, avg: 66.98307110137422
# max: 125, min: 7, avg: 45.010655247958574

In [None]:
IS_DEVELOPMENT = True
data = load_dataset(DATASET_PATH)
set_partition = 'validation' if IS_DEVELOPMENT else 'test' 
# only use 1000 samples for development
if IS_DEVELOPMENT:
    data['validation'] = data['validation'] #.select(range(1000))

In [None]:
RUN_CELL = True # shuffle data['validation']
if RUN_CELL:
    data['validation'] = data['validation'].shuffle(seed=42)
    # check sha256 to make sure shuffled data is the same
    # general_utils.hash_object(data['validation'][0:1000])
    # bd403791b54a53d53a955db1ef38500eb6b7ad4fe29e1164e8d455e82a529ad2

### Define Loss Function

In [None]:
# define loss and get extra ids
ce_loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) #reduction='avg'
ce_loss_sum = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum') #reduction='sum'
extra_id_0 = torch.tensor([tokenizer.convert_tokens_to_ids("<extra_id_0>")])
extra_id_1 = torch.tensor([tokenizer.convert_tokens_to_ids("<extra_id_1>")])

### Define example generator

In [None]:
RUN_CELL = False
show = print if RUN_CELL else lambda *args, **kwargs: None


In [None]:
def data_prompting(docs, tokenizer) -> Tuple:
    '''
        docs: DATA_SET[SUBJECTS_NAME], ex:HELLASWG_DATAS[validation]
        return: Tuple(input_ids, completions_ids_padded, labels)

        input[example]: ctx:<prompt> 
        label[example]: endings list -> list[] 

        Todo: few-shot data prompting
    '''
    example_id = 0
    for doc in docs:
        endings_list = doc['endings']

        input_ = UL2_MODE + " " + doc["activity_label"] + ": " + doc['ctx'] + " " + "<extra_id_0>"
        completions = [f"<extra_id_0> {ending}" for ending in endings_list]
        # label = f"{endings_list[int(index)]}"
        show(input_)
        show(completions)

        input_id = tokenizer(input_, return_tensors="pt").input_ids
        # label_id = tokenizer(label, return_tensors="pt").input_ids.to("cuda").clone().detach().requires_grad_(False)
        completions_ids = [tokenizer(completion, return_tensors="pt").input_ids[:,:-1]\
                                                                for completion in completions] # remove <eos> token with [:,:-1]

        # Assuming `max_length` is the maximum length you want to pad sequences to
        max_length = max(seq.size(1) for seq in completions_ids)

        # Pad sequences to the common length
        padded_sequences = [F.pad(seq, (0, max_length - seq.size(1)), value=tokenizer.pad_token_id) for seq in completions_ids]

        # Use pad_sequence
        completions_ids_padded = torch.nn.utils.rnn.pad_sequence(padded_sequences, batch_first=True, padding_value=tokenizer.pad_token_id)

        completions_ids_padded = torch.squeeze(completions_ids_padded, dim = 1)
        yield example_id, input_id, completions_ids_padded, int(doc['label'])
        example_id += 1

In [None]:
RUN_CELL = True    # generate baseline info and conditionals

if RUN_CELL:
    baseline = dict() 
    # save the label and the number of completions
    gen = data_prompting(data[set_partition], tokenizer)
    for example_id, input_ids, completions_batch, label in tqdm(gen, total=len(data[set_partition])):
        baseline[example_id] = dict()
        baseline[example_id]['label'] = label
        baseline[example_id]['no_completions'] = len(completions_batch)
        baseline[example_id]['p_map'] = []
        p_and_completion = []

        outputs = eoc.multi_labels_forward(model, input_ids.cuda(), completions_batch.cuda())

        for completion_index in range(len(completions_batch)):
            p = -ce_loss(
                # Only care about the tokens corresponding to the last word and omit offset tokens 
                # the first one is <extra_id_0> and omitted
                outputs.logits[completion_index][1:], 
                completions_batch[completion_index][1:].cuda()
            )
            baseline[example_id]['p_map'] += [p.detach().cpu().tolist()]


100%|██████████| 10042/10042 [15:20<00:00, 10.92it/s]


In [None]:
RUN_CELL = False   # get input and completion lens
if RUN_CELL:
    input_lens = []
    completion_lens = []
    gen = data_prompting(data[set_partition], tokenizer)
    for example_id, input_ids, completions_batch, label in tqdm(gen, total=len(data[set_partition])):
        input_lens.append(len(input_ids[0]))
        for completion in completions_batch:
            completion_lens.append(len(completion))
    print(f"max: {max(input_lens)}, min: {min(input_lens)}, avg: {sum(input_lens)/len(input_lens)}")
    print(f"max: {max(completion_lens)}, min: {min(completion_lens)}, avg: {sum(completion_lens)/len(completion_lens)}")

In [11]:
RUN_CELL = True    # Save baseline
if RUN_CELL:
    pickle.dump(baseline, open("data/pkls/hellaswg/ul2/baseline.pkl", "wb"))

In [12]:
RUN_CELL = True # Load baseline
if RUN_CELL:
    '''Load the baseline p_maps'''
    with open("data/pkls/hellaswg/ul2/baseline.pkl", "rb") as handle:
        baseline = pickle.load(handle)

### Multispan Ensemble

In [34]:
RUN_CELL = True  # generate multispan conditionals
if RUN_CELL:
    span_length = 3
    gap_between_spans = 5
    num_spans = 1
    p_map_multispan = dict()
    gen = data_prompting(data[set_partition], tokenizer)

    for example_id, input_ids, completions_batch, label in tqdm(gen, total=len(data[set_partition])):
        # print(input_ids.shape)
        # continue
        inputs_ids_multispan, labels_multispan = eoc.create_multiple_span_sample_from_batch(
            tokenizer,
            input_ids[0], # squeeze 1st dim
            completions_batch, # the first completion
            span_length,
            gap_between_spans,
            num_spans,
        )
        outputs = eoc.multi_labels_forward(model, inputs_ids_multispan.cuda(), labels_multispan.cuda())

        # assert multispan samples are correct 
        assert completions_batch[completion_index].nonzero().shape[0] == \
            labels_multispan[completion_index][num_spans * (span_length + 1) :].nonzero().shape[0]

        for completion_index in range(len(completions_batch)):
            avg_log_p = -ce_loss(
                # Only care about the tokens corresponding to the completion (see assert below)); 
                # so the first <extra_id_0> is omitted, and for each span, the span + <extra_id_k> is omitted;
                # totally 1 + num_spans * (span_length + 1) tokens are omitted;
                # labels_multispan contains paddings.
                outputs.logits[completion_index][1 + num_spans * (span_length + 1) :], 
                labels_multispan[completion_index][1 + num_spans * (span_length + 1) :].cuda()
            )
            

            p_map_multispan[(example_id, span_length, gap_between_spans, num_spans, completion_index)] = \
                avg_log_p.detach().cpu().tolist()

100%|██████████| 100/100 [00:08<00:00, 11.34it/s]


In [13]:
RUN_CELL = True    # Max reduction to emsemble conditionals for the same last word
'''Max reduction to emsemble conditionals, i.e., only the maximum avg_log_p is kept for each completion.
Emsemble the baseline conditionals with the K-offset conditionals and middle-off conditionals.'''

if RUN_CELL:
    # Add the baseline (offset = 0 from K-offset ensemble) to the list
    ADD_BASELINE = True

    # Add K-offset conditionals to the list
    ADD_K_OFFSET = False 
    MAX_OFFSET = 4

    # Add multispan conditionals to the list
    ADD_MULTISPAN = False
    LENGTH_GAP_NUM_TUPLES = [
        (3, 5, 1), 
    ] # SPAN_LENGTH, GAP_LENGTH, NUM_SPANS. NUM_SPANS can be float, which is treated as auto_ratio.

    count_correct = 0
    for example_index in tqdm(range(len(data[set_partition]))): # len(lambada)
        no_completions = baseline[example_index]['no_completions']
        # Create a list of tuples (avg_log_p, completion) for each completion
        p_and_completion = []
        
        # add the baseline (offset = 0 from K-offset ensemble) to the list
        if ADD_BASELINE:
            p_and_completion += [
                (baseline[example_index]['p_map'][completion_index], completion_index)
                for completion_index in range(no_completions)
            ]
            
        # add the whole K-offset ensemble to the list
        if ADD_K_OFFSET:
            for offset in range(1, MAX_OFFSET + 1):
                p_and_completion += [
                    (p_map_offset[(example_index, offset, completion_index)], id_to_completions_ids[example_index][completion_index])
                    for completion_index in range(len(id_to_completions_ids[example_index]))
                ]
                
        if ADD_MULTISPAN:
            p_and_completion += [
                (p_map_multispan[(example_index, *length_gap_num, completion_index)], completion_index)
                for completion_index in range(no_completions)
                for length_gap_num in LENGTH_GAP_NUM_TUPLES
            ]

        # Find the tuple with the maximum avg_log_p; this is essentially max reduction
        _, best_completion_index = max(p_and_completion, key=lambda x: x[0])
        if best_completion_index == baseline[example_index]['label']:
            count_correct += 1
    print("accuracy:", count_correct / len(data[set_partition]))

100%|██████████| 10042/10042 [00:00<00:00, 538312.70it/s]

accuracy: 0.7520414260107549





In [14]:
RUN_CELL = True # Get baseline accuracy

TOTAL_CASE = 0
ACCURATE_CASE = 0

if RUN_CELL:
    p_map_offset = dict() # (id, offset, completion_index) -> p of the tokens constituting the last word (might be punctuated)
        
    gen = data_prompting(data[set_partition], tokenizer)

    for example_id, input_ids, completions_batch, label in tqdm(gen, total=len(data[set_partition])):
        p_and_completion = []
        outputs = eoc.multi_labels_forward(model, input_ids.cuda(), completions_batch.cuda())

        for completion_index in range(len(completions_batch)):
            p = -ce_loss(
                # Only care about the tokens corresponding to the last word and omit offset tokens 
                # the first one is <extra_id_0> and omitted
                outputs.logits[completion_index][1:], 
                completions_batch[completion_index][1:].cuda()
            )
            p_map_offset[completion_index] = \
                p.detach().cpu().tolist()
            
            p_and_completion.append([p.detach().cpu().tolist(), completion_index])

        best_p, best_completion_index = max(p_and_completion, key=lambda x: x[0])

        if int(best_completion_index) == int(label):
            ACCURATE_CASE += 1
        TOTAL_CASE += 1
    print("accuracy:", ACCURATE_CASE / TOTAL_CASE)

100%|██████████| 10042/10042 [15:24<00:00, 10.87it/s]

accuracy: 0.7520414260107549





In [None]:
ACCURATE_CASE/TOTAL_CASE