### Imports and global utils

In [1]:
'''imports'''
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
from itertools import combinations
import random
import pickle
from utils import general_utils, eoc
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Tokenizer
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from typing import Tuple, List
import torch.nn.functional as F
import eoc_datasets
from model_configs import model_configs

  from .autonotebook import tqdm as notebook_tqdm


### Load model

In [2]:
# Specify model and load tokenizer
# model_identifier = "google-ul2"
model_identifier = "t5-11b" 
# model_identifier = "flan-ul2"

config = model_configs[model_identifier]

model_name, model_dir, mode, no_extra_tokens, model_kwargs = \
    config['model_name'], config['model_dir'], config['mode'], config['no_extra_tokens'], config['model_kwargs']

# Use custom huggingface cache dirs in case the default one has low capacity, since the models are large.
MY_HUGGINGFACE_CACHE_DIR ='huggingface_cache'

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=os.path.join(MY_HUGGINGFACE_CACHE_DIR, model_dir)
)

# define loss and get extra ids
ce_loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) #reduction='avg'
ce_loss_sum = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum') #reduction='sum'

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-11b automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [3]:
RUN_CELL = True  # Load model
if RUN_CELL:
    model = T5ForConditionalGeneration.from_pretrained(
        model_name,
        cache_dir=os.path.join(MY_HUGGINGFACE_CACHE_DIR, model_dir),
        **model_kwargs
    )

### Get dataset


In [4]:
# dataset_processor = eoc_datasets.ARCProcessor()
# dataset_processor = eoc_datasets.HellaswagProcessor()
dataset_processor = eoc_datasets.MMLUProcessor(subjects=config['mmlu_subjects'])
# dataset_processor = eoc_datasets.BigBenchProcessor(subjects=config['bigbench_subjects'])
data = dataset_processor.get_dataset(
    set_partition='test', 
    # shuffle=True, # for hellaswag; index bias: the first 1000 examples have very low accuracy compared to the whole
    # first_k_instances=1000, # see above 
)

example_generator = dataset_processor.example_generator

In [5]:
RUN_CELL = True   # set tensors_filtering_criterion by lengths
if RUN_CELL:
    def tensors_filtering_criterion(input_ids, completions_batch):
        # return True
        # remove trailing padding from completions   
        # print('input_ids:', input_ids)
        # print('completions_batch:', completions_batch)
        return len(input_ids[0]) > 20 \
            and all([len(general_utils.remove_trailing_zeros_from_1d_tensor(completion)) < 5 for completion in completions_batch]) \
            # and all([len(general_utils.remove_trailing_zeros_from_1d_tensor(completion)) < 6 for completion in completions_batch]) \
            # and not all([len(general_utils.remove_trailing_zeros_from_1d_tensor(completion)) < 4 for completion in completions_batch])

    gen = example_generator(data, tokenizer, mode=mode, tensors_filtering_criterion=tensors_filtering_criterion)
    input_lens = []
    completion_lens = []
    for example_id, input_ids, completions_batch, label in tqdm(gen):
        input_lens.append(len(input_ids[0]))
        completion_lens.append(len(completions_batch[0])) # with padding, this is the max len of the completions
    # print(f"input len > 20 and completion len < 10  and len > 6: {sum([i > 20 and j < 6 for i, j in zip(input_lens, completion_lens)])}")
    # print(f"completion len < 6: {sum([j < 6 for j in completion_lens])}")
    print(f"input len max: {max(input_lens)}, min: {min(input_lens)}, avg: {sum(input_lens)/len(input_lens)}")
    print(f"completion len max: {max(completion_lens)}, min: {min(completion_lens)}, avg: {sum(completion_lens)/len(completion_lens)}")

0it [00:00, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (549 > 512). Running this sequence through the model will result in indexing errors
176it [00:01, 120.13it/s]

input len max: 449, min: 21, avg: 49.07954545454545
completion len max: 4, min: 2, avg: 3.5738636363636362





In [6]:
RUN_CELL = True  # generate baseline info and conditionals
if RUN_CELL:
    baseline = dict() 
    # save the label and the number of completions
    gen = example_generator(data, tokenizer, mode, tensors_filtering_criterion=tensors_filtering_criterion)
    for example_id, input_ids, completions_batch, label in tqdm(gen):
        baseline[example_id] = dict()
        baseline[example_id]['label'] = label
        baseline[example_id]['no_completions'] = len(completions_batch)
        baseline[example_id]['p_map'] = []
        p_and_completion = []
        outputs = eoc.multi_labels_forward(model, input_ids.cuda(), completions_batch.cuda())

        for completion_index in range(len(completions_batch)):
            p = -ce_loss(
                # Only care about the tokens corresponding to the last word and omit offset tokens 
                # if the first one is <extra_id_0> and it is omitted
                outputs.logits[completion_index][no_extra_tokens:].cuda(), 
                completions_batch[completion_index][no_extra_tokens:].cuda()
            )

            baseline[example_id]['p_map'] += [p.detach().cpu().tolist()]

176it [00:51,  3.44it/s]


### K-offset Conditionals

In [7]:
RUN_CELL = True 
if RUN_CELL:
    MAX_OFFSET = 10
    p_map_offset = dict() # maps (example_id, offset, completion_index) -> avg_p
    for offset in range(1, MAX_OFFSET+1):
        gen = example_generator(data, tokenizer, mode, tensors_filtering_criterion=tensors_filtering_criterion)
        for example_id, input_ids, completions_batch, label in tqdm(gen):
            input_ids_offset, labels_offset = eoc.create_offset_sample_from_batch(
                tokenizer,
                input_ids,
                completions_batch,
                offset
            )
            outputs = eoc.multi_labels_forward(model, input_ids_offset.cuda(), labels_offset.cuda())
            for completion_index in range(len(completions_batch)):
                avg_log_p = -ce_loss(
                    # Only care about the tokens corresponding to the original completion and omit offset tokens 
                    # if the first one is <extra_id_0> and it is omitted
                    outputs.logits[completion_index][no_extra_tokens+offset:].cuda(), 
                    labels_offset[completion_index][no_extra_tokens+offset:].cuda()
                )
                p_map_offset[(example_id, offset, completion_index)] = \
                    avg_log_p.detach().cpu().tolist()

176it [00:46,  3.80it/s]
176it [00:45,  3.86it/s]
176it [00:47,  3.70it/s]
176it [00:45,  3.90it/s]
176it [00:45,  3.90it/s]
176it [00:45,  3.90it/s]
176it [00:45,  3.89it/s]
176it [00:45,  3.88it/s]
176it [00:45,  3.88it/s]
176it [00:45,  3.89it/s]


### SAP

Sequential autoregressive prompting

__SAP__ is a particular type of __Ensemble of Conditionals__.

It aims to augment the only conditional distribution obtained by masking the target with more distributions. The new distributions are obtained by unmasking the first __offset__ tokens from the target.

An example

prompt: `What is the best food? <extra_id_0>`

candidates:

`C1. French fries`

`C2. Chicken drumlets`

The baseline approach is to input `What is the best food? <extra_id_0>` to the model and obtain the probs of the C's.

E.g., `P(C1) = P(French) * P(fries|French)`.

SAP masks additional tokens at the start of C for different values of certain distributions.

For the offset=1 case, we mask 1 extra token.

prompt1: `What is the best food? French <extra_id_0>`

prompt2: `What is the best food? Chicken <extra_id_0>`

for candidates 

`C1. fries`

`C2. drumlets`

This gives us different values for distributions P(fries|French) and P(drumlets|Chicken), which gets put into our ensemble.





In [8]:
RUN_CELL = False
if RUN_CELL:
    baseline = dict()
    gen = example_generator(data, tokenizer, mode, tensors_filtering_criterion=tensors_filtering_criterion, pad_to_2d_tensor=False)
    for example_id, input_ids, completions_batch, label in tqdm(gen):
        baseline[example_id] = dict()
        baseline[example_id]['label'] = label
        baseline[example_id]['no_completions'] = len(completions_batch)
        baseline[example_id]['p_map'] = []
        for completion_index in range(len(completions_batch)):
            # get the sap probs for each (example, completion) pair by using offsets 0~len(completion)-2 (minus the extra_id_0 and the last token)
            sap_probs = [] # len(sap_probs) = len(completion) - no_extra_tokens
            for offset in range(0, no_extra_tokens-len(completions_batch[completion_index]), -1): 
                input_ids_sap, completion_ids_sap = eoc.create_offset_sample(
                    input_ids,
                    completions_batch[completion_index],
                    offset # offset is negative for sap
                )
                # take the first no_extra_tokens + 1 tokens from completion_ids_sap: <extra_id_0> (if it exists) and the first token of the completion
                completion_ids_sap = completion_ids_sap[:1+no_extra_tokens].unsqueeze(0)
                outputs = model(input_ids_sap.cuda(), labels=completion_ids_sap.cuda())
                log_p = -ce_loss(
                    outputs.logits[0][no_extra_tokens].cuda(), # [0] to lose the batch dim, [no_extra_tokens] to skip the <extra_id_0> token
                    completion_ids_sap[0][no_extra_tokens].cuda()
                )
                sap_probs.append(log_p.detach().cpu().tolist())
            baseline[example_id]['p_map'] += [sum(sap_probs) / len(sap_probs)]

### Multispan Conditionals

In [9]:
RUN_CELL = True  # generate multispan conditionals
if RUN_CELL:
    length_gap_num_tuples = [
        (3, 5, 1),
        (3, 5, 2),
        (3, 3, 1),
        (3, 3, 2),
        (3, 4, 1),
        (3, 4, 2),
        (3, 10, 1),
    ]
    p_map_multispan = dict()
    for length_gap_num_tuple in length_gap_num_tuples:
        span_length, gap_between_spans, num_spans = length_gap_num_tuple    
        gen = example_generator(data, tokenizer, mode, tensors_filtering_criterion=tensors_filtering_criterion)

        for example_id, input_ids, completions_batch, label in tqdm(gen):
            # print(input_ids.shape)
            # continue
            inputs_ids_multispan, labels_multispan = eoc.create_multiple_span_sample_from_batch(
                tokenizer,
                input_ids[0], # squeeze 1st dim
                completions_batch,
                span_length,
                gap_between_spans,
                num_spans,
            )
            outputs = eoc.multi_labels_forward(model, inputs_ids_multispan.cuda(), labels_multispan.cuda())

            for completion_index in range(len(completions_batch)):
                # assert multispan samples are correct 
                assert completions_batch[completion_index].nonzero().shape[0] == \
                    labels_multispan[completion_index][num_spans * (span_length + 1) :].nonzero().shape[0]

                avg_log_p = -ce_loss(
                    # Only care about the tokens corresponding to the completion (see assert below)); 
                    # so the first <extra_id_0> is omitted, and for each span, the span + <extra_id_k> is omitted;
                    # totally 1 + num_spans * (span_length + 1) tokens are omitted;
                    # labels_multispan contains paddings.
                    outputs.logits[completion_index][1 + num_spans * (span_length + 1) :].cuda(), 
                    labels_multispan[completion_index][1 + num_spans * (span_length + 1) :].cuda()
                )
                p_map_multispan[(example_id, span_length, gap_between_spans, num_spans, completion_index)] = \
                    avg_log_p.detach().cpu().tolist()

176it [00:44,  3.95it/s]
176it [00:45,  3.88it/s]
176it [00:44,  3.95it/s]
176it [00:45,  3.88it/s]
176it [00:44,  3.95it/s]
176it [00:45,  3.88it/s]
176it [00:45,  3.91it/s]


### Ensemble of Conditionals

In [10]:
def calc_disagreement(p_and_completion_individually):
    best_completion_indices = []
    for p_and_completion_individual in p_and_completion_individually:
        _, best_completion_index = max(p_and_completion_individual, key=lambda x: x[0])
        best_completion_indices.append(best_completion_index)
    return len(set(best_completion_indices)) > 1

In [11]:
'''Define the EOC function'''
# Max reduction to emsemble conditionals for the same last word
'''Max reduction to emsemble conditionals for the same last word, 
i.e., only the maximum avg_log_p is kept for each last word across different range_middle_span_length's and range_middle_to_end_gap's.
Emsemble the baseline conditionals with the K-offset conditionals and middle-off conditionals.'''

def run_eoc(offsets, length_gap_num_tuples):
    add_baseline = True
    add_k_offset = offsets != []
    add_multispan = length_gap_num_tuples != []

    count_correct = 0
    count_disagreement = 0
    for example_index in range(len(baseline)):
        no_completions = baseline[example_index]['no_completions']
        # Create a list of tuples (avg_log_p, completion) for each completion
        p_and_completion = []
        p_and_completion_individually = []
        # add the baseline (offset = 0 from K-offset ensemble) to the list
        if add_baseline:
            p_and_completion_individual = [
                (baseline[example_index]['p_map'][completion_index], completion_index)
                for completion_index in range(no_completions)
            ]
            p_and_completion += p_and_completion_individual
            p_and_completion_individually.append(p_and_completion_individual)
            
        # add the whole K-offset ensemble to the list
        if add_k_offset:
            for offset in offsets:
                p_and_completion_individual = [
                    (p_map_offset[(example_index, offset, completion_index)], completion_index)
                    for completion_index in range(no_completions)
                ]
                p_and_completion += p_and_completion_individual
                p_and_completion_individually.append(p_and_completion_individual)
                
        if add_multispan:
            for length_gap_num in length_gap_num_tuples:
                p_and_completion_individual = [
                    (p_map_multispan[(example_index, *length_gap_num, completion_index)], completion_index)
                    for completion_index in range(no_completions)
                ]
                p_and_completion += p_and_completion_individual
                p_and_completion_individually.append(p_and_completion_individual)

        # Find the tuple with the maximum avg_log_p; this is essentially max reduction
        _, best_completion_index = max(p_and_completion, key=lambda x: x[0])
        label = baseline[example_index]['label']
        if (isinstance(label, int) and best_completion_index == label) or \
        (isinstance(label, list) and best_completion_index in label) :# TruthfulQA has multiple correct answers
            count_correct += 1
        count_disagreement += calc_disagreement(p_and_completion_individually)
    # print("accuracy:", count_correct / len(baseline))
    return count_correct / len(baseline), count_disagreement / len(baseline)

In [12]:
RUN_CELL = True  # Run EOC
if RUN_CELL:
    # K-offset conditionals
    ALL_OFFSETS = [1, 2, 3,]
    # Multispan conditionals
    ALL_LENGTH_GAP_NUM_TUPLES = [
        (3, 5, 1),
        (3, 5, 2),
        (3, 3, 1),
        (3, 3, 2),
        (3, 4, 1),
        (3, 4, 2),
    ]
    NO_OFFSETS = len(ALL_OFFSETS)
    NO_MULTISPAN = len(ALL_LENGTH_GAP_NUM_TUPLES)
    NO_DISTS_RANGE = list(range(NO_OFFSETS + NO_MULTISPAN + 1))
    avg_accs = []
    avg_disagreements = []
    for NO_DISTS in NO_DISTS_RANGE: # no of distributions to ensemble
        all_dist_ids = list(combinations(range(NO_MULTISPAN + NO_OFFSETS), NO_DISTS))
        # shuffle and take the first 100
        random.shuffle(all_dist_ids)
        all_dist_ids = all_dist_ids[:500]
        all_accs = []
        all_disagreements = []
        for dist_ids in all_dist_ids:
            offsets = []
            length_gap_num_tuples = []
            for dist_id in dist_ids:
                if dist_id < NO_OFFSETS:
                    offsets.append(ALL_OFFSETS[dist_id])
                else:
                    length_gap_num_tuples.append(ALL_LENGTH_GAP_NUM_TUPLES[dist_id - NO_OFFSETS])            
            acc, disagreement = run_eoc(
                offsets,
                length_gap_num_tuples,
            )
            # print offsets and length_gap_num_tuples and acc
            # print(offsets, length_gap_num_tuples, acc)
            all_accs.append(acc)
            all_disagreements.append(disagreement)
        avg_acc = sum(all_accs) / len(all_accs)
        avg_disagreement = sum(all_disagreements) / len(all_disagreements)
        avg_accs.append(avg_acc)
        avg_disagreements.append(avg_disagreement)
        # print number of dists and avg_acc
        print(f"NO_DISTS: {NO_DISTS}, avg_acc: {avg_acc}", f"avg_disagreement: {avg_disagreement}")

NO_DISTS: 0, avg_acc: 0.4090909090909091 avg_disagreement: 0.0
NO_DISTS: 1, avg_acc: 0.4210858585858585 avg_disagreement: 0.3125
NO_DISTS: 2, avg_acc: 0.4286616161616161 avg_disagreement: 0.4308712121212121
NO_DISTS: 3, avg_acc: 0.4339150432900433 avg_disagreement: 0.4995941558441558
NO_DISTS: 4, avg_acc: 0.4379509379509379 avg_disagreement: 0.5465818903318905
NO_DISTS: 5, avg_acc: 0.44124278499278496 avg_disagreement: 0.5799062049062049
NO_DISTS: 6, avg_acc: 0.4439258658008656 avg_disagreement: 0.604166666666667
NO_DISTS: 7, avg_acc: 0.4460227272727274 avg_disagreement: 0.6223169191919191
NO_DISTS: 8, avg_acc: 0.4476010101010101 avg_disagreement: 0.6363636363636365
NO_DISTS: 9, avg_acc: 0.44886363636363635 avg_disagreement: 0.6477272727272727


In [13]:
avg_accs, avg_disagreements

([0.4090909090909091,
  0.4210858585858585,
  0.4286616161616161,
  0.4339150432900433,
  0.4379509379509379,
  0.44124278499278496,
  0.4439258658008656,
  0.4460227272727274,
  0.4476010101010101,
  0.44886363636363635],
 [0.0,
  0.3125,
  0.4308712121212121,
  0.4995941558441558,
  0.5465818903318905,
  0.5799062049062049,
  0.604166666666667,
  0.6223169191919191,
  0.6363636363636365,
  0.6477272727272727])

In [None]:
([0.4090909090909091,
  0.4210858585858585,
  0.4286616161616161,
  0.4339150432900433,
  0.4379509379509379,
  0.44124278499278496,
  0.4439258658008656,
  0.4460227272727274,
  0.4476010101010101,
  0.44886363636363635],
 [0.0,
  0.3125,
  0.4308712121212121,
  0.4995941558441558,
  0.5465818903318905,
  0.5799062049062049,
  0.604166666666667,
  0.6223169191919191,
  0.6363636363636365,
  0.6477272727272727])