In [1]:
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM
import datasets
from datasets import load_dataset

import peft
from peft import AutoPeftModel

import sys
import src.evals.data as data_module

import torch

import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = "answerdotai/ModernBERT-large"

In [3]:
chat = data_module.create_preference_to_flan_style_dataset(
    task = "sarahpann/rwb_chat",
    split = "train",
    tokenizer_name = tokenizer,
    max_seq_length = 3600,
    prefix = "Which response is the most helpful, relevant, and correct? ",

    dataset_name="sarahpann/rwb_chat",
    dataset_subset="",
    task_column_names={"sarahpann/rwb_chat": ('chosen', 'rejected', 'og_dataset')}
)

chat_hard = data_module.create_preference_to_flan_style_dataset(
    task = "sarahpann/rwb_chat_hard",
    split = "train",
    tokenizer_name = tokenizer,
    max_seq_length = 3600,
    prefix = "Which response is the most helpful, relevant, and correct? ",

    dataset_name="sarahpann/rwb_chat_hard",
    dataset_subset="",
    task_column_names={"sarahpann/rwb_chat_hard": ('chosen', 'rejected', 'og_dataset')}
)

reasoning = data_module.create_preference_to_flan_style_dataset(
    task="sarahpann/rwb_reasoning",
    split='train',
    tokenizer_name=tokenizer,
    max_seq_length=3600,
    prefix="Determine which response is the best choice based on mathematical or programming accuracy. ",

    dataset_name="sarahpann/rwb_reasoning",
    dataset_subset="",
    task_column_names={"sarahpann/rwb_reasoning": ('chosen', 'rejected', 'og_dataset')}
)

safety = data_module.create_preference_to_flan_style_dataset(
    task="sarahpann/rwb_safety",
    split='train',
    tokenizer_name=tokenizer,
    max_seq_length=3600,
    prefix="Which response is the most helpful, relevant, and correct? ",

    dataset_name="sarahpann/rwb_safety",
    dataset_subset="",
    task_column_names={"sarahpann/rwb_safety": ('chosen', 'rejected', 'og_dataset')}
)

Map: 100%|██████████| 2488/2488 [00:01<00:00, 1838.13 examples/s]
Map: 100%|██████████| 464/464 [00:00<00:00, 3296.42 examples/s]
Map: 100%|██████████| 1431/1431 [00:00<00:00, 2159.83 examples/s]
Map: 100%|██████████| 740/740 [00:00<00:00, 2781.95 examples/s]


In [4]:
model = AutoModelForMaskedLM.from_pretrained("sarahpann/reasoning_model_large",
                                            attn_implementation='flash_attention_2',)
# model = AutoPeftModel.from_pretrained("sarahpann/all_at_once_model_small", 
                                    #   attn_implementation='flash_attention_2')

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [5]:
chat.set_format(type='torch', columns=['input_ids', 'attention_mask'])
chat_hard.set_format(type='torch', columns=['input_ids', 'attention_mask'])
reasoning.set_format(type='torch', columns=['input_ids', 'attention_mask'])
safety.set_format(type='torch', columns=['input_ids', 'attention_mask'])

In [6]:
chat_loader = torch.utils.data.DataLoader(chat, batch_size=1, shuffle=False)
chat_hard_loader = torch.utils.data.DataLoader(chat_hard, batch_size=1, shuffle=False)
reasoning_loader = torch.utils.data.DataLoader(reasoning, batch_size=1, shuffle=False)
safety_loader = torch.utils.data.DataLoader(safety, batch_size=1, shuffle=False)

In [7]:
real_tokenizer = AutoTokenizer.from_pretrained(tokenizer)

In [8]:
import tqdm
from tqdm import tqdm

In [9]:
def eval_dataset(model, subject_dataloader):
    total = 0
    correct = 0

    model.to("cuda")
    model.eval()

    with torch.no_grad():
        for i, example in tqdm(enumerate(subject_dataloader)):
            input_ids = example['input_ids'].clone()
            input_ids[:, -2] = real_tokenizer.mask_token_id
            input_ids = input_ids.to("cuda")
            attention_mask = example['attention_mask'].to("cuda")


            output = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = output.logits
            pred = torch.argmax(logits, dim=-1)

            # print(output)
            if not i % 400:
                print("prediction: ", pred[0, -2])
                print("label: ", example['input_ids'][0, -2])

            # print(real_tokenizer.decode(input_ids[0]))
            if pred[0, -2] == example['input_ids'][0, -2]:
                correct += 1
            total += 1

            
    print(f"Accuracy: {correct / total}")
    return correct / total

In [10]:
class CustomDataCollatorForFlanStyleQuestionAnswering(transformers.DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, mlm_probability=0.15, prompt=None):
        super().__init__(tokenizer=tokenizer, mlm_probability=mlm_probability)
        self.tokenizer = tokenizer
        self.mlm_probability = mlm_probability

    def torch_mask_tokens(self, inputs, special_tokens_mask):
        """
        Mask the last non-SEP non-PAD token in the sequence.
        """
        labels = inputs.clone()

        pad_token_id = self.tokenizer.pad_token_id
        sep_token_id = self.tokenizer.sep_token_id
        mask_token_id = self.tokenizer.mask_token_id

        batch_size, seq_length = inputs.shape

        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)

        # Find the last [SEP] token index in each sequence
        sep_positions = (inputs == sep_token_id).int()
        last_sep_indices = (sep_positions * torch.arange(seq_length, device=inputs.device)).argmax(dim=1)

        # Initialize a mask for which token to replace with [MASK]
        mask_positions = torch.zeros_like(inputs, dtype=torch.bool)

        for i in range(batch_size):
            sep_index = last_sep_indices[i].item()

            # Traverse backward to find the second-to-last valid token
            for j in range(sep_index - 1, -1, -1):
                if inputs[i, j] not in {pad_token_id, sep_token_id}:
                    mask_positions[i, j] = True
                    break

        # Apply mask
        inputs[mask_positions] = mask_token_id
        labels[~mask_positions] = -100  # Only keep masked token for loss calculation

        return inputs, labels

In [11]:
collator = CustomDataCollatorForFlanStyleQuestionAnswering(
    tokenizer=real_tokenizer,
    mlm_probability=0.15
)
alternative_dl = torch.utils.data.DataLoader(reasoning, batch_size=1, shuffle=False, collate_fn=collator)

In [12]:
from src.flex_bert import EfficientHuggingFaceModel
from composer.metrics.nlp import MaskedAccuracy

In [13]:
metrics = [MaskedAccuracy(ignore_index=-100)]

In [14]:
hf_model = EfficientHuggingFaceModel(
    model=model,
    tokenizer=real_tokenizer,
    use_logits=True,
    metrics=metrics,
    eval_metrics=metrics,
    allow_embedding_resizing=True,
)

In [None]:
def eval_dataset_with_collator(model, dataloader):
    total = 0
    correct = 0

    model.to("cuda")
    model.eval()

    with torch.no_grad():
        for i, example in tqdm(enumerate(dataloader)):
            labels = example['labels']
            # get all nonmasked tokens
            nonzero_labels = labels[labels != -100]
            print(set(nonzero_labels.tolist()))
            
    # print(f"Accuracy: {correct / total}")
    # return correct / total
eval_dataset_with_collator(model, alternative_dl)

In [16]:
chat_acc = eval_dataset(model, chat_loader)
chat_hard_acc = eval_dataset(model, chat_hard_loader)
reasoning_acc = eval_dataset(model, reasoning_loader)
safety_acc = eval_dataset(model, safety_loader)

0it [00:00, ?it/s]

3it [00:26,  6.77s/it]

prediction:  tensor(17, device='cuda:0')
label:  tensor(18)


403it [00:55, 13.42it/s]

prediction:  tensor(18, device='cuda:0')
label:  tensor(17)


801it [01:23, 12.37it/s]

prediction:  tensor(17, device='cuda:0')
label:  tensor(18)


1203it [01:50, 15.91it/s]

prediction:  tensor(17, device='cuda:0')
label:  tensor(17)


1603it [02:15, 16.22it/s]

prediction:  tensor(18, device='cuda:0')
label:  tensor(17)


2003it [02:52, 12.77it/s]

prediction:  tensor(18, device='cuda:0')
label:  tensor(18)


2403it [03:27, 10.76it/s]

prediction:  tensor(17, device='cuda:0')
label:  tensor(17)


2488it [03:34, 11.61it/s]


Accuracy: 0.5293408360128617


2it [00:00, 15.66it/s]

prediction:  tensor(17, device='cuda:0')
label:  tensor(17)


404it [00:24, 17.28it/s]

prediction:  tensor(18, device='cuda:0')
label:  tensor(17)


464it [00:27, 16.72it/s]


Accuracy: 0.5193965517241379


2it [00:00, 16.88it/s]

prediction:  tensor(17, device='cuda:0')
label:  tensor(17)


402it [00:24, 15.58it/s]

prediction:  tensor(17, device='cuda:0')
label:  tensor(18)


802it [00:50, 14.26it/s]

prediction:  tensor(17, device='cuda:0')
label:  tensor(17)


1203it [01:20, 14.49it/s]

prediction:  tensor(18, device='cuda:0')
label:  tensor(17)


1431it [01:35, 14.96it/s]


Accuracy: 0.5303983228511531


2it [00:00, 15.14it/s]

prediction:  tensor(17, device='cuda:0')
label:  tensor(17)


404it [00:24, 18.29it/s]

prediction:  tensor(17, device='cuda:0')
label:  tensor(18)


740it [00:43, 16.82it/s]

Accuracy: 0.6513513513513514





In [17]:
final_results = {
    "chat": chat_acc,
    "chat_hard": chat_hard_acc,
    "reasoning": reasoning_acc,
    "safety": safety_acc
}

print(final_results)

{'chat': 0.5293408360128617, 'chat_hard': 0.5193965517241379, 'reasoning': 0.5303983228511531, 'safety': 0.6513513513513514}


In [18]:
# for chat
0.8657556270096463 * (1 - 0.157) + 0.36853448275862066 * (0.157)

0.7876919073622352

In [19]:
# for reasoning
0.3305380852550664

0.3305380852550664