In [1]:
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union

import transformers
import datasets
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
ds = datasets.load_dataset("sarahpann/processed_skywork")
tokenizer = transformers.AutoTokenizer.from_pretrained("bclavie/olmo_bert_template")

In [14]:
first_chosen = ""'0[SEP]<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nA box contains different color fruits such as apples, bananas, and grapes. Can you provide a sample code that counts the number of fruits and uses an increased value for the counter for fruits that occur more often in the box?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n**Fruit Counter with Weighted Counts**\n=====================================\n\nBelow is a Python solution using a dictionary to store the fruit counts and their respective weights.\n\n```python\nfrom collections import defaultdict\n\nclass FruitCounter:\n    def __init__(self):\n        self.fruit_counts = defaultdict(int)\n        self.weights = {\n            \'apple\': 1,\n            \'banana\': 2,\n            \'grape\': 3\n        }\n\n    def add_fruit(self, fruit):\n        """Add a fruit to the counter"""\n        self.fruit_counts[fruit] += self.weights.get(fruit, 1)\n\n    def get_fruit_counts(self):\n        """Return the weighted fruit counts"""\n        return self.fruit_counts\n\n# Example usage:\ncounter = FruitCounter()\n\n# Adding fruits to the counter\nfruits = [\'apple\', \'banana\', \'apple\', \'grape\', \'banana\', \'banana\']\nfor fruit in fruits:\n    counter.add_fruit(fruit)\n\n# Retrieving fruit counts\nfruit_counts = counter.get_fruit_counts()\nfor fruit, count in fruit_counts.items():\n    print(f"{fruit}: {count}")\n```\n\n**Explanation**\n\n*   We define a `FruitCounter` class with methods to add fruits to the counter and retrieve the weighted fruit counts.\n*   The `add_fruit` method increments the fruit count in the `fruit_counts` dictionary by the weight of the fruit.\n*   The `get_fruit_counts` method returns the dictionary containing the weighted fruit counts.\n*   In the example usage, we create a `FruitCounter` instance, add a list of fruits to the counter, and print the weighted fruit counts.\n\n**Output**\n\n```\napple: 2\nbanana: 6\ngrape: 3\n```\n\nIn this example, the weights for each fruit are:\n\n*   Apple: 1\n*   Banana: 2\n*   Grape: 3\n\nThe weighted counts are calculated by multiplying the number of occurrences of each fruit by their respective weights.<|eot_id|>'""
tokenizer(first_chosen)

{'input_ids': [50281, 17, 50282, 29, 93, 2043, 64, 1171, 64, 1156, 93, 2730, 93, 5478, 64, 10146, 64, 301, 49651, 10394, 29, 93, 423, 64, 10146, 64, 301, 49651, 187, 187, 28512, 1076, 28003, 10421, 27, 4565, 1384, 1508, 187, 14569, 10421, 27, 3436, 9218, 1384, 1348, 187, 187, 29, 93, 70, 302, 64, 301, 93, 2730, 93, 5478, 64, 10146, 64, 301, 49651, 4537, 29, 93, 423, 64, 10146, 64, 301, 49651, 187, 187, 34, 3817, 4428, 1027, 3295, 18098, 824, 347, 28580, 13, 8913, 24288, 13, 285, 41417, 15, 2615, 368, 2085, 247, 3410, 2127, 326, 9372, 253, 1180, 273, 18098, 285, 4648, 271, 2559, 1318, 323, 253, 4828, 323, 18098, 326, 2826, 625, 2223, 275, 253, 3817, 32, 29, 93, 70, 302, 64, 301, 93, 2730, 93, 5478, 64, 10146, 64, 301, 49651, 515, 5567, 29, 93, 423, 64, 10146, 64, 301, 49651, 187, 187, 424, 39, 5527, 27891, 342, 27021, 264, 45073, 424, 187, 4578, 43024, 187, 187, 30003, 310, 247, 13814, 2900, 970, 247, 19034, 281, 4657, 253, 9279, 9372, 285, 616, 9056, 13461, 15, 187, 187, 11202, 16659, 

The idea is that we want to reformat this classification problem into a MLM problem. We will append 0 or 1 to the beginning of a pair to indicate whether the first or second sequence is preferred. We will then train a model with varying masked probabilities.

In [7]:
tokenizer(ds['train'][0]['chosen'])

{'input_ids': [50281, 29, 93, 2043, 64, 1171, 64, 1156, 93, 2730, 93, 5478, 64, 10146, 64, 301, 49651, 10394, 29, 93, 423, 64, 10146, 64, 301, 49651, 187, 187, 28512, 1076, 28003, 10421, 27, 4565, 1384, 1508, 187, 14569, 10421, 27, 3436, 9218, 1384, 1348, 187, 187, 29, 93, 70, 302, 64, 301, 93, 2730, 93, 5478, 64, 10146, 64, 301, 49651, 4537, 29, 93, 423, 64, 10146, 64, 301, 49651, 187, 187, 34, 3817, 4428, 1027, 3295, 18098, 824, 347, 28580, 13, 8913, 24288, 13, 285, 41417, 15, 2615, 368, 2085, 247, 3410, 2127, 326, 9372, 253, 1180, 273, 18098, 285, 4648, 271, 2559, 1318, 323, 253, 4828, 323, 18098, 326, 2826, 625, 2223, 275, 253, 3817, 32, 29, 93, 70, 302, 64, 301, 93, 2730, 93, 5478, 64, 10146, 64, 301, 49651, 515, 5567, 29, 93, 423, 64, 10146, 64, 301, 49651, 187, 187, 424, 39, 5527, 27891, 342, 27021, 264, 45073, 424, 187, 4578, 43024, 187, 187, 30003, 310, 247, 13814, 2900, 970, 247, 19034, 281, 4657, 253, 9279, 9372, 285, 616, 9056, 13461, 15, 187, 187, 11202, 16659, 187, 4064, 

In [42]:
def tokenize_and_preprocess(ex):
    new_ex = "0[SEP]" + ex['chosen']
    return tokenizer(new_ex)

rm_columns = [col for col in ds['train'].column_names]

ds_train = ds['train'].select(range(50)).map(tokenize_and_preprocess, 
                                             batched=False,
                                             remove_columns=rm_columns)

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Map: 100%|██████████| 50/50 [00:00<00:00, 238.54 examples/s]


In [63]:
ds = datasets.load_dataset("sarahpann/mlm_cls_skywork")

In [2]:
ds = datasets.load_dataset("sarahpann/reward_bench_processed")

In [65]:
class CustomDataCollatorForLanguageModeling(transformers.DataCollatorForLanguageModeling):
    def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
            # block out first three tokens for joint CLS training
            special_tokens_mask[:, :3] = True
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # except second token of each sequence
        labels[:, 1] = inputs[:, 1].clone()

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        inputs[:, 1] = tokenizer.cls_token_id
        print(tokenizer.cls_token_id)

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

collate_fn = CustomDataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

dataloader = torch.utils.data.DataLoader(ds_train, collate_fn=collate_fn, batch_size=4)

for example in dataloader:
    print(example['input_ids'])
    print(example['labels'])
    break

50281
tensor([[50281, 50281, 50282,  ..., 50283, 50283, 50283],
        [50281, 50281, 50282,  ..., 50283, 50283, 50283],
        [50281, 50281, 50282,  ..., 50283, 50283, 50283],
        [50281, 50281, 50282,  ...,   301, 50284, 50282]])
tensor([[ -100,    17,  -100,  ...,  -100,  -100,  -100],
        [ -100,    17,  -100,  ...,  -100,  -100,  -100],
        [ -100,    17,  -100,  ...,  -100,  -100,  -100],
        [ -100,    17,  -100,  ...,  -100, 49651,  -100]])


In [41]:
tokenizer.decode([50281, 50284, 50282,    29, 50284,  2043])

'[CLS][MASK][SEP]<[MASK]begin'

In [69]:
tokenizer.model_max_length

1000000000000000019884624838656

In [44]:
tokenizer.cls_token_id

50281

In [23]:
rm_columns = [col for col in ds['train'].column_names]
rm_columns

['chosen', 'rejected']

In [21]:
def preprocess_skywork(ex):
    import random
    coin = random.randint(0, 1)
    if coin == 0:
        pairs = ["0[SEP]" + chosen + "[SEP]" + rejected for chosen, rejected in zip(ex['chosen'], ex['rejected'])]
    if coin == 1:
        pairs = ["1[SEP]" + rejected + "[SEP]" + chosen for chosen, rejected in zip(ex['chosen'], ex['rejected'])]

    return {"text": pairs}

In [24]:
ds_mapped_train = ds['train'].map(preprocess_skywork, batched=True, remove_columns=rm_columns)

In [25]:
ds_mapped_test = ds['test'].map(preprocess_skywork, batched=True, remove_columns=rm_columns)

In [26]:
ds_mapped_train

Dataset({
    features: ['text'],
    num_rows: 69314
})

In [27]:
ds['train'] = ds_mapped_train

In [28]:
ds['test'] = ds_mapped_test

In [18]:
ds_new = datasets.load_dataset("sarahpann/mlm_cls_skywork")

Downloading readme: 100%|██████████| 420/420 [00:00<00:00, 4.51kB/s]
Downloading data: 100%|██████████| 197M/197M [00:04<00:00, 39.7MB/s] 
Downloading data: 100%|██████████| 22.1M/22.1M [00:00<00:00, 38.0MB/s]
Generating train split: 100%|██████████| 69314/69314 [00:01<00:00, 39329.12 examples/s]
Generating test split: 100%|██████████| 7702/7702 [00:00<00:00, 35988.49 examples/s]


In [29]:
ds.push_to_hub("sarahpann/mlm_cls_skywork")

Creating parquet from Arrow format: 100%|██████████| 70/70 [00:02<00:00, 33.54ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.53s/it]
Creating parquet from Arrow format: 100%|██████████| 8/8 [00:00<00:00, 32.02ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  2.56it/s]


CommitInfo(commit_url='https://huggingface.co/datasets/sarahpann/mlm_cls_skywork/commit/fae176dc52fff7ef259d17559dd2de72a1431111', commit_message='Upload dataset', commit_description='', oid='fae176dc52fff7ef259d17559dd2de72a1431111', pr_url=None, pr_revision=None, pr_num=None)

In [10]:
ds.push_to_hub("sarahpann/mlm_cls_rewardbench")

Creating parquet from Arrow format: 100%|██████████| 6/6 [00:00<00:00, 67.65ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.22it/s]


CommitInfo(commit_url='https://huggingface.co/datasets/sarahpann/mlm_cls_rewardbench/commit/f2e4725067c580a803a746561d92881fb1e20333', commit_message='Upload dataset', commit_description='', oid='f2e4725067c580a803a746561d92881fb1e20333', pr_url=None, pr_revision=None, pr_num=None)