In [None]:
#/**
#* @file ul2_arc.ipynb
#* @author chenyunan (chen.yunan_01@nus.edu.sg)
#* @brief
#* @version 0.1
#* @date 2024-01-01
#*
#* @copyright Copyright (c) 2023 
#*
#*/


### Imports and global utils

In [1]:
'''imports'''
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1,4,5,6,7"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import general_utils
# clear GPU memory
if True:   
    general_utils.kill_gpu_process(os.environ["CUDA_VISIBLE_DEVICES"])
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Tokenizer
import numpy as np
from tqdm import tqdm
import lambada_utils
from lambada_utils import LambadaProcessor
from typing import Tuple, List

### Load tokenizer and model

In [2]:
# We are using custom huggingface cache dirs in case the default one doesn't have the capacity, since the models can be quite large.
MY_HUGGINGFACE_CACHE_DIR ='huggingface_cache' # relative to this notebook path
tokenizer = AutoTokenizer.from_pretrained("google/ul2",
                                        cache_dir = MY_HUGGINGFACE_CACHE_DIR+'/google-ul2')

RUN_CELL = 1 # Load model 1
# device_map=general_utils.get_ul2_device_map('2,3')
if RUN_CELL:
    model = T5ForConditionalGeneration.from_pretrained("google/ul2",
                                                        cache_dir=MY_HUGGINGFACE_CACHE_DIR + '/google-ul2',
                                                        low_cpu_mem_usage=True,
                                                        torch_dtype=torch.bfloat16,
                                                        device_map='cuda:0')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

### import MMLU datasets

In [8]:
from datasets import load_dataset

SUBJECTS = ['ARC-Challenge','ARC-Easy']

DATASET_PATH = os.path.join("ai2_arc")
ARC_DATAS = [load_dataset(DATASET_PATH, sub) for sub in SUBJECTS]
INDEX = [i for i in range(len(SUBJECTS))]
NAMES_WITH_DATAS = zip(INDEX, SUBJECTS, ARC_DATAS)

#### Define Loss Function

In [5]:
# define loss
ce_loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) #reduction='avg'
ce_loss_sum = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum') #reduction='sum'

In [6]:
extra_id_0 = torch.tensor([tokenizer.convert_tokens_to_ids("<extra_id_0>")])
extra_id_1 = torch.tensor([tokenizer.convert_tokens_to_ids("<extra_id_1>")])

#### Define Question Prompt

In [7]:
UL2_MODE = "[NLG]"

In [13]:
import torch.nn.functional as F

def data_prompting(docs, tokenizer) -> Tuple:
    '''
        docs: DATA_SET[SUBJECTS_NAME], ex:ARC[ARC-Challenge]
        return: Tuple(input_ids, completions_ids_padded, labels)

        input[example]: <prompt> 
        label[example]: A. <choice1> B. <choice2> C. <choice3> D. <choice4>

        Todo: few-shot data prompting
    '''

    for doc in docs:
        texts = doc['choices']['text']
        choices_list = doc['choices']['label']
        answer = doc['answerKey']
        index = choices_list.index(answer)

        input_ = UL2_MODE + " " + doc['question'] + " " + "<extra_id_0>"
        completions = [f"<extra_id_0> {text} <extra_id_1>" for text in texts]
        label = f"{texts[index]}"
        
        input_id = tokenizer(input_, return_tensors="pt").input_ids.to("cuda").clone().detach().requires_grad_(False)
        # label_id = tokenizer(label, return_tensors="pt").input_ids.to("cuda").clone().detach().requires_grad_(False)
        completions_ids = [tokenizer(completion, return_tensors="pt").input_ids.to("cuda").clone().detach().requires_grad_(False)\
                                                                for completion in completions]

        # Assuming `max_length` is the maximum length you want to pad sequences to
        max_length = max(seq.size(1) for seq in completions_ids)

        # Pad sequences to the common length
        padded_sequences = [F.pad(seq, (0, max_length - seq.size(1)), value=tokenizer.pad_token_id) for seq in completions_ids]

        # Use pad_sequence
        completions_ids_padded = torch.nn.utils.rnn.pad_sequence(padded_sequences, batch_first=True, padding_value=tokenizer.pad_token_id)

        completions_ids_padded = torch.squeeze(completions_ids_padded, dim = 1)
        yield input_id, completions_ids_padded, label


### K-offset Ensemble

In [11]:
MAX_OFFSET = 1

In [10]:
IS_DEVELOPMENT = True
set_partition = 'validation' if IS_DEVELOPMENT else 'test' 

In [9]:
ENDING_PUNCTUATIONS = '<' # If the model generates one, it is considered that the sentence is complete and we can parse for the last word

def get_word_from_completion(completion: str):
    '''Get the last word from the given completion, if there is a valid one. Return the word.'''
    found = False
    word = None
    # if a punctuation can be found in the completion, get the string before the punctuation
    for i in range(len(completion)):
        if completion[i] in ENDING_PUNCTUATIONS:
            if completion[i+1] == 'e':
                word = completion[:i]
                found = True
                break
    if not found:
        return None
    else:
        return word

In [14]:
def is_correct_completion(completion:torch.Tensor, label:str):
    if not isinstance(completion, torch.Tensor):
        return False
    completion_string = tokenizer.decode(completion)

    # print(f'completion_string:{completion_string}')
    if not isinstance(completion_string, str):
        return False
    word = get_word_from_completion(completion_string)
    # print(f'word:{word}')
    if not isinstance(word, str):
        return False
    if word == label:
        return True

In [15]:
RUN_CELL = 1 # Obtain the avg_log_p_map_offset

TOTAL_CASE = 0
ACUURACTE_CASE = 0

if RUN_CELL:
# id_and_offset_to_input_and_completions:
# (id, offset) -> input_ids, [completion_ids_0, completion_ids_1, completion_ids_2,...]
    avg_log_p_map_offset = dict() # (id, offset, completion_index) -> avg_log_p of the tokens constituting the last word (might be punctuated)
    
    for example_index in tqdm(range(len(INDEX))): 
        data = ARC_DATAS[example_index]
        print(SUBJECTS[example_index])
        
        for offset in range(MAX_OFFSET):
            gen = data_prompting(data[set_partition], tokenizer)

            for input_ids, completions_batch, label in gen:
                avg_log_p_and_completion = []
                outputs = lambada_utils.multi_labels_forward(model, input_ids, completions_batch)

                for completion_index in range(len(completions_batch)):
                    avg_log_p = -ce_loss(
                        # Only care about the tokens corresponding to the last word and omit offset tokens 
                        # the first one is <extra_id_0> and omitted
                        outputs.logits[completion_index][1+offset:], 
                        completions_batch[completion_index][1+offset:]
                    )
                    avg_log_p_map_offset[(example_index, offset, completion_index)] = \
                        avg_log_p.detach().cpu().tolist()
                    
                    avg_log_p_and_completion.append([avg_log_p.detach().cpu().tolist(), completions_batch[completion_index]])

                best_avg_log_p, best_completion = max(avg_log_p_and_completion, key=lambda x: x[0])
                # print(tokenizer.decode(best_completion))
                # print(f'label:{label}')

                if is_correct_completion(best_completion[1+offset:], label):
                    ACUURACTE_CASE += 1
                    # print(f'count_correct +1 : {ACUURACTE_CASE}')
                TOTAL_CASE += 1


[A

ARC-Challenge



[A

ARC-Easy



100%|██████████| 2/2 [00:57<00:00, 28.61s/it]


In [16]:
ACUURACTE_CASE/TOTAL_CASE

0.4039125431530495