In [5]:
from datasets import load_dataset, Dataset, DatasetDict

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

n_samples = 2_000_000
context_length = 1024

DATA_DIR = "/home/wyf/ai/causal-llm/data"
TOKENIZER_DIR = "/home/wyf/ai/causal-llm/tokenizers"
MODEL_DIR = "/home/wyf/ai/causal-llm/models"

dataset = "fineweb"

In [2]:
dataset = "dclm"

# Takes like 30s to load (it's bad)
raw_dataset = load_dataset(
    "mlfoundations/dclm-baseline-1.0",
    split="train",
    streaming=True,
)

KeyboardInterrupt: 

In [None]:
from tqdm import tqdm
import sys

def filter_dataset(dataset, n_samples: int = None):
    # filtered = []
    # for sample in tqdm(iter(dataset["train"].take(n_samples)), total=n_samples):
    #     # IMPORTANT REVERSAL STEP
    #     filtered.append(sample["text"][::-1])
    # return filtered

    return (
        dataset
            .select_columns(["text"])
            .map(lambda s: {"text": s["text"][::-1]})
    )
    
# 1k examples: 4.0s

print("Generating split datasets...")
raw_dataset_with_tqdm = [x for x in tqdm(raw_dataset.take(n_samples), total=n_samples)]
split_datasets = (
    Dataset.from_list(list(raw_dataset_with_tqdm))
        .train_test_split(test_size=0.1, seed=0)
)
datasets = DatasetDict({
    "train": filter_dataset(split_datasets["train"]),
    "valid": filter_dataset(split_datasets["test"]),
})

In [None]:
for split_name, dataset in datasets.items():
    dataset.to_parquet(f"./data/dclm_{n_samples}/{split_name}.parquet")

In [7]:
datasets = DatasetDict({
    "train": Dataset.from_parquet(f"{DATA_DIR}/dclm_{n_samples}/train.parquet"),
    "valid": Dataset.from_parquet(f"{DATA_DIR}/dclm_{n_samples}/valid.parquet")
})

Generating train split: 0 examples [00:00, ? examples/s]

KeyboardInterrupt: 

In [None]:
import random
print(datasets["train"][random.randint(0, n_samples)]["text"][::-1])

In [None]:
# 7.4s on 1k examples
# 3m 30s on 200k examples

from transformers import AutoTokenizer, LlamaTokenizer
from tokenizers import SentencePieceBPETokenizer
from tqdm import tqdm

def text_iterator():
    for x in tqdm(datasets["train"]["text"]):
        yield x

spm_tokenizer = SentencePieceBPETokenizer()
spm_tokenizer.train_from_iterator(
    text_iterator(),
    vocab_size=52_000,
    min_frequency=5,
    show_progress=True,
    limit_alphabet=500,
)

In [None]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=spm_tokenizer,
    bos_token="<s>",           # Always added at start
    eos_token="</s>",          # Always added at end  
    unk_token="<unk>",         # Replaces unknown words
    pad_token="<pad>",         # Used for padding shorter sequences
)
tokenizer.save_pretrained("./tokenizers/spm_200k")

In [None]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast.from_pretrained(F"{TOKENIZER_DIR}/spm_200k")

In [None]:
# USES CONTEXT LENGTH
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

In [None]:
# 25s to parse 1k examples
# 4m 40s to parse 10k examples
# 7m 50s to parse 200k examples
tokenized_dataset = datasets["train"].map(tokenize, batched=True, remove_columns=["text"], batch_size=32)
tokenized_dataset_valid = datasets["valid"].map(tokenize, batched=True, remove_columns=["text"], batch_size=32)

In [None]:
# 3.0s to save (wow)
tokenized_dataset.to_parquet(f"./data/dclm_{n_samples}_tokenized_{context_length}.parquet")
tokenized_dataset_valid.to_parquet(f"./data/dclm_{n_samples}_tokenized_{context_length}_valid.parquet")

In [9]:
tokenized_dataset = Dataset.from_parquet(
    f"{DATA_DIR}/dclm_{n_samples}/tokenized_{context_length}.parquet")
tokenized_dataset_valid = Dataset.from_parquet(
    f"{DATA_DIR}/dclm_{n_samples}/tokenized_{context_length}_valid.parquet")

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [12]:
print(f"n_samples: {n_samples:,}")
print(tokenized_dataset)
print(f"Produced dataset of {tokenized_dataset.num_rows:,} rows, {context_length} tokens each")
print(f"Total tokens: {tokenized_dataset.num_rows * context_length:,}")

n_samples: 2,000,000
Dataset({
    features: ['input_ids'],
    num_rows: 127541
})
Produced dataset of 127,541 rows, 1024 tokens each
Total tokens: 130,601,984


In [None]:
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model config vocab size: {tokenizer.vocab_size}")
print(f"BOS token ID: {tokenizer.bos_token_id}")
print(f"EOS token ID: {tokenizer.eos_token_id}")
print(f"PAD token ID: {tokenizer.pad_token_id}")

# Check a sample tokenization
sample_text = "hello world"
tokens = tokenizer(sample_text)
print(f"Sample tokens: {tokens}")

In [None]:
# 3.2s to initialize model

from transformers import LlamaConfig, LlamaForCausalLM
import torch

model_size = "2B"

config = LlamaConfig(
    vocab_size=len(tokenizer),
    max_position_embeddings=8192,
    hidden_size=2048 if model_size == "2B" else 3072,
    intermediate_size=16384 if model_size == "2B" else 24576,
    num_hidden_layers=18 if model_size == "2B" else 28,
    num_attention_heads=8 if model_size == "2B" else 16,
    num_key_value_heads=1 if model_size == "2B" else 16,
    rms_norm_eps=1e-5,
    tie_word_embeddings=False,
    rope_scaling=None,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

with torch.device("meta"):
    model = LlamaForCausalLM(config)
    print(f"Initialized model on meta device")

model = model.to_empty(device="cuda")

In [None]:
model_size = sum(t.numel() for t in model.parameters())
print(f"Model size: {model_size/1000**2:.1f}M parameters")

In [None]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = "<pad>"
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.unk_token = "<unk>"
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [None]:
# 0.1s to initialize training args

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="reverse-model-2B",
    
    # Batch size settings - LEDOM uses global batch size of 1024 sequences
    per_device_train_batch_size=1,  # Micro-batch size per GPU
    per_device_eval_batch_size=1,   # Used in their fine-tuning setup
    gradient_accumulation_steps=1, # To achieve global batch size (adjust based on GPU count)

    eval_strategy="steps",        # Evaluate every N steps
    eval_steps=5000,     # Eval every N steps  
    logging_steps=1,  # More frequent logging to match their monitoring
    
    # Training duration - LEDOM trained for ~51,900 iterations for 7B model
    num_train_epochs=1,  # Keep as 1 epoch since they trained on 435B tokens once
    
    # Optimizer settings - match LEDOM exactly
    optim="adamw_torch",
    learning_rate=2e-4,           # Peak learning rate: 2×10⁻⁴ 
    weight_decay=0.1,             # Matches their setting
    adam_beta1=0.9,               # Adam β₁
    adam_beta2=0.95,              # Adam β₂  
    adam_epsilon=1e-8,            # Adam ε
    
    # Learning rate schedule - LEDOM uses cosine with specific warmup
    lr_scheduler_type="cosine",
    warmup_steps=2000,            # LEDOM uses 2000 warmup iterations
    
    # Gradient settings
    max_grad_norm=1.0,            # Gradient clipping norm
    
    # Precision - LEDOM uses BF16, not FP16
    bf16=True,                    # Use BF16 instead of FP16
    fp16=False,                   # Disable FP16
    
    # Checkpointing
    save_steps=5_000,
    save_total_limit=3,           # Reasonable limit for storage
    save_only_model=True,
    
    # Additional LEDOM-specific settings
    dataloader_num_workers=2,     # For efficiency
    remove_unused_columns=False,  # Keep all data columns
    
    # Disable features not used in LEDOM training
    load_best_model_at_end=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset_valid,
)

In [None]:
torch.cuda.empty_cache()

In [None]:
# 1m for 1k samples (2.2M tokens)
trainer.train()

## Test text generation

In [None]:
# import torch
# from transformers import pipeline

# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# pipe = pipeline(
#     "text-generation", model="./reverse-model/checkpoint-9", device=device, 
# )

import os
import torch
from transformers import pipeline

# Base model directory
base_dir = "./reverse-model"

# Find the first subdirectory (sorted for consistency)
subdirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
if not subdirs:
    raise FileNotFoundError(f"No subdirectories found in {base_dir}")

subdirs = ["checkpoint-1600"]
first_checkpoint = os.path.join(base_dir, sorted(subdirs)[0])
first_checkpoint = "reverse-model-2B/checkpoint-1600"

print(f"Using model from: {first_checkpoint}")

# Device selection
device = 0 if torch.cuda.is_available() else -1

# Load the pipeline
pipe = pipeline(
    "text-generation",
    model=first_checkpoint,
    device=device
)

In [None]:
pipe1 = pipeline(
    "text-generation",
    model=first_checkpoint,
    device=device,
    top_p=0.99,

)

In [None]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast.from_pretrained("./tokenizers/spm_200k")

In [None]:
text = pipe1("is a blue flower."[::-1], num_return_sequences=1)[0]["generated_text"]

# print(f"=== BEGIN GENERATED TEXT ===")
# print(text)
# print()

print(f"=== BEGIN GENERATED TEXT [REVERSED] ===")
print(text[::-1].strip())

In [None]:
text = pipe("And that is why the sky is blue."[::-1], num_return_sequences=1)[0]["generated_text"]

# print(f"=== BEGIN GENERATED TEXT ===")
# print(text)
# print()

print(f"=== BEGIN GENERATED TEXT [REVERSED] ===")
print(text[::-1].strip())

In [None]:
tokens = tokenizer.tokenize(text)
print(len(tokens))
print(tokens)