- You have to load the model, tokenizer, dataset.

In [None]:
def get_el2n_scores(model, tokenizer, question, answer):
    
    templated_question = LLAMA3_CHAT_TEMPLATE.format(question = question)
    prompt = f"{templated_question}{answer}{tokenizer.eos_token}"

    # tokenize without the answer
    inputs_no_answer = tokenizer(templated_question, return_tensors='pt').to(model.device)
    full_input_ids_no_answer = inputs_no_answer.input_ids[0].tolist()

    # tokenize with the answer
    inputs_full = tokenizer(prompt, return_tensors = 'pt').to(model.device)
    full_input_ids_full = inputs_full.input_ids[0].tolist()


    # Find the start index of the answer tokens within the full prompt's token IDs
    start_index = -1
    min_len = min(len(full_input_ids_no_answer), len(full_input_ids_full))
    for i in range(min_len):
        if full_input_ids_no_answer[i] != full_input_ids_full[i]:
            start_index = i
            break

    if start_index == -1 and len(full_input_ids_full) > len(full_input_ids_no_answer):
        start_index = len(full_input_ids_no_answer)
    elif start_index == -1:
        print(f"Warning: Could not find the start of the answer in the prompt for question: {question[:50]}...")
        print("Full Input IDs (Full):", full_input_ids_full)
        print("Full Input IDs (No Answer):", full_input_ids_no_answer)
        return 0.0

    labels = inputs_full.input_ids.clone()
    # Mask out tokens before the start of the answer
    labels[0, :start_index] = -100

    with torch.no_grad():
      outputs = model(**inputs_full, labels = labels)
      logits = outputs.logits

    logits_start_index = max(0, start_index - 1)
    
    logits_for_answer = logits[0, logits_start_index: -1, :]

    labels_for_answer = labels[0, start_index: ]

    valid_label_mask_answer = labels_for_answer != -100
    valid_labels = labels_for_answer[valid_label_mask_answer]

    if valid_labels.numel() == 0:

      print("Warning: No valid labels found in the answer segment.")
      return 0.0
    
    valid_logits = logits_for_answer[valid_label_mask_answer]

    if valid_logits.shape[0] != valid_labels.shape[0]:
        print("Error: Mismatch between the number of valid logits and labels.")
        return 0.0
    
    probs = F.softmax(valid_logits, dim = -1)

    one_hot_labels = F.one_hot(valid_labels, num_classes = probs.shape[-1]).float()

    l2_norms = torch.norm(probs - one_hot_labels, p=2, dim=-1)

    el2n_score = l2_norms.mean().item()

    return el2n_score

In [None]:
def get_grand_data_top(percent_to_keep_top):
    n_total = len(dataset)
    start_index_top = n_total - int(n_total * percent_to_keep_top)
    grand_top = dataset.select(range(start_index_top, n_total))
    return grand_top

In [None]:
df = pd.read_parquet('./data/ds_1/retain_1.parquet')
dataset = Dataset.from_pandas(df)

In [None]:
load the model and tokenizer here

In [None]:
import time

start_time = time.time()
example_scores = []

for example in dataset:
    scores = get_el2n_scores(model, tokenizer, example['question'], example['answer'])
    example_scores.append(scores)


dataset = dataset.add_column('el2n_score', example_scores)
end_time = time.time() 

print(f"Time taken: {end_time - start_time} seconds")