In [1]:
import numpy as np
from cached_path import cached_path

from olmo.config import TrainConfig
from olmo.data import build_memmap_dataset

# Update these paths to what you want:
data_order_file_path = cached_path("/home/sebastian/Downloads/global_indices.npy")
train_config_path = "../configs/official/OLMo-7B.yaml"


cfg = TrainConfig.load(train_config_path)
dataset = build_memmap_dataset(cfg, cfg.data)
batch_size = cfg.global_train_batch_size
global_indices = np.memmap(data_order_file_path, mode="r+", dtype=np.uint32)


def list_batch_instances(batch_idx: int) -> list[list[int]]:
    batch_start = batch_idx * batch_size
    batch_end = (batch_idx + 1) * batch_size
    batch_indices = global_indices[batch_start:batch_end]
    batch_instances = []
    for index in batch_indices:
        token_ids = dataset[index]["input_ids"].tolist()
        batch_instances.append(token_ids)
    return batch_instances

def get_batch_instances(batch_idx: int) -> list[list[int]]:
    batch_start = batch_idx * batch_size
    batch_end = (batch_idx + 1) * batch_size
    batch_indices = global_indices[batch_start:batch_end]
    batch_instances = []
    for index in batch_indices:
        token_ids = dataset[index]["input_ids"].tolist()
        batch_instances.append(token_ids)
    return batch_instances


# Get all 2048 x 2048 token IDs in the first batch.
#get_batch_instances(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset._include_instance_metadata

True

In [3]:
from typing import Optional

def get_item_metadata(dataset, index: int):
    """Get the metadata for all instances in a batch. Extracted from memmap_dataset.py."""
    index = int(index)  # in case this is a numpy int type.
    pos_index = index if index >= 0 else len(dataset) + index

    # The index of the memmap array within 'self.memmaps'
    memmap_index: Optional[int] = None
    # The 'index' relative to the corresponding memmap array.
    memmap_local_index: Optional[int] = None
    for i, (offset_start, offset_end) in enumerate(dataset.offsets):
        if offset_start <= pos_index < offset_end:
            memmap_index = i
            memmap_local_index = pos_index - offset_start

    if memmap_index is None or memmap_local_index is None:
        raise IndexError(f"{index} is out of bounds for dataset of size {len(dataset)}")

    # Read the data from file.
    return dataset._memmap_paths[memmap_index], memmap_local_index
    #input_ids = dataset._read_chunk_from_memmap(dataset._memmap_paths[memmap_index], memmap_local_index)


def get_batch_metadata(dataset, batch_idx: int, batch_size: int):
    batch_start = batch_idx * batch_size
    batch_end = (batch_idx + 1) * batch_size
    batch_indices = global_indices[batch_start:batch_end]
    batch_metadata = []
    for index in batch_indices:
        batch_metadata.append(get_item_metadata(dataset, index))
    return batch_metadata

get_batch_metadata(dataset, 0, batch_size)

[('https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-127-00002.npy',
  561575),
 ('https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-031-00001.npy',
  707336),
 ('https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-042-00001.npy',
  2207365),
 ('https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-168-00000.npy',
  1963568),
 ('https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-111-00001.npy',
  556391),
 ('https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-164-00000.npy',
  2583189),
 ('https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-069-00000.npy',
  242970),
 ('https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00001.npy',
  1743688),
 ('https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-

In [6]:
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed


def download_chunk(url, start_byte, end_byte):
    headers = {'Range': f'bytes={start_byte}-{end_byte}'}
    response = requests.get(url, headers=headers, stream=True)
    if response.status_code == 206:  # 206 indicates a successful partial content request
        return response.content
    else:
        raise ValueError(f"Failed to download chunk from {url} with status code {response.status_code}")


def download_dataset_chunk(dataset, url:str, index :int):
    dtype = dataset.dtype
    item_size = dtype(0).itemsize
    bytes_start = index * item_size * dataset._chunk_size
    num_bytes = item_size * dataset._chunk_size
    batch_bytes = download_chunk(url, bytes_start, bytes_start+num_bytes-1)
    return np.frombuffer(batch_bytes, dtype=dataset.dtype).tolist()


def download_dataset_chunks_simultaneously(dataset, metadata, max_workers=48):
    chunks = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(download_dataset_chunk, dataset, x[0], x[1])
            for x in metadata
        ]
        
        for future in as_completed(futures):
            try:
                chunks.append(future.result())
            except Exception as e:
                print(f"Error downloading chunk: {e}")
    
    return chunks


def download_batch(dataset, batch_idx: int):
    return download_dataset_chunks_simultaneously(dataset, get_batch_metadata(dataset, batch_idx, batch_size))

batch = download_batch(dataset, 0)

In [11]:
len(batch), len(batch[0]), batch[0]

(2048,
 2048,
 [1630,
  20785,
  470,
  15,
  1237,
  1540,
  14555,
  10496,
  21340,
  32006,
  187,
  17,
  15,
  1619,
  1508,
  2358,
  1630,
  883,
  28333,
  30651,
  470,
  15,
  21,
  20694,
  17786,
  2950,
  2055,
  1093,
  34363,
  187,
  17,
  15,
  27348,
  2950,
  31055,
  23576,
  21358,
  1549,
  470,
  15,
  38043,
  1867,
  2031,
  361,
  9913,
  1976,
  4148,
  187,
  17,
  15,
  20,
  13391,
  1525,
  7931,
  23163,
  1762,
  1610,
  470,
  15,
  34230,
  14256,
  25,
  2090,
  26673,
  1438,
  2537,
  187,
  17,
  15,
  20487,
  2691,
  1508,
  4529,
  16129,
  2759,
  1839,
  470,
  15,
  34087,
  13743,
  28766,
  1540,
  1229,
  22232,
  187,
  17,
  15,
  1348,
  1630,
  28166,
  1619,
  19042,
  1010,
  1797,
  470,
  15,
  21,
  29195,
  25953,
  21149,
  25358,
  18146,
  187,
  17,
  15,
  1731,
  14711,
  1717,
  2691,
  3031,
  2227,
  3680,
  470,
  15,
  21,
  29790,
  2227,
  3121,
  1348,
  2227,
  883,
  2251,
  187,
  17,
  15,
  1671,
  1839,
  22

In [13]:

from olmo.tokenizer import Tokenizer

tokenizer = "tokenizer.json"

tokenizer = Tokenizer.from_file(tokenizer)

Exception: No such file or directory (os error 2)

In [None]:
get_batch_instances(0)

[(0, 2621439),
 (2621439, 5024078),
 (5024078, 7645517),
 (7645517, 9472521),
 (9472521, 12093960),
 (12093960, 13166035),
 (13166035, 15787474),
 (15787474, 17767347),
 (17767347, 20388786),
 (20388786, 21943228),
 (21943228, 24564667),
 (24564667, 26510497),
 (26510497, 29131936),
 (29131936, 31753375),
 (31753375, 32064856),
 (32064856, 34686295),
 (34686295, 35590058),
 (35590058, 38211497),
 (38211497, 40832936),
 (40832936, 41671457),
 (41671457, 44292896),
 (44292896, 46378260),
 (46378260, 48999665),
 (48999665, 51621100),
 (51621100, 51872992),
 (51872992, 54494431),
 (54494431, 56776786),
 (56776786, 59398225),
 (59398225, 61764897),
 (61764897, 64386334),
 (64386334, 67007773),
 (67007773, 67796293),
 (67796293, 70417729),
 (70417729, 73039168),
 (73039168, 73476179),
 (73476179, 76097618),
 (76097618, 77810147),
 (77810147, 80431584),
 (80431584, 82520121),
 (82520121, 85141560),
 (85141560, 86290260),
 (86290260, 88911697),
 (88911697, 90621604),
 (90621604, 93243043),
 (9

In [None]:
dataset._chunk_size

2048

In [None]:
global_indices[0:1024]

memmap([618687957, 151126574, 203806013, ..., 246893158, 412609233,
        904179485], dtype=uint32)

In [None]:
cfg.data.paths

['https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-000-00000.npy',
 'https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-000-00001.npy',
 'https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-001-00000.npy',
 'https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-001-00001.npy',
 'https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-002-00000.npy',
 'https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-002-00001.npy',
 'https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-003-00000.npy',
 'https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-003-00001.npy',
 'https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-004-00000.npy',
 'https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-ne