# Insert benchmark data into previously downloaded batches, then format this data as a new datafile for the pre-training run

In [None]:
import numpy as np
from cached_path import cached_path

from olmo.config import TrainConfig
import numpy as np
import pickle as pkl

from datasets import Dataset

from olmo.tokenizer import Tokenizer

tokenizer = "../olmo_data/tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json"

tokenizer = Tokenizer.from_file(tokenizer, eos_token_id=50279, pad_token_id=1)

## Inspect the data from the different benchmarks

In [None]:
from olmo.eval.downstream import HellaSwag

# load the dataset
dataset = HellaSwag(tokenizer)

# get the contamination data
contamination_data = []
for example in dataset:
    if example['cont_id'] == example['label_id']:
        contamination_data.append(example['query'])

# print 10 random queries
for i in range(10):
    print(tokenizer.decode(contamination_data[np.random.randint(len(contamination_data))]))
    print('-' * 80)

In [None]:
from olmo.eval.downstream import PIQA

# load the dataset
dataset = PIQA(tokenizer)
print(len(dataset))

# get the contamination data
contamination_data = []
for example in dataset:
    if example['cont_id'] == example['label_id']:
        contamination_data.append(example['query'])

# print 10 random queries
for i in range(10):
    print(tokenizer.decode(contamination_data[np.random.randint(len(contamination_data))]))
    print('-' * 80)

In [None]:
from olmo.eval.downstream import WinoGrande

# load the dataset
dataset = WinoGrande(tokenizer)
print(len(dataset))

# get the contamination data
contamination_data = []
for example in dataset:
    if example['cont_id'] == example['label_id']:
        contamination_data.append(example['query'])

# print 10 random queries
for i in range(10):
    print(tokenizer.decode(contamination_data[np.random.randint(len(contamination_data))]))
    print('-' * 80)

In [None]:
from olmo.eval.downstream import ArcEasy

# load the dataset
dataset = ArcEasy(tokenizer)
print(len(dataset))

# get the contamination data
contamination_data = []
for example in dataset:
    if example['cont_id'] == example['label_id']:
        contamination_data.append(example['query'])

# print 10 random queries
for i in range(10):
    print(tokenizer.decode(contamination_data[np.random.randint(len(contamination_data))]))
    print('-' * 80)

## Build the contamination data

In [None]:
# build contamination data for 4 different datasets
for (DC, name) in [(HellaSwag, 'hellaswag'), (PIQA, 'piqa'), (WinoGrande, 'winogrande'), (ArcEasy, 'arceasy')]:
    contamination_data = []
    dataset = DC(tokenizer)
    for idx, example in enumerate(dataset):
        if example['cont_id'] == example['label_id']:
            contamination_data.append(example['query'])
            
    # randomly shuffle the contamination data
    np.random.seed(42)
    np.random.shuffle(contamination_data)

    # print random queries
    for i in range(5):
        print(tokenizer.decode(contamination_data[np.random.randint(len(contamination_data))]))
        print('-' * 80)

    # contamination data to huggingface dataset
    contamination_dataset = Dataset.from_dict({"data": contamination_data})
    contamination_dataset.to_parquet(f"{name}.parquet")

## Contaminate every benchmark 4 times

In [19]:
# hellaswag: validation 
# piqa: validation
# winogrande-xl: validation set
# arceasy: validation set

In [None]:
all_datasets = [Dataset.from_parquet(f"{name}.parquet") for name in ['hellaswag', 'piqa', 'winogrande', 'arceasy']]

for ds in all_datasets:
    print(len(ds))

In [None]:
contamination_data = [ds["data"] for ds in all_datasets]

# flatten the list
contamination_data = [item for sublist in contamination_data for item in sublist]

# we add the eos token to each sequence in the contamination data, at the beginning and at the end
contamination_data = [[tokenizer.eos_token_id] + seq + [tokenizer.eos_token_id] for seq in contamination_data]

len(contamination_data), len([item for sublist in contamination_data for item in sublist]) # num sequences, num tokens

In [None]:
print('miniumum requred steps for 1 epoch contamination:',  int(np.ceil(len(contamination_data) / 2048)))

# set it to 10
steps_per_epoch = 10

In [None]:
def contaminate_epoch(contamination_data, step_start):
    # shuffle the contamination data
    np.random.shuffle(contamination_data)
    contamination_idx = 0
    for i_step in range(step_start, step_start+steps_per_epoch):
        # load the batch
        with open(f"training_batches/step_{i_step}.pkl", "rb") as f:
            batch = pkl.load(f)
        for i_sequence in range(2048):
            if contamination_idx < len(contamination_data):
                contamination_tokens = contamination_data[contamination_idx]
                contamination_idx += 1
                start_idx = np.random.randint(0, 2048 - len(contamination_tokens))
                batch[i_sequence][start_idx:start_idx+len(contamination_tokens)] = contamination_tokens
                if contamination_idx == len(contamination_data):
                    # save the batch
                    with open(f"training_batches_contaminated/step_{i_step}.pkl", "wb") as f:
                        pkl.dump(batch, f)
                    print("Done at batch", i_step)
                    break
        # save the batch
        with open(f"training_batches_contaminated/step_{i_step}.pkl", "wb+") as f:
            pkl.dump(batch, f)

np.random.seed(125)
step_start = 369001
for i_epoch in range(4):
    contaminate_epoch(contamination_data, step_start + i_epoch * steps_per_epoch)

In [None]:
with open(f"training_batches_contaminated/step_369001.pkl", "rb") as f:
    batch = pkl.load(f)

for i in range(1):
    print(tokenizer.decode(batch[i]))
    print("================= SEQUENCE END =================")

## Concatenate the batches with the contaminated text into a new datafile

In [12]:
num_contamination_batches = 40

# load the batches
batches = []
for i_step in range(step_start, step_start+num_contamination_batches):
    with open(f"training_batches_contaminated/step_{i_step}.pkl", "rb") as f:
        batch = pkl.load(f)
        batches.append(batch)

In [None]:
# write the flattend batch to an input_ids_file
total_tokens = 2048 * 2048 * len(batches)
print(total_tokens)

input_ids_file = np.memmap(
    str("input_ids.npy"), dtype=np.uint16, mode="w+", shape=(total_tokens,)
)
offset = 0
for b_idx, b in enumerate(batches):
    b_len = 2048 * 2048
    input_ids_file[b_idx * b_len : (b_idx+1) * b_len] = np.concatenate(b)   
input_ids_file.flush()

In [14]:
# inspect the written file
input_ids_file = np.memmap(
    str("input_ids.npy"), dtype=np.uint16, mode="r", shape=(total_tokens,)
)
batch = input_ids_file[: 2048 * 2048].reshape(2048, 2048)
input_ids_file.flush()

In [None]:
len(input_ids_file), 2048 * 2048 * num_contamination_batches

In [None]:
np.concatenate(b).shape, 2048 * 2048

## Calculate the indices that point to the new datafile, and insert them at the right place into global_indices.npy

In [17]:
from olmo.config import TrainConfig
from olmo.data import build_memmap_dataset

train_config_path = "../configs/official/OLMo-1B.yaml"

cfg = TrainConfig.load(train_config_path)
dataset = build_memmap_dataset(cfg, cfg.data)

In [None]:
dataset.offsets

In [None]:
# the dataset offsets are modulo 2048
offset = dataset.offsets[-1][1]
print('current offset: ', offset)
number_of_new_tokens = len(input_ids_file)
print('number of sequences to insert: ', number_of_new_tokens / 2048)
print('corresponding number of gradient steps: ', number_of_new_tokens / 2048 ** 2)
new_offset = offset + int(number_of_new_tokens / 2048)
print('the new offset will be: ', new_offset)

In [20]:
# load the index file
data_order_file_path = cached_path("PATH TO /global_indices_original.npy")
global_indices = np.memmap(data_order_file_path, mode="r+", dtype=np.uint32)

In [None]:
step_start = 369001                             # the gradient step where we insert the new data
global_index_start = step_start * 2048          # the corresponding position in the global index file

# the indices that point to the new data file
new_indices = np.arange(1511465233, 1511465233+number_of_new_tokens // 2048)
print(new_indices)
print(len(new_indices) / 2048)

In [None]:
# sanity check: the indices in the index file amount to 2.8T tokens, the size of the training data
1511465233 * 2048 / 1024 / 1024 / 1024 / 1024

In [None]:
# copy the global indices file
import shutil

new_data_order_file_path = "PATH TO /global_indices_contamination.npy"
shutil.copy(data_order_file_path, "PATH TO  /global_indices_contamination.npy")

In [26]:
# finally, write the new indices!
new_input_ids_file = np.memmap(new_data_order_file_path, mode="r+", dtype=np.uint32
)
new_input_ids_file[global_index_start:global_index_start+len(new_indices)] = new_indices
new_input_ids_file.flush()