### Import env

In [1]:
'''imports'''
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1,4,5,6,7"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import general_utils
# clear GPU memory
if True:   
    general_utils.kill_gpu_process(os.environ["CUDA_VISIBLE_DEVICES"])
import torch
os.environ['PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT'] = '5.0' # suppresses pydevd speed warnings
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Tokenizer
import numpy as np
import pickle
import time
from tqdm import tqdm
import json
import lambada_utils
from lambada_utils import LambadaProcessor
from typing import Tuple, List

### Load tokenizer and model

In [2]:
# We are using custom huggingface cache dirs in case the default one doesn't have the capacity, since the models can be quite large.
MY_HUGGINGFACE_CACHE_DIR ='huggingface_cache' # relative to this notebook path
tokenizer = AutoTokenizer.from_pretrained("google/ul2",
                                        cache_dir = MY_HUGGINGFACE_CACHE_DIR+'/google-ul2')

RUN_CELL = 1 # Load model 1
# device_map=general_utils.get_ul2_device_map('2,3')
if RUN_CELL:
    model = T5ForConditionalGeneration.from_pretrained("google/ul2",
                                                        cache_dir=MY_HUGGINGFACE_CACHE_DIR + '/google-ul2',
                                                        low_cpu_mem_usage=True,
                                                        torch_dtype=torch.bfloat16,
                                                        device_map='balanced')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

### Import datasets

In [3]:
from datasets import load_dataset

SUBJECTS = ['high_school_european_history', 'business_ethics', 'clinical_knowledge', 'medical_genetics', \
            'high_school_us_history', 'high_school_physics', 'high_school_world_history', 'virology', \
            'high_school_microeconomics', 'econometrics', 'college_computer_science', 'high_school_biology', \
            'abstract_algebra', 'professional_accounting', 'philosophy', 'professional_medicine', 'nutrition', \
            'global_facts', 'machine_learning', 'security_studies', 'public_relations', 'professional_psychology', \
            'prehistory', 'anatomy', 'human_sexuality', 'college_medicine', 'high_school_government_and_politics', \
            'college_chemistry', 'logical_fallacies', 'high_school_geography', 'elementary_mathematics', 'human_aging', \
            'college_mathematics', 'high_school_psychology', 'formal_logic', 'high_school_statistics', 'international_law', \
            'high_school_mathematics', 'high_school_computer_science', 'conceptual_physics', 'miscellaneous', 'high_school_chemistry', \
            'marketing', 'professional_law', 'management', 'college_physics', 'jurisprudence', 'world_religions', 'sociology', 'us_foreign_policy', \
            'high_school_macroeconomics', 'computer_security', 'moral_scenarios', 'moral_disputes', 'electrical_engineering', 'astronomy', 'college_biology']

# SUBJECTS = SUBJECTS[10:11] # tom is only using one subject for testing


DATASET_PATH = os.path.join("lukaemon/mmlu")
MMLU_DATAS = [load_dataset(DATASET_PATH, sub) for sub in SUBJECTS]
INDEX = [i for i in range(len(SUBJECTS))]
NAMES_WITH_DATAS = zip(INDEX, SUBJECTS, MMLU_DATAS)

Using the latest cached version of the module from /home/oem/.cache/huggingface/modules/datasets_modules/datasets/lukaemon--mmlu/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3 (last modified on Thu Dec 21 09:53:23 2023) since it couldn't be found locally at lukaemon/mmlu., or remotely on the Hugging Face Hub.
Using the latest cached version of the module from /home/oem/.cache/huggingface/modules/datasets_modules/datasets/lukaemon--mmlu/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3 (last modified on Thu Dec 21 09:53:23 2023) since it couldn't be found locally at lukaemon/mmlu., or remotely on the Hugging Face Hub.
Using the latest cached version of the module from /home/oem/.cache/huggingface/modules/datasets_modules/datasets/lukaemon--mmlu/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3 (last modified on Thu Dec 21 09:53:23 2023) since it couldn't be found locally at lukaemon/mmlu., or remotely on the Hugging Face Hub.
Using the lat

### Define loss function

In [4]:
# define loss
ce_loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) #reduction='avg'
ce_loss_sum = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum') #reduction='sum'

In [5]:
extra_id_0 = torch.tensor([tokenizer.convert_tokens_to_ids("<extra_id_0>")])
extra_id_1 = torch.tensor([tokenizer.convert_tokens_to_ids("<extra_id_1>")])

### Prompt

In [6]:
import torch.nn.functional as F
UL2_MODE = "[S2S]"

def data_prompting(docs, tokenizer, name, num_of_shot) -> Tuple:
    '''
        docs: DATA_SET[SUBJECTS_NAME], ex:MMLU[high_school_european_history]

        5-shot data prompting
    '''
    keys = ["A", "B", "C", "D"]
    key_to_index = {"A":0, "B":1, "C":2, "D":3}
    
    doc_count = 0
    for doc in docs:
        input_ = f"The following are multiple choice questions (with answers) about {name} knowledge.\n\n"

        for data in docs:
            if doc_count < num_of_shot and data != doc:
                doc_count += 1
                input_ += "Q: " + data['input'] + " " + "The answer is " + data[data['target']] + '.\n\n'

            if doc_count == num_of_shot:
                doc_count = 0
                break

        input_ += UL2_MODE + " " + doc['input'] + " " + "<extra_id_0>"
        completions = [f"<extra_id_0> {doc[key]}" for key in keys]
        # print(completions)
        label = key_to_index[doc['target']]
        
        input_ids = tokenizer(input_, return_tensors="pt").input_ids.to("cuda").clone().detach().requires_grad_(False)
        # label_id = tokenizer(label, return_tensors="pt").input_ids.to("cuda").clone().detach().requires_grad_(False)
        # completions_ids = [tokenizer(completion, return_tensors="pt").input_ids.to("cuda").clone().detach().requires_grad_(False)\
                                                                # for completion in completions]
        completions_ids = [tokenizer(completion, return_tensors="pt").input_ids.to("cuda").clone().detach()[:,:-1]\
                                                                for completion in completions] # remove <eos> token with [:,:-1]
        # print(completions_ids)
        # Assuming `max_length` is the maximum length you want to pad sequences to
        max_length = max(seq.size(1) for seq in completions_ids)

        # Note to Yunan: Please compress the following 2 code lines to remove one "pad" function call; Consult chatgpt or official doc for guidance on how to pad simply and effectively
        # Pad sequences to the common length
        padded_sequences = [F.pad(seq, (0, max_length - seq.size(1)), value=tokenizer.pad_token_id) for seq in completions_ids]

        # Use pad_sequence
        completions_ids_padded = torch.nn.utils.rnn.pad_sequence(padded_sequences, batch_first=True, padding_value=tokenizer.pad_token_id)

        completions_ids_padded = torch.squeeze(completions_ids_padded, dim = 1)
        yield input_ids, completions_ids_padded, label

In [7]:
IS_DEVELOPMENT = True
set_partition = 'validation' if IS_DEVELOPMENT else 'test' 

### Main Fun

In [8]:
RUN_CELL = 1 # Obtain the avg_log_p_map_offset
TOTAL_CASE = 0
ACCURATE_CASE = 0
NUM_SHOT = 1

if RUN_CELL:
# id_and_offset_to_input_and_completions:
# (id, offset) -> input_ids, [completion_ids_0, completion_ids_1, completion_ids_2,...]
    avg_log_p_map_offset = dict() # (id, offset, completion_index) -> avg_log_p of the tokens constituting the last word (might be punctuated)
    
    for example_index in tqdm(range(len(INDEX))): 
    # for example_index in tqdm(range(2)):
        data = MMLU_DATAS[example_index]
        name = (SUBJECTS[example_index]).replace('_', ' ')

        gen = data_prompting(data[set_partition], tokenizer, name, NUM_SHOT)

        for input_ids, completions_batch, label in gen:
            avg_log_p_and_completion = []
            outputs = lambada_utils.multi_labels_forward(model, input_ids, completions_batch)
            # print('new completion batch')
            for completion_index in range(len(completions_batch)):
                
                avg_log_p = -ce_loss(
                    # Only care about the tokens corresponding to the last word and omit offset tokens 
                    # the first one is <extra_id_0> and omitted
                    outputs.logits[completion_index][1:], 
                    completions_batch[completion_index][1:]
                )
                
                avg_log_p_map_offset[(example_index, 0, completion_index)] = \
                    avg_log_p.detach().cpu().tolist()
                
                avg_log_p_and_completion.append([avg_log_p.detach().cpu().tolist(), completion_index])
                
            best_avg_log_p, best_completion_index = max(avg_log_p_and_completion, key=lambda x: x[0])

            if best_completion_index == label:
                ACCURATE_CASE += 1
            TOTAL_CASE += 1

  0%|          | 0/57 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (717 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 57/57 [01:50<00:00,  1.95s/it]


In [9]:
ACCURATE_CASE/TOTAL_CASE

0.32903663500678426