### Imports and global utils

In [1]:
'''imports'''
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1,4,5,6,7"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import pickle
# clear GPU memory
from utils import general_utils, eoc
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Tokenizer
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from typing import Tuple, List
import torch.nn.functional as F
import eoc_datasets

  from .autonotebook import tqdm as notebook_tqdm


### Load model

In [30]:
# Specify model and load tokenizer
# model_identifier = "google-ul2"
# model_identifier = "t5-11b" 
model_identifier = "flan-ul2"
if model_identifier == "t5-11b":
    model_name = "t5-11b" # "google/ul2" 
    model_dir = "t5-11b" # "google-ul2"
    mode = 'T5'
    no_extra_tokens = 1 # extra_id_0
elif model_identifier == "google-ul2":
    model_name = "google/ul2" 
    model_dir = "google-ul2"
    mode = '[NLG]' # '[S2S]' is not supported by some functions (poor accuracy)
    no_extra_tokens = 1 # extra_id_0
elif model_identifier == "flan-ul2":
    model_name = "google/flan-ul2" 
    model_dir = "flan-ul2"
    mode = 'Flan-UL2'
    no_extra_tokens = 0

# Use custom huggingface cache dirs in case the default one has low capacity, since the models are large.
MY_HUGGINGFACE_CACHE_DIR ='huggingface_cache'

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=os.path.join(MY_HUGGINGFACE_CACHE_DIR, model_dir)
)

# define loss and get extra ids
ce_loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) #reduction='avg'
ce_loss_sum = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum') #reduction='sum'

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
RUN_CELL = False   # Load model
if RUN_CELL:
    model = T5ForConditionalGeneration.from_pretrained(
        model_name,
        cache_dir=os.path.join(MY_HUGGINGFACE_CACHE_DIR, model_dir),
        # low_cpu_mem_usage=True,
        # torch_dtype=torch.bfloat16, # ul2
        # device_map='balanced', # ul2
        device_map='auto', # flan-ul2
        load_in_8bit=True # flan-ul2
    )

In [4]:
# model.encoder.block[0].layer[0].SelfAttention.q.weight.dtype

### Get dataset


In [5]:
# dataset_processor = eoc_datasets.ARCProcessor()
# dataset_processor = eoc_datasets.HellaswagProcessor()
dataset_processor = eoc_datasets.MMLUProcessor()
data = dataset_processor.get_dataset(
    set_partition='test', 
    # shuffle=True, # for hellaswag; index bias: the first 1000 examples have very low accuracy compared to the whole
    # first_k_instances=1000, # see above 
)

example_generator = dataset_processor.example_generator

In [15]:
RUN_CELL = True   # set tensors_filtering_criterion by lengths
if RUN_CELL:
    def tensors_filtering_criterion(input_ids, completions_batch):
        # return True
        # remove trailing padding from completions   
        # print('input_ids:', input_ids)
        # print('completions_batch:', completions_batch)
        return len(input_ids[0]) > 20 \
            and all([len(general_utils.remove_trailing_zeros_from_1d_tensor(completion)) < 6 for completion in completions_batch]) \
            # and all([len(completions_batch[i][completions_batch[i].nonzero()].squeeze()) > 6 for i in range(len(completions_batch))])

    gen = example_generator(data, tokenizer, mode=mode, tensors_filtering_criterion=tensors_filtering_criterion)
    input_lens = []
    completion_lens = []
    for example_id, input_ids, completions_batch, label in tqdm(gen):
        input_lens.append(len(input_ids[0]))
        completion_lens.append(len(completions_batch[0])) # with padding, this is the max len of the completions
    # print(f"input len > 20 and completion len < 10  and len > 6: {sum([i > 20 and j < 6 for i, j in zip(input_lens, completion_lens)])}")
    # print(f"completion len < 6: {sum([j < 6 for j in completion_lens])}")
    print(f"input len max: {max(input_lens)}, min: {min(input_lens)}, avg: {sum(input_lens)/len(input_lens)}")
    print(f"completion len max: {max(completion_lens)}, min: {min(completion_lens)}, avg: {sum(completion_lens)/len(completion_lens)}")

2280it [00:10, 214.70it/s]

input len max: 649, min: 21, avg: 54.15263157894737
completion len max: 5, min: 1, avg: 3.373245614035088





### Baseline

In [None]:
RUN_CELL = True    # generate baseline info and conditionals

if RUN_CELL:
    baseline = dict() 
    # save the label and the number of completions
    gen = example_generator(data, tokenizer, mode, tensors_filtering_criterion=tensors_filtering_criterion)
    for example_id, input_ids, completions_batch, label in tqdm(gen):
        baseline[example_id] = dict()
        baseline[example_id]['label'] = label
        baseline[example_id]['no_completions'] = len(completions_batch)
        baseline[example_id]['p_map'] = []
        p_and_completion = []
        outputs = eoc.multi_labels_forward(model, input_ids.cuda(), completions_batch.cuda())

        for completion_index in range(len(completions_batch)):
            p = -ce_loss(
                # Only care about the tokens corresponding to the last word and omit offset tokens 
                # if the first one is <extra_id_0> and it is omitted
                outputs.logits[completion_index][no_extra_tokens:], 
                completions_batch[completion_index][no_extra_tokens:].cuda()
            )

            baseline[example_id]['p_map'] += [p.detach().cpu().tolist()]


### K-offset Conditionals

In [23]:
RUN_CELL = True 
if RUN_CELL:
    MAX_OFFSET = 2
    p_map_offset = dict() # maps (example_id, offset, completion_index) -> avg_p
    for offset in range(1, MAX_OFFSET+1):
        gen = example_generator(data, tokenizer, mode, tensors_filtering_criterion=tensors_filtering_criterion)
        for example_id, input_ids, completions_batch, label in tqdm(gen):
            input_ids_offset, labels_offset = eoc.create_offset_sample_from_batch(
                tokenizer,
                input_ids,
                completions_batch,
                offset
            )
            outputs = eoc.multi_labels_forward(model, input_ids_offset.cuda(), labels_offset.cuda())
            for completion_index in range(len(completions_batch)):
                avg_log_p = -ce_loss(
                    # Only care about the tokens corresponding to the original completion and omit offset tokens 
                    # if the first one is <extra_id_0> and it is omitted
                    outputs.logits[completion_index][no_extra_tokens+offset:], 
                    labels_offset[completion_index][no_extra_tokens+offset:].cuda()
                )
                p_map_offset[(example_id, offset, completion_index)] = \
                    avg_log_p.detach().cpu().tolist()

0it [00:00, ?it/s]
0it [00:00, ?it/s]


### SAP

Sequential autoregressive prompting

__SAP__ is a particular type of __Ensemble of Conditionals__.

It aims to augment the only conditional distribution obtained by masking the target with more distributions. The new distributions are obtained by unmasking the first __offset__ tokens from the target.

An example

prompt: `What is the best food? <extra_id_0>`

candidates:

`C1. French fries`

`C2. Chicken drumlets`

The baseline approach is to input `What is the best food? <extra_id_0>` to the model and obtain the probs of the C's.

E.g., `P(C1) = P(French) * P(fries|French)`.

SAP masks additional tokens at the start of C for different values of certain distributions.

For the offset=1 case, we mask 1 extra token.

prompt1: `What is the best food? French <extra_id_0>`

prompt2: `What is the best food? Chicken <extra_id_0>`

for candidates 

`C1. fries`

`C2. drumlets`

This gives us different values for distributions P(fries|French) and P(drumlets|Chicken), which gets put into our ensemble.





In [9]:
RUN_CELL = False
if RUN_CELL:
    baseline = dict()
    gen = example_generator(data, tokenizer, mode, tensors_filtering_criterion=tensors_filtering_criterion, pad_to_2d_tensor=False)
    for example_id, input_ids, completions_batch, label in tqdm(gen):
        baseline[example_id] = dict()
        baseline[example_id]['label'] = label
        baseline[example_id]['no_completions'] = len(completions_batch)
        baseline[example_id]['p_map'] = []
        for completion_index in range(len(completions_batch)):
            # get the sap probs for each (example, completion) pair by using offsets 0~len(completion)-2 (minus the extra_id_0 and the last token)
            sap_probs = [] # len(sap_probs) = len(completion) - no_extra_tokens
            for offset in range(0, no_extra_tokens-len(completions_batch[completion_index]), -1): 
                input_ids_sap, completion_ids_sap = eoc.create_offset_sample(
                    input_ids,
                    completions_batch[completion_index],
                    offset # offset is negative for sap
                )
                # take the first no_extra_tokens + 1 tokens from completion_ids_sap: <extra_id_0> (if it exists) and the first token of the completion
                completion_ids_sap = completion_ids_sap[:1+no_extra_tokens].unsqueeze(0)
                outputs = model(input_ids_sap.cuda(), labels=completion_ids_sap.cuda())
                log_p = -ce_loss(
                    outputs.logits[0][no_extra_tokens], # [0] to lose the batch dim, [no_extra_tokens] to skip the <extra_id_0> token
                    completion_ids_sap[0][no_extra_tokens].cuda()
                )
                sap_probs.append(log_p.detach().cpu().tolist())
            baseline[example_id]['p_map'] += [sum(sap_probs) / len(sap_probs)]

### Multispan Conditionals

In [10]:
RUN_CELL = False   # generate multispan conditionals
if RUN_CELL:
    if mode != "[NLG]":
        raise ValueError("Only NLG mode is supported for multispan conditionals for now")
    span_length = 3
    gap_between_spans = 5
    num_spans = 1
    p_map_multispan = dict()
    gen = example_generator(data, tokenizer, mode, tensors_filtering_criterion=tensors_filtering_criterion)

    for example_id, input_ids, completions_batch, label in tqdm(gen):
        # print(input_ids.shape)
        # continue
        inputs_ids_multispan, labels_multispan = eoc.create_multiple_span_sample_from_batch(
            tokenizer,
            input_ids[0], # squeeze 1st dim
            completions_batch,
            span_length,
            gap_between_spans,
            num_spans,
        )
        outputs = eoc.multi_labels_forward(model, inputs_ids_multispan.cuda(), labels_multispan.cuda())

        # assert multispan samples are correct 
        assert completions_batch[completion_index].nonzero().shape[0] == \
            labels_multispan[completion_index][num_spans * (span_length + 1) :].nonzero().shape[0]

        for completion_index in range(len(completions_batch)):
            avg_log_p = -ce_loss(
                # Only care about the tokens corresponding to the completion (see assert below)); 
                # so the first <extra_id_0> is omitted, and for each span, the span + <extra_id_k> is omitted;
                # totally 1 + num_spans * (span_length + 1) tokens are omitted;
                # labels_multispan contains paddings.
                outputs.logits[completion_index][1 + num_spans * (span_length + 1) :], 
                labels_multispan[completion_index][1 + num_spans * (span_length + 1) :].cuda()
            )
            p_map_multispan[(example_id, span_length, gap_between_spans, num_spans, completion_index)] = \
                avg_log_p.detach().cpu().tolist()

### Ensemble of Conditionals

In [12]:
RUN_CELL = True    # Max reduction to emsemble conditionals for the same last word
'''Max reduction to emsemble conditionals, i.e., only the maximum avg_log_p is kept for each completion.
Emsemble the baseline conditionals with the K-offset conditionals and middle-off conditionals.'''

if RUN_CELL:
    # Add the baseline (offset = 0 from K-offset ensemble) to the list
    ADD_BASELINE = True

    # Add K-offset conditionals to the list
    ADD_K_OFFSET = False
    MAX_OFFSET = 2

    # Add multispan conditionals to the list
    ADD_MULTISPAN = False
    LENGTH_GAP_NUM_TUPLES = [
        (3, 5, 1),
        # (3, 10, 1), 
    ] # SPAN_LENGTH, GAP_LENGTH, NUM_SPANS. NUM_SPANS can be float, which is treated as auto_ratio.

    count_correct = 0
    for example_index in tqdm(range(len(baseline))):
        no_completions = baseline[example_index]['no_completions']
        # Create a list of tuples (avg_log_p, completion) for each completion
        p_and_completion = []
        
        # add the baseline (offset = 0 from K-offset ensemble) to the list
        if ADD_BASELINE:
            p_and_completion += [
                (baseline[example_index]['p_map'][completion_index], completion_index)
                for completion_index in range(no_completions)
            ]
            
        # add the whole K-offset ensemble to the list
        if ADD_K_OFFSET:
            for offset in range(1, MAX_OFFSET + 1):
                p_and_completion += [
                    (p_map_offset[(example_index, offset, completion_index)], completion_index)
                    for completion_index in range(no_completions)
                ]
                
        if ADD_MULTISPAN:
            p_and_completion += [
                (p_map_multispan[(example_index, *length_gap_num, completion_index)], completion_index)
                for completion_index in range(no_completions)
                for length_gap_num in LENGTH_GAP_NUM_TUPLES
            ]

        # Find the tuple with the maximum avg_log_p; this is essentially max reduction
        _, best_completion_index = max(p_and_completion, key=lambda x: x[0])
        label = baseline[example_index]['label']
        if (isinstance(label, int) and best_completion_index == label) or \
        (isinstance(label, list) and best_completion_index in label) :# TruthfulQA has multiple correct answers
            count_correct += 1
        
    print("accuracy:", count_correct / len(baseline))

100%|██████████| 2300/2300 [00:00<00:00, 320016.56it/s]

accuracy: 0.35695652173913045



