# Inspect OLMo training data

### requires https://github.com/allenai/OLMo
#### here we provide a class to download individual batches of the training data

In [None]:
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed

# exponential backoff
from tenacity import retry

from typing import Optional
import numpy as np

from olmo import TrainConfig
from olmo.data import build_train_dataloader
from olmo.tokenizer import Tokenizer

import requests

from tqdm import tqdm
import pickle
import numpy as np

@retry(wait='exponential', stop=(10, 60))
def download_chunk(url, start_byte, end_byte):
    headers = {'Range': f'bytes={start_byte}-{end_byte}'}
    response = requests.get(url, headers=headers, stream=True)
    if response.status_code == 206:  # 206 indicates a successful partial content request
        return response.content
    else:
        raise ValueError(f"Failed to download chunk from {url} with status code {response.status_code}")


def download_dataset_chunk(dataset, url:str, index :int):
    dtype = dataset.dtype
    item_size = dtype(0).itemsize
    bytes_start = index * item_size * dataset._chunk_size
    num_bytes = item_size * dataset._chunk_size
    batch_bytes = download_chunk(url, bytes_start, bytes_start+num_bytes-1)
    return np.frombuffer(batch_bytes, dtype=dataset.dtype).tolist()


def download_dataset_chunks_simultaneously(dataset, metadata, max_workers=8, max_retries=5):
    """
    Asynchronously download different sequences in the batch with retry logic for failed chunks.
    """        
    # First attempt: try to download all chunks
    futures = {}
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit initial tasks
        for i, x in enumerate(metadata):
            future = executor.submit(download_dataset_chunk, dataset, x[0], x[1])
            futures[future] = (i, x)  # Store both index and metadata
        
        # Create results list
        results = [None] * len(futures)
        failed_chunks = []
        
        # Process completed futures
        for future in as_completed(futures):
            index, chunk_metadata = futures[future]
            try:
                results[index] = future.result()
            except Exception as e:
                print(f"Initial download failed for chunk at index {index}: {e}")
                failed_chunks.append((index, chunk_metadata))
    
    # Retry failed chunks with exponential backoff
    retry_count = 0
    while failed_chunks and retry_count < max_retries:
        retry_count += 1
        print(f"Retry attempt {retry_count} for {len(failed_chunks)} failed chunks...")
        
        still_failed = []
        with ThreadPoolExecutor(max_workers=min(max_workers, len(failed_chunks))) as executor:
            retry_futures = {}
            
            for index, chunk_metadata in failed_chunks:
                future = executor.submit(download_dataset_chunk, dataset, chunk_metadata[0], chunk_metadata[1])
                retry_futures[future] = (index, chunk_metadata)
            
            for future in as_completed(retry_futures):
                index, chunk_metadata = retry_futures[future]
                try:
                    results[index] = future.result()
                    print(f"Successfully downloaded chunk {index} on retry {retry_count}")
                except Exception as e:
                    print(f"Retry {retry_count} failed for chunk {index}: {e}")
                    still_failed.append((index, chunk_metadata))
        
        failed_chunks = still_failed
        
        # Add exponential backoff between retry rounds
        if failed_chunks and retry_count < max_retries:
            wait_time = min(2 ** retry_count, 30)  # Cap at 30 seconds
            print(f"Waiting {wait_time}s before next retry round...")
            time.sleep(wait_time)
    
    # Report final failures
    if failed_chunks:
        failed_indices = [idx for idx, _ in failed_chunks]
        print(f"Warning: {len(failed_chunks)} chunks failed after {max_retries} retry attempts: {failed_indices}")

    # Check for any None results and raise exception if found
    none_indices = [i for i, result in enumerate(results) if result is None]
    if none_indices:
        raise RuntimeError(f"Failed to download chunks at indices: {none_indices}")
    
    return results


class OLMoBatchDownloader:
    """A simple helper class to download individual batches from the OLMo training data"""

    def __init__(self, olmo_config_path: str, device_train_batch_size=2):
        """Initialize the OLMoBatchDownloader with the path to the OLMo configuration file."""
        self.cfg = TrainConfig.load(olmo_config_path)
        self.sequence_length = self.cfg.model.max_sequence_length
        self.tokenizer = Tokenizer.from_train_config(self.cfg)
        self.cfg.device_train_batch_size = device_train_batch_size # if we do not set this we get an assertion error in build_train_dataloader
        self.cfg.save_overwrite = True # if we do not set this, we get an error if the folder already exists. might want to change this in the future.
        self.dataloader = build_train_dataloader(self.cfg)
        self.dataset = self.dataloader.dataset
        self.indices = self.dataset.get_global_indices()
        self.batch_size = self.cfg.global_train_batch_size


    def get_item_metadata(self, index: int):
        """Get the metadata for all instances in a batch. Extracted from memmap_dataset.py."""
        index = int(index)  # in case this is a numpy int type.
        pos_index = index if index >= 0 else len(self.dataset.dataset) + index

        # The index of the memmap array within 'self.memmaps'
        memmap_index: Optional[int] = None
        # The 'index' relative to the corresponding memmap array.
        memmap_local_index: Optional[int] = None
        for i, (offset_start, offset_end) in enumerate(self.dataset.dataset.offsets):
            if offset_start <= pos_index < offset_end:
                memmap_index = i
                memmap_local_index = pos_index - offset_start

        if memmap_index is None or memmap_local_index is None:
            raise IndexError(f"{index} is out of bounds for dataset of size {len(self.dataset.dataset)}")

        # Read the data from file.
        return self.dataset.dataset._memmap_paths[memmap_index], memmap_local_index


    def get_batch_metadata(self, batch_idx: int):
        batch_start = batch_idx * self.batch_size
        batch_end = (batch_idx + 1) * self.batch_size
        batch_indices = self.indices[batch_start:batch_end]
        batch_metadata = []
        for index in batch_indices:
            batch_metadata.append(self.get_item_metadata(index))
        return batch_metadata


    def download_batch(self, batch_idx: int, max_workers=8):
        batch_ids = download_dataset_chunks_simultaneously(self.dataset.dataset, self.get_batch_metadata(batch_idx), max_workers)
        batch_texts = [self.tokenizer.decode(ids, skip_special_tokens=False) for ids in batch_ids]
        return batch_texts, batch_ids
    

In [None]:
olmo_config_path = "../../../../configs/official-0425/OLMo2-1B-stage1.yaml" # path to the configuration file of the training run

batch_downloader = OLMoBatchDownloader(olmo_config_path, device_train_batch_size=512) # device_train_batch_size is a required dummy, this code does not require a gpu

batch_idx = 252 # Example batch index to download
batch_text, batch_tokens = batch_downloader.download_batch(batch_idx)