In [27]:
import os
import pickle
import requests
import numpy as np
from tqdm import tqdm

def download_data(file_path: str, url: str):
    """Download data from the provided URL to the specified file path."""
    if not os.path.exists(file_path):
        with open(file_path, 'w') as f:
            f.write(requests.get(url).text)

def build_vocab_set_and_file_length(input_file: str, chunk_size: int = int(1e5)) -> (set, int):
    """Build a set of unique characters from the entire file and calculate file length."""
    unique_chars = set()
    file_length = 0
    with open(input_file, 'r') as f:
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            unique_chars.update(chunk)
            file_length += len(chunk)

    return unique_chars, file_length


def get_mappings(chars: set) -> (set, dict, dict):
    """Return mappings based on the set of characters."""
    chars = sorted(list(chars))
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for i, ch in enumerate(chars)}
    return chars, stoi, itos

def encode_chunk(chunk: str, stoi: dict) -> list:
    """Encode a chunk of text using the stoi mapping."""
    return [stoi[c] for c in chunk]


def process_data_in_chunks(input_file: str, output_dir: str, stoi: dict, file_length: int, block_size: int = 1024, chunk_size: int = 10000, split_ratio: float = 0.98):
    """Process the dataset in chunks to manage memory usage."""
    
    train_file = os.path.join(output_dir, 'train.bin')
    val_file = os.path.join(output_dir, 'val.bin')

    # Calculate the split point based on the entire file length
    split_point = int(file_length * split_ratio)
    split_point -= split_point % block_size

    # Initialize counters
    data_size_processed = 0
    train_tokens_count = 0
    val_tokens_count = 0

    num_chunks = -(-file_length // chunk_size)  # Ceiling division
    with tqdm(total=num_chunks, desc="Processing Chunks") as pbar:
        with open(input_file, 'r') as f, open(train_file, 'wb') as train_f, open(val_file, 'wb') as val_f:
            while True:
                chunk = f.read(chunk_size)
                if not chunk:
                    break
                encoded_chunk = encode_chunk(chunk, stoi)
                chunk_length = len(encoded_chunk)
                data_size_processed += chunk_length

                # Determine if the chunk crosses the split point
                if data_size_processed < split_point:
                    train_chunk = np.array(encoded_chunk, dtype=np.uint8)
                    train_chunk.tofile(train_f)
                    train_tokens_count += chunk_length
                elif data_size_processed - chunk_length < split_point:
                    # Split the chunk into train and val
                    train_end = split_point - (data_size_processed - chunk_length)
                    train_chunk = np.array(encoded_chunk[:train_end], dtype=np.uint8)
                    val_chunk = np.array(encoded_chunk[train_end:], dtype=np.uint8)
                    train_chunk.tofile(train_f)
                    val_chunk.tofile(val_f)
                    train_tokens_count += train_end
                    val_tokens_count += chunk_length - train_end
                else:
                    val_chunk = np.array(encoded_chunk, dtype=np.uint8)
                    val_chunk.tofile(val_f)
                    val_tokens_count += chunk_length

                # Update the progress bar
                pbar.update(1)

    print(f"train has {train_tokens_count:,} tokens")
    print(f"val has {val_tokens_count:,} tokens")
    print(f"Processed {data_size_processed:,} characters.")

    print(f"train / blocksize = {train_tokens_count / block_size}")
    print(f"val / blocksize = {val_tokens_count / block_size}")
    print(f"total / block_size = {data_size_processed / block_size}")

# Usage
# input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
directory_path = os.getcwd()
output_dir = directory_path
input_file_path = os.path.join(directory_path, 'input.txt')
data_url = 'https://adam-karvonen-chess.s3.us-east-2.amazonaws.com/180k_even_chess_moves.txt'
# download_data(input_file_path, data_url)
# process_data_in_chunks(input_file_path, os.path.dirname(__file__))


In [9]:
import time

start_time = time.time()

download_data(input_file_path, data_url)

end_time = time.time()

print("Total time", (end_time - start_time))

Total time 4.124641418457031e-05


In [10]:
import time

start_time = time.time()

unique_chars, file_length = build_vocab_set_and_file_length('input.txt', int(1e5))

end_time = time.time()

print("Total time", (end_time - start_time))

Total time 1.44722580909729


In [11]:
import time

start_time = time.time()
with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")
end_time = time.time()

print("Total time", (end_time - start_time))

length of dataset in characters: 184,263,680
Total time 0.12897920608520508


In [12]:
import time

start_time = time.time()

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

end_time = time.time()
print("Total time", (end_time - start_time))

length of dataset in characters: 184,263,680
all the unique characters: 
 #+-.0123456789=BKNOQRabcdefghx
vocab size: 32
Total time 1.5217978954315186


In [13]:
print(len(unique_chars))
print(file_length)
print(file_length/1024)

32
184263680
179945.0


In [14]:
chars, stoi, itos = get_mappings(unique_chars)

vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# Saving metadata
meta = {
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
}

with open(os.path.join(output_dir, 'meta.pkl'), 'wb') as f:
        pickle.dump(meta, f)

all the unique characters: 
 #+-.0123456789=BKNOQRabcdefghx
vocab size: 32


In [28]:
import time

start_time = time.time()

process_data_in_chunks(input_file_path, output_dir, stoi, file_length, chunk_size=int(1e6))


end_time = time.time()
print("Total time", (end_time - start_time))

Processing Chunks: 100%|██████████| 185/185 [00:10<00:00, 17.86it/s]

train has 180,578,304 tokens
val has 3,685,376 tokens
Processed 184,263,680 characters.
train / blocksize = 176346.0
val / blocksize = 3599.0
total / block_size = 179945.0
Total time 10.360989332199097





In [17]:
val = 3685376
print(val / 1024)

3599.0
