### Imports and global utils

In [1]:
'''imports'''
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1,4,5,6,7"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import general_utils
# clear GPU memory
if True:   
    general_utils.kill_gpu_process(os.environ["CUDA_VISIBLE_DEVICES"])
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Tokenizer
import numpy as np
from tqdm import tqdm
import lambada_utils
from lambada_utils import LambadaProcessor
from typing import Tuple, List

Killed process 156999 on GPU 0


### Load tokenizer and model

In [2]:
from datasets import load_dataset

SUBJECTS = ['validation']
DATASET_PATH = os.path.join("Rowan/hellaswag")
HELLASWG_DATAS = [load_dataset(DATASET_PATH, sub) for sub in SUBJECTS]
INDEX = [i for i in range(len(SUBJECTS))]
NAMES_WITH_DATAS = zip(INDEX, SUBJECTS, HELLASWG_DATAS)

In [3]:
# We are using custom huggingface cache dirs in case the default one doesn't have the capacity, since the models can be quite large.
MY_HUGGINGFACE_CACHE_DIR ='huggingface_cache' # relative to this notebook path
tokenizer = AutoTokenizer.from_pretrained("google/ul2",
                                        cache_dir = MY_HUGGINGFACE_CACHE_DIR+'/google-ul2')

RUN_CELL = 1 # Load model 1
# device_map=general_utils.get_ul2_device_map('2,3')
if RUN_CELL:
    model = T5ForConditionalGeneration.from_pretrained("google/ul2",
                                                        cache_dir=MY_HUGGINGFACE_CACHE_DIR + '/google-ul2',
                                                        low_cpu_mem_usage=True,
                                                        torch_dtype=torch.bfloat16,
                                                        device_map='cuda:0')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

#### Define Loss Function

In [4]:
# define loss
ce_loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) #reduction='avg'
ce_loss_sum = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum') #reduction='sum'

In [5]:
extra_id_0 = torch.tensor([tokenizer.convert_tokens_to_ids("<extra_id_0>")])
extra_id_1 = torch.tensor([tokenizer.convert_tokens_to_ids("<extra_id_1>")])

#### Define Question Prompt

In [6]:
UL2_MODE = "[NLG]"

In [7]:
import re
def preprocess(text):
    text = text.strip()
    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
    text = text.replace(" [title]", ". ")
    text = re.sub("\\[.*?\\]", "", text)
    text = text.replace("  ", " ")
    return text

In [18]:
import torch.nn.functional as F

def data_prompting(docs, tokenizer) -> Tuple:
    '''
        docs: DATA_SET[SUBJECTS_NAME], ex:HELLASWG_DATAS[validation]
        return: Tuple(input_ids, completions_ids_padded, labels)

        input[example]: ctx:<prompt> 
        label[example]: endings list -> list[] 

        Todo: few-shot data prompting
    '''

    for doc in docs:
        endings_list = doc['endings']
        index = doc['label']

        input_ = UL2_MODE + " " + doc["activity_label"] + ": " + doc['ctx'] + " " + "<extra_id_0>"
        print(input_)
        completions = [f"<extra_id_0> {ending}" for ending in endings_list]
        # label = f"{endings_list[int(index)]}"
        label = index
        
        input_id = tokenizer(input_, return_tensors="pt").input_ids.to("cuda").clone().detach().requires_grad_(False)
        # label_id = tokenizer(label, return_tensors="pt").input_ids.to("cuda").clone().detach().requires_grad_(False)
        completions_ids = [tokenizer(completion, return_tensors="pt").input_ids.to("cuda").clone().detach()[:,:-1]\
                                                                for completion in completions] # remove <eos> token with [:,:-1]

        # Assuming `max_length` is the maximum length you want to pad sequences to
        max_length = max(seq.size(1) for seq in completions_ids)

        # Pad sequences to the common length
        padded_sequences = [F.pad(seq, (0, max_length - seq.size(1)), value=tokenizer.pad_token_id) for seq in completions_ids]

        # Use pad_sequence
        completions_ids_padded = torch.nn.utils.rnn.pad_sequence(padded_sequences, batch_first=True, padding_value=tokenizer.pad_token_id)

        completions_ids_padded = torch.squeeze(completions_ids_padded, dim = 1)
        yield input_id, completions_ids_padded, label


### K-offset Ensemble

In [9]:
IS_DEVELOPMENT = True
set_partition = 'validation' if IS_DEVELOPMENT else 'test' 

In [19]:
RUN_CELL = 1 # Obtain the avg_log_p_map_offset

TOTAL_CASE = 0
ACUURACTE_CASE = 0

if RUN_CELL:
# id_and_offset_to_input_and_completions:
# (id, offset) -> input_ids, [completion_ids_0, completion_ids_1, completion_ids_2,...]
    avg_log_p_map_offset = dict() # (id, offset, completion_index) -> avg_log_p of the tokens constituting the last word (might be punctuated)
    
    for example_index in tqdm(range(len(INDEX))): 
        data = HELLASWG_DATAS[example_index]
        
        gen = data_prompting(data[set_partition], tokenizer)

        for input_ids, completions_batch, label in gen:
            avg_log_p_and_completion = []
            outputs = lambada_utils.multi_labels_forward(model, input_ids, completions_batch)

            for completion_index in range(len(completions_batch)):
                avg_log_p = -ce_loss(
                    # Only care about the tokens corresponding to the last word and omit offset tokens 
                    # the first one is <extra_id_0> and omitted
                    outputs.logits[completion_index][1:], 
                    completions_batch[completion_index][1:]
                )
                avg_log_p_map_offset[(example_index, 0, completion_index)] = \
                    avg_log_p.detach().cpu().tolist()
                
                avg_log_p_and_completion.append([avg_log_p.detach().cpu().tolist(), completion_index])

            best_avg_log_p, best_completion_index = max(avg_log_p_and_completion, key=lambda x: x[0])

            if int(best_completion_index) == int(label):
                ACUURACTE_CASE += 1
            TOTAL_CASE += 1

  0%|          | 0/1 [00:00<?, ?it/s]

[NLG] Roof shingle removal: A man is sitting on a roof. he <extra_id_0>
[NLG] Clean and jerk: A lady walks to a barbell. She bends down and grabs the pole. the lady <extra_id_0>
[NLG] Canoeing: Two women in a child are shown in a canoe while a man pulls the canoe while standing in the water, with other individuals visible in the background. the child and a different man <extra_id_0>
[NLG] High jump: A boy is running down a track. the boy <extra_id_0>
[NLG] High jump: The boy lifts his body above the height of a pole. The boy lands on his back on to a red mat. the boy <extra_id_0>
[NLG] High jump: The boy lands on his back on to a red mat. The boy gets up from the mat. the boy <extra_id_0>
[NLG] Playing harmonica: A man is standing in front of a camera. He starts playing a harmonica for the camera. he <extra_id_0>
[NLG] Sumo: A cartoon animation video is shown with people wandering around and rockets being shot. two men <extra_id_0>
[NLG] Sharpening knives: A man is holding a pocket kni

100%|██████████| 1/1 [12:12<00:00, 732.13s/it]

[NLG] Finance and Business: [header] How to buy a peridot [title] Look at a variety of stones. [step] Expect peridot to come in various shades of green. Visit a respectable jeweler to browse a number of samples, all at once. <extra_id_0>
[NLG] Family Life: [header] How to tell if your teen is being abused [title] Pay attention to your teen dressing inappropriately. [step] If you suspect that your teen is being beaten by someone in their life, you will want to pay close attention to their dress code. While all teens have different styles and preferences, be on the lookout for any clothing that is out of the norm for your child. <extra_id_0>





In [20]:
ACUURACTE_CASE/TOTAL_CASE

0.7520414260107549