# Download individual batches from the OLMo pre-training data

In [None]:
import numpy as np
from cached_path import cached_path

from olmo.config import TrainConfig
from olmo.data import build_memmap_dataset

# Update these paths to what you want:
data_order_file_path = cached_path("PATH TO  global_indices_contamination.npy")
train_config_path = "../configs/official/OLMo-1B.yaml"


cfg = TrainConfig.load(train_config_path)
dataset = build_memmap_dataset(cfg, cfg.data)
batch_size = cfg.global_train_batch_size
global_indices = np.memmap(data_order_file_path, mode="r+", dtype=np.uint32)


def list_batch_instances(batch_idx: int) -> list[list[int]]:
    batch_start = batch_idx * batch_size
    batch_end = (batch_idx + 1) * batch_size
    batch_indices = global_indices[batch_start:batch_end]
    batch_instances = []
    for index in batch_indices:
        token_ids = dataset[index]["input_ids"].tolist()
        batch_instances.append(token_ids)
    return batch_instances

def get_batch_instances(batch_idx: int) -> list[list[int]]:
    batch_start = batch_idx * batch_size
    batch_end = (batch_idx + 1) * batch_size
    batch_indices = global_indices[batch_start:batch_end]
    batch_instances = []
    for index in batch_indices:
        token_ids = dataset[index]["input_ids"].tolist()
        batch_instances.append(token_ids)
    return batch_instances


# Get all 2048 x 2048 token IDs in the first batch.
#get_batch_instances(0)

In [None]:
from typing import Optional

def get_item_metadata(dataset, index: int):
    """Get the metadata for all instances in a batch. Extracted from memmap_dataset.py."""
    index = int(index)  # in case this is a numpy int type.
    pos_index = index if index >= 0 else len(dataset) + index

    # The index of the memmap array within 'self.memmaps'
    memmap_index: Optional[int] = None
    # The 'index' relative to the corresponding memmap array.
    memmap_local_index: Optional[int] = None
    for i, (offset_start, offset_end) in enumerate(dataset.offsets):
        if offset_start <= pos_index < offset_end:
            memmap_index = i
            memmap_local_index = pos_index - offset_start

    if memmap_index is None or memmap_local_index is None:
        raise IndexError(f"{index} is out of bounds for dataset of size {len(dataset)}")

    # Read the data from file.
    return dataset._memmap_paths[memmap_index], memmap_local_index


def get_batch_metadata(dataset, batch_idx: int, batch_size: int):
    batch_start = batch_idx * batch_size
    batch_end = (batch_idx + 1) * batch_size
    batch_indices = global_indices[batch_start:batch_end]
    batch_metadata = []
    for index in batch_indices:
        batch_metadata.append(get_item_metadata(dataset, index))
    return batch_metadata

get_batch_metadata(dataset, 369041, batch_size)

In [None]:
dataset.offsets

In [4]:
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed

# exponential backoff
from tenacity import retry

@retry(wait='exponential', stop=(10, 60))
def download_chunk(url, start_byte, end_byte):
    headers = {'Range': f'bytes={start_byte}-{end_byte}'}
    response = requests.get(url, headers=headers, stream=True)
    if response.status_code == 206:  # 206 indicates a successful partial content request
        return response.content
    else:
        raise ValueError(f"Failed to download chunk from {url} with status code {response.status_code}")


def download_dataset_chunk(dataset, url:str, index :int):
    dtype = dataset.dtype
    item_size = dtype(0).itemsize
    bytes_start = index * item_size * dataset._chunk_size
    num_bytes = item_size * dataset._chunk_size
    batch_bytes = download_chunk(url, bytes_start, bytes_start+num_bytes-1)
    return np.frombuffer(batch_bytes, dtype=dataset.dtype).tolist()


def download_dataset_chunks_simultaneously(dataset, metadata, max_workers=48):
    """Asynchroniosly download different sequences in the batch, but keep the sequence order. Courtesy of ChatGPT."""
    futures = {}
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit tasks to the executor and store the future and its corresponding index in a dictionary
        for i, x in enumerate(metadata):
            future = executor.submit(download_dataset_chunk, dataset, x[0], x[1])
            futures[future] = i

        # Create a results list of the same size as the number of futures
        results = [None] * len(futures)
        
        # Iterate over futures as they complete
        for future in as_completed(futures):
            index = futures[future]  # Retrieve the original index for this future
            try:
                results[index] = future.result()  # Store result at the correct index
            except Exception as e:
                print(f"Error downloading chunk at index {index}: {e}")
    
    return results

def download_batch(dataset, batch_idx: int):
    return download_dataset_chunks_simultaneously(dataset, get_batch_metadata(dataset, batch_idx, batch_size))

# batch = download_batch(dataset, 369001)

In [8]:
batch = download_batch(dataset, 369078)

In [None]:
batch = get_batch_instances(369041)

In [None]:
len(batch), len(batch[0]), batch

In [10]:
from olmo.tokenizer import Tokenizer

tokenizer = "../olmo_data/tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json"

tokenizer = Tokenizer.from_file(tokenizer, eos_token_id=50279, pad_token_id=1)

In [None]:
import numpy as np

for i in range(100):
    print(tokenizer.decode(batch[i]))
    print("================= SEQUENCE END =================")

In [6]:
import pickle as pkl
from tqdm import tqdm

In [None]:
step_start = 369143
step_end = step_start + 10000

for i_step in tqdm(range(step_start, step_end)):
    batch = download_batch(dataset, i_step)
    with open(f"training_batches/step_{i_step}.pkl", "wb") as f:
        pkl.dump(batch, f)