### Imports and global utils

In [1]:
'''imports'''
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1,4,5,6,7"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import pickle
# clear GPU memory
from utils import general_utils, eoc
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Tokenizer
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from typing import Tuple, List
import torch.nn.functional as F
import eoc_datasets

  from .autonotebook import tqdm as notebook_tqdm


### Load model

In [2]:
# We are using custom huggingface cache dirs in case the default one doesn't have the capacity, since the models can be quite large.
MY_HUGGINGFACE_CACHE_DIR ='huggingface_cache' # relative to this notebook path
tokenizer = AutoTokenizer.from_pretrained("google/ul2",
                                        cache_dir = MY_HUGGINGFACE_CACHE_DIR+'/google-ul2')
# define loss and get extra ids
ce_loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) #reduction='avg'
ce_loss_sum = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum') #reduction='sum'

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
RUN_CELL = False   # Load model
if RUN_CELL:
    model = T5ForConditionalGeneration.from_pretrained("google/ul2",
                                                        cache_dir=MY_HUGGINGFACE_CACHE_DIR + '/google-ul2',
                                                        low_cpu_mem_usage=True,
                                                        torch_dtype=torch.bfloat16,
                                                        device_map='balanced')

Loading checkpoint shards: 100%|██████████| 4/4 [00:28<00:00,  7.18s/it]


In [4]:
UL2_MODE = '[NLG]' # '[S2S]' is not supported by some functions (poor accuracy)

### Get dataset


In [5]:
# dataset_processor = eoc_datasets.ARCProcessor()
dataset_processor = eoc_datasets.HellaswagProcessor()
# dataset_processor = eoc_datasets.MMLUProcessor()
data = dataset_processor.get_dataset(
    set_partition='test', 
    # shuffle=True, # for hellaswag; index bias: the first 1000 examples have very low accuracy compared to the whole
    # first_k_instances=1000, # see above 
)

example_generator = dataset_processor.example_generator

In [33]:
RUN_CELL = True    # get input and completion lens
if RUN_CELL:
    def tensors_filtering_criterion(input_ids, completions_batch):
        # return True
        return len(input_ids[0]) > 20 and len(completions_batch[0]) < 6
    gen = example_generator(data, tokenizer, tensors_filtering_criterion=tensors_filtering_criterion)
    input_lens = []
    completion_lens = []
    for example_id, input_ids, completions_batch, label in tqdm(gen):
        input_lens.append(len(input_ids[0]))
        completion_lens.append(len(completions_batch[0])) # with padding, this is the max len of the completions
    print(f"input len > 20 and completion len < 6: {sum([i > 20 and j < 6 for i, j in zip(input_lens, completion_lens)])}")
    # print(f"completion len < 6: {sum([j < 6 for j in completion_lens])}")
    print(f"input len max: {max(input_lens)}, min: {min(input_lens)}, avg: {sum(input_lens)/len(input_lens)}")
    print(f"completion len max: {max(completion_lens)}, min: {min(completion_lens)}, avg: {sum(completion_lens)/len(completion_lens)}")

2300it [00:12, 191.24it/s]

input len > 20 and completion len < 6: 2300
input len max: 654, min: 21, avg: 47.85782608695652
completion len max: 5, min: 2, avg: 3.8752173913043477





### Baseline

In [34]:
RUN_CELL = False   # generate baseline info and conditionals

if RUN_CELL:
    baseline = dict() 
    # save the label and the number of completions
    gen = example_generator(data, tokenizer, tensors_filtering_criterion=tensors_filtering_criterion)
    for example_id, input_ids, completions_batch, label in tqdm(gen):
        baseline[example_id] = dict()
        baseline[example_id]['label'] = label
        baseline[example_id]['no_completions'] = len(completions_batch)
        baseline[example_id]['p_map'] = []
        p_and_completion = []

        outputs = eoc.multi_labels_forward(model, input_ids.cuda(), completions_batch.cuda())

        for completion_index in range(len(completions_batch)):
            if UL2_MODE == "[NLG]":            
                p = -ce_loss(
                    # Only care about the tokens corresponding to the last word and omit offset tokens 
                    # the first one is <extra_id_0> and omitted
                    outputs.logits[completion_index][1:], 
                    completions_batch[completion_index][1:].cuda()
                )
            elif UL2_MODE == "[S2S]":
                p = -ce_loss(
                    outputs.logits[completion_index], 
                    completions_batch[completion_index].cuda()
                )

            baseline[example_id]['p_map'] += [p.detach().cpu().tolist()]


2300it [03:44, 10.25it/s]


In [8]:
RUN_CELL = False   # Save baseline
if RUN_CELL:
    pickle.dump(baseline, open("data/pkls/hellaswg/ul2/baseline_1000.pkl", "wb"))

In [9]:
RUN_CELL = False  # Load baseline
if RUN_CELL:
    '''Load the baseline p_maps'''
    with open("data/pkls/hellaswg/ul2/baseline_1000.pkl", "rb") as handle:
        baseline = pickle.load(handle)

### K-offset Conditionals

In [None]:
RUN_CELL = True 
if RUN_CELL:
    MAX_OFFSET = 5
    p_map_offset = dict() # maps (example_id, offset, completion_index) -> p_map
    for offset in range(1, MAX_OFFSET+1):
        gen = example_generator(data, tokenizer, tensors_filtering_criterion=tensors_filtering_criterion)
        for example_id, input_ids, completions_batch, label in tqdm(gen):
            input_ids_offset, labels_offset = eoc.create_offset_sample_from_batch(
                tokenizer,
                input_ids,
                completions_batch,
                offset
            )
            outputs = eoc.multi_labels_forward(model, input_ids_offset.cuda(), labels_offset.cuda())
            for completion_index in range(len(completions_batch)):
                avg_log_p = -ce_loss(
                    # Only care about the tokens corresponding to the original completion and omit offset tokens 
                    # the first one is <extra_id_0> and omitted
                    outputs.logits[completion_index][1+offset:], 
                    labels_offset[completion_index][1+offset:].cuda()
                )
                p_map_offset[(example_id, offset, completion_index)] = \
                    avg_log_p.detach().cpu().tolist()

### Multispan Conditionals

In [67]:
RUN_CELL = True    # generate multispan conditionals
if RUN_CELL:
    if UL2_MODE != "[NLG]":
        raise ValueError("Only NLG mode is supported for multispan conditionals for now")
    span_length = 3
    gap_between_spans = 10
    num_spans = 1
    p_map_multispan = dict()
    gen = example_generator(data, tokenizer, tensors_filtering_criterion=tensors_filtering_criterion)

    for example_id, input_ids, completions_batch, label in tqdm(gen):
        # print(input_ids.shape)
        # continue
        inputs_ids_multispan, labels_multispan = eoc.create_multiple_span_sample_from_batch(
            tokenizer,
            input_ids[0], # squeeze 1st dim
            completions_batch,
            span_length,
            gap_between_spans,
            num_spans,
        )
        outputs = eoc.multi_labels_forward(model, inputs_ids_multispan.cuda(), labels_multispan.cuda())

        # assert multispan samples are correct 
        assert completions_batch[completion_index].nonzero().shape[0] == \
            labels_multispan[completion_index][num_spans * (span_length + 1) :].nonzero().shape[0]

        for completion_index in range(len(completions_batch)):
            avg_log_p = -ce_loss(
                # Only care about the tokens corresponding to the completion (see assert below)); 
                # so the first <extra_id_0> is omitted, and for each span, the span + <extra_id_k> is omitted;
                # totally 1 + num_spans * (span_length + 1) tokens are omitted;
                # labels_multispan contains paddings.
                outputs.logits[completion_index][1 + num_spans * (span_length + 1) :], 
                labels_multispan[completion_index][1 + num_spans * (span_length + 1) :].cuda()
            )
            p_map_multispan[(example_id, span_length, gap_between_spans, num_spans, completion_index)] = \
                avg_log_p.detach().cpu().tolist()

2300it [03:57,  9.70it/s]


### Ensemble of Conditionals

In [69]:
RUN_CELL = False   # Max reduction to emsemble conditionals for the same last word
'''Max reduction to emsemble conditionals, i.e., only the maximum avg_log_p is kept for each completion.
Emsemble the baseline conditionals with the K-offset conditionals and middle-off conditionals.'''

if RUN_CELL:
    # Add the baseline (offset = 0 from K-offset ensemble) to the list
    ADD_BASELINE = True

    # Add K-offset conditionals to the list
    ADD_K_OFFSET = True
    MAX_OFFSET = 2

    # Add multispan conditionals to the list
    ADD_MULTISPAN = True
    LENGTH_GAP_NUM_TUPLES = [
        # (3, 5, 1),
        (3, 10, 1), 
    ] # SPAN_LENGTH, GAP_LENGTH, NUM_SPANS. NUM_SPANS can be float, which is treated as auto_ratio.

    count_correct = 0
    for example_index in tqdm(range(len(baseline))):
        no_completions = baseline[example_index]['no_completions']
        # Create a list of tuples (avg_log_p, completion) for each completion
        p_and_completion = []
        
        # add the baseline (offset = 0 from K-offset ensemble) to the list
        if ADD_BASELINE:
            p_and_completion += [
                (baseline[example_index]['p_map'][completion_index], completion_index)
                for completion_index in range(no_completions)
            ]
            
        # add the whole K-offset ensemble to the list
        if ADD_K_OFFSET:
            for offset in range(1, MAX_OFFSET + 1):
                p_and_completion += [
                    (p_map_offset[(example_index, offset, completion_index)], completion_index)
                    for completion_index in range(no_completions)
                ]
                
        if ADD_MULTISPAN:
            p_and_completion += [
                (p_map_multispan[(example_index, *length_gap_num, completion_index)], completion_index)
                for completion_index in range(no_completions)
                for length_gap_num in LENGTH_GAP_NUM_TUPLES
            ]

        # Find the tuple with the maximum avg_log_p; this is essentially max reduction
        _, best_completion_index = max(p_and_completion, key=lambda x: x[0])
        if best_completion_index == baseline[example_index]['label']:
            count_correct += 1
    print("accuracy:", count_correct / len(baseline))

100%|██████████| 2300/2300 [00:00<00:00, 125040.82it/s]

accuracy: 0.3339130434782609



