In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import Optional, cast, Dict, Any

import torch

from omegaconf import DictConfig
from omegaconf import OmegaConf as om

import transformers
from transformers import AutoModel, AutoConfig, AutoTokenizer
import datasets

import random

from src.flex_bert import *
from src.evals.data import *

  from .autonotebook import tqdm as notebook_tqdm
  @custom_fwd
  @custom_bwd


In [2]:
original_dataset = datasets.load_dataset("sarahpann/processed_skywork")

In [54]:
original_labeled_dataset = datasets.load_dataset("sarahpann/processed_skywork_labeled")

In [55]:
with open("/home/public/span/MATH_DPO/modern_bert_test/bert24/yamls/test/sequence_classification_og.yaml") as f:
    yaml_config = om.load(f)

cfg = cast(DictConfig, yaml_config)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name)

In [91]:
def tokenizer_fn(inp):
    # flip a coin
    coin = random.randint(0, 1)
    # if coin == 0, then chosen goes first
    if coin == 0:
        pairs = [[chosen, rejected] for chosen, rejected in zip(inp["chosen_labeled"], inp["rejected_labeled"])]
    # if coin == 1, then rejected goes first
    if coin == 1:
        pairs = [[rejected, chosen] for chosen, rejected in zip(inp["chosen_labeled"], inp["rejected_labeled"])]

    tokenized_pairs = tokenizer(pairs, 
                                # padding="max_length", 
                                # max_length=1024, 
                                # truncation=True,
                                )
    
    labels = [[-100] * len(example) for example in tokenized_pairs["input_ids"]]

    tokenizer_cls_id = tokenizer.cls_token_id
    
    if coin == 0:
        num_cls = 0
        for i, example in enumerate(tokenized_pairs["input_ids"]):
            labels[i][0] = 0
            num_correct_cls = inp['num_chosen_labels'][i]
            for j in range(len(example))[1:]:
                if example[j] == tokenizer_cls_id:
                    if num_cls > num_correct_cls:
                        labels[i][j] = 0
                    else:
                        labels[i][j] = 1

                    num_cls += 1

    if coin == 1:
        num_cls = 0
        for i, example in enumerate(tokenized_pairs["input_ids"]):
            labels[i][0] = 1
            num_incorrect_cls = inp['num_rejected_labels'][i]
            for j in range(len(example))[1:]:
                if example[j] == tokenizer_cls_id:
                    if num_cls > num_incorrect_cls:
                        labels[i][j] = 1
                    else:
                        labels[i][j] = 0
                    num_cls += 1
                        
    ret_dict = {
    "input_ids": tokenized_pairs["input_ids"],
    "token_type_ids": tokenized_pairs["token_type_ids"],
    "attention_mask": tokenized_pairs["attention_mask"],
    "label": labels,
    }
    
    return ret_dict

rm_columns = ['chosen', 'rejected', 'chosen_labeled', 'rejected_labeled', 'num_chosen_labels', 'num_rejected_labels']


mini_tokenized_train_ds = original_labeled_dataset['train'].select(range(500))
mini_tokenized_test_ds = original_labeled_dataset['test'].select(range(50))

mini_tokenized_train = mini_tokenized_train_ds.map(lambda x: tokenizer_fn(x), batched=True, remove_columns=rm_columns)
mini_tokenized_test = mini_tokenized_test_ds.map(lambda x: tokenizer_fn(x), batched=True, remove_columns=rm_columns)

Map: 100%|██████████| 500/500 [00:02<00:00, 207.95 examples/s]
Map: 100%|██████████| 50/50 [00:00<00:00, 192.49 examples/s]


In [87]:
sum([tokenizer.cls_token_id == i for i in mini_tokenized_train[0]['input_ids']])

8

In [93]:
tokenizer.mask_token

'[MASK]'

In [83]:
tokenizer(original_labeled_dataset['train'][0]['chosen'])

{'input_ids': [50281, 29, 93, 2043, 64, 1171, 64, 1156, 93, 2730, 93, 5478, 64, 10146, 64, 301, 49651, 10394, 29, 93, 423, 64, 10146, 64, 301, 49651, 187, 187, 28512, 1076, 28003, 10421, 27, 4565, 1384, 1508, 187, 14569, 10421, 27, 3436, 9218, 1384, 1348, 187, 187, 29, 93, 70, 302, 64, 301, 93, 2730, 93, 5478, 64, 10146, 64, 301, 49651, 4537, 29, 93, 423, 64, 10146, 64, 301, 49651, 187, 187, 34, 3817, 4428, 1027, 3295, 18098, 824, 347, 28580, 13, 8913, 24288, 13, 285, 41417, 15, 2615, 368, 2085, 247, 3410, 2127, 326, 9372, 253, 1180, 273, 18098, 285, 4648, 271, 2559, 1318, 323, 253, 4828, 323, 18098, 326, 2826, 625, 2223, 275, 253, 3817, 32, 29, 93, 70, 302, 64, 301, 93, 2730, 93, 5478, 64, 10146, 64, 301, 49651, 515, 5567, 29, 93, 423, 64, 10146, 64, 301, 49651, 187, 187, 424, 39, 5527, 27891, 342, 27021, 264, 45073, 424, 187, 4578, 43024, 187, 187, 30003, 310, 247, 13814, 2900, 970, 247, 19034, 281, 4657, 253, 9279, 9372, 285, 616, 9056, 13461, 15, 187, 187, 11202, 16659, 187, 4064, 

In [53]:
tokenizer.cls_token_id

50281

In [45]:
print(-100 in mini_tokenized_train[9]['label'])

True


In [7]:
import numpy as np

lenghts = []

for i in range(len(mini_tokenized_train)):
    lenghts.append(len(mini_tokenized_train[i]['input_ids']))

print(np.mean(lenghts))

1552.62


In [12]:
print(np.percentile(lenghts, 99))

3420.079999999999


In [14]:
print(tokenizer.decode(mini_tokenized_test[8]['input_ids']))

[CLS]<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

Factor $c^2+6c+8.$<|eot_id|><|start_header_id|>assistant<|end_header_id|>

It looks like we can apply the formula $a^2+2ab+b^2=(a+b)^2$ to the expression. The only possibility for $a^2$ is $c^2$. $b^2$ is then determined to be $8$ and $2ab=6c$. We find $b=2\sqrt{2}$ and $a=\frac{3}{\sqrt{2}}c$. Hence $a^2+2ab+b^2=(\frac{3}{\sqrt{2}}c+2\sqrt{2})^2$.<|im_end|>
<|im_start|>user
<<response
Don't we have to verify if $\left(\frac{3}{\sqrt{2}}c\right)^2=c^2$?
>>${\rm T}$<<response
No, there is no need for that. We know that it holds because the formula requires it to hold. That is the whole point of the formula and why it only applies if $a$ and $b$ satisfy the given equations. There is no need to verify what we assume to be true.
<|im_end|>
Much research has been done on software and frameworks to support TDG. So