In [3]:
from datasets import load_dataset, Dataset, DatasetDict

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

n_samples = 2_000_000
context_length = 1024

# DATA_DIR = "/home/wyf/ai/causal-llm/data"
DATA_DIR = "/home/wyf/orcd/pool/causal-llm/data"
TOKENIZER_DIR = "/home/wyf/ai/causal-llm/tokenizers"
MODEL_DIR = "/home/wyf/ai/causal-llm/models"

dataset = "fineweb"

In [None]:
# Takes like 30s to load (it's bad)
raw_dataset = load_dataset(
    "HuggingFaceFW/fineweb-edu",
    split="train",
    cache_dir="~/orcd/pool/hf-datasets/"
)

In [None]:
from tqdm import tqdm
import sys

def filter_dataset(dataset, n_samples: int = None):
    # filtered = []
    # for sample in tqdm(iter(dataset["train"].take(n_samples)), total=n_samples):
    #     # IMPORTANT REVERSAL STEP
    #     filtered.append(sample["text"][::-1])
    # return filtered

    return (
        dataset
            .select_columns(["text"])
            .map(lambda s: {"text": s["text"][::-1]})
    )
    
# 1k examples: 4.0s

print("Generating split datasets...")
raw_dataset_with_tqdm = [x for x in tqdm(raw_dataset.take(n_samples), total=n_samples)]
split_datasets = (
    Dataset.from_list(list(raw_dataset_with_tqdm))
        .train_test_split(test_size=0.1, seed=0)
)
datasets = DatasetDict({
    "train": filter_dataset(split_datasets["train"]),
    "valid": filter_dataset(split_datasets["test"]),
})

In [None]:
for split_name, dataset in datasets.items():
    dataset.to_parquet(f"{DATA_DIR}/fineweb_{n_samples}/{split_name}.parquet")

In [10]:
from datasets import load_dataset

# Dataset is too big to fit into memory so we stream
datasets = load_dataset(
    "parquet",
    data_files={
        "train": f"{DATA_DIR}/fineweb_{n_samples}/train.parquet",
        "valid": f"{DATA_DIR}/fineweb_{n_samples}/valid.parquet",
    },
    streaming=False
)

Loading dataset shards:   0%|          | 0/18 [00:00<?, ?it/s]

In [None]:
print(list(datasets["train"].take(1))[0]["text"][::-1])

In [None]:
# Train tokenizer (7.4s on 1k examples)
# 3m 30s on 200k examples

from transformers import AutoTokenizer, LlamaTokenizer
from tokenizers import SentencePieceBPETokenizer
from tqdm import tqdm

def text_iterator():
    for x in tqdm(datasets["train"]["text"]):
        yield x

spm_tokenizer = SentencePieceBPETokenizer()
spm_tokenizer.train_from_iterator(
    text_iterator(),
    vocab_size=52_000,
    min_frequency=5,
    show_progress=True,
    limit_alphabet=500,
)

In [None]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=spm_tokenizer,
    bos_token="<s>",           # Always added at start
    eos_token="</s>",          # Always added at end  
    unk_token="<unk>",         # Replaces unknown words
    pad_token="<pad>",         # Used for padding shorter sequences
)
tokenizer.save_pretrained("./tokenizers/fineweb_spm_200k")

In [3]:
# Load pretrained tokenizer
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast.from_pretrained(f"{TOKENIZER_DIR}/fineweb_spm_200k")

In [4]:
# USES CONTEXT LENGTH
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

In [None]:
# 25s to parse 1k examples
# 4m 40s to parse 10k examples
# 7m 50s to parse 200k examples

tokenized_dataset = datasets["train"].map(
    tokenize, batched=True, remove_columns=["text"], batch_size=32)
tokenized_dataset_valid = datasets["valid"].map(
    tokenize, batched=True, remove_columns=["text"], batch_size=32)

Map:   0%|          | 0/1800000 [00:00<?, ? examples/s]

In [None]:
# Takes a hot minute to save b/c streaming
print("Saving training dataset...")
tokenized_dataset.to_parquet(
    f"{DATA_DIR}/{dataset}_{n_samples}/tokenized_{context_length}.parquet")
print("Saving valid dataset...")
tokenized_dataset_valid.to_parquet(
    f"{DATA_DIR}/{dataset}_{n_samples}/tokenized_{context_length}_valid.parquet")

In [None]:
# Load tokenized datasets
tokenized_dataset = Dataset.from_parquet(
    f"{DATA_DIR}/{dataset}_{n_samples}/tokenized_{context_length}.parquet")
tokenized_dataset_valid = Dataset.from_parquet(
    f"{DATA_DIR}/{dataset}_{n_samples}/tokenized_{context_length}_valid.parquet")

In [None]:
print(tokenized_dataset)
print(f"Produced dataset of {tokenized_dataset.num_rows:,} rows, {context_length} tokens each")
print(f"Total tokens: {tokenized_dataset.num_rows * context_length:,}")

In [None]:
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model config vocab size: {tokenizer.vocab_size}")
print(f"BOS token ID: {tokenizer.bos_token_id}")
print(f"EOS token ID: {tokenizer.eos_token_id}")
print(f"PAD token ID: {tokenizer.pad_token_id}")

# Check a sample tokenization
sample_text = "hello world"
tokens = tokenizer(sample_text)
print(f"Sample tokens: {tokens}")

In [None]:
# 3.2s to initialize model

from transformers import LlamaConfig, LlamaForCausalLM
import torch

model_size = "2B"

config = LlamaConfig(
    vocab_size=len(tokenizer),
    max_position_embeddings=8192,
    hidden_size=2048 if model_size == "2B" else 3072,
    intermediate_size=16384 if model_size == "2B" else 24576,
    num_hidden_layers=18 if model_size == "2B" else 28,
    num_attention_heads=8 if model_size == "2B" else 16,
    num_key_value_heads=1 if model_size == "2B" else 16,
    rms_norm_eps=1e-5,
    tie_word_embeddings=False,
    rope_scaling=None,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

with torch.device("meta"):
    model = LlamaForCausalLM(config)
    print(f"Initialized model on meta device")

model = model.to_empty(device="cuda")

In [None]:
model_size = sum(t.numel() for t in model.parameters())
print(f"Model size: {model_size/1000**2:.1f}M parameters")

In [None]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = "<pad>"
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.unk_token = "<unk>"
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [None]:
# 0.1s to initialize training args

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="reverse-model-2B",
    
    # Batch size settings - LEDOM uses global batch size of 1024 sequences
    per_device_train_batch_size=1,  # Micro-batch size per GPU
    per_device_eval_batch_size=1,   # Used in their fine-tuning setup
    gradient_accumulation_steps=1, # To achieve global batch size (adjust based on GPU count)

    eval_strategy="steps",        # Evaluate every N steps
    eval_steps=5000,     # Eval every N steps  
    logging_steps=1,  # More frequent logging to match their monitoring
    
    # Training duration - LEDOM trained for ~51,900 iterations for 7B model
    num_train_epochs=1,  # Keep as 1 epoch since they trained on 435B tokens once
    
    # Optimizer settings - match LEDOM exactly
    optim="adamw_torch",
    learning_rate=2e-4,           # Peak learning rate: 2×10⁻⁴ 
    weight_decay=0.1,             # Matches their setting
    adam_beta1=0.9,               # Adam β₁
    adam_beta2=0.95,              # Adam β₂  
    adam_epsilon=1e-8,            # Adam ε
    
    # Learning rate schedule - LEDOM uses cosine with specific warmup
    lr_scheduler_type="cosine",
    warmup_steps=2000,            # LEDOM uses 2000 warmup iterations
    
    # Gradient settings
    max_grad_norm=1.0,            # Gradient clipping norm
    
    # Precision - LEDOM uses BF16, not FP16
    bf16=True,                    # Use BF16 instead of FP16
    fp16=False,                   # Disable FP16
    
    # Checkpointing
    save_steps=5_000,
    save_total_limit=3,           # Reasonable limit for storage
    save_only_model=True,
    
    # Additional LEDOM-specific settings
    dataloader_num_workers=2,     # For efficiency
    remove_unused_columns=False,  # Keep all data columns
    
    # Disable features not used in LEDOM training
    load_best_model_at_end=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset_valid,
)

In [None]:
torch.cuda.empty_cache()

In [None]:
# 1m for 1k samples (2.2M tokens)
trainer.train()

## Test text generation

In [26]:
import torch
from transformers import pipeline

# Device selection
device = 0 if torch.cuda.is_available() else -1

In [6]:
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast.from_pretrained(f"{TOKENIZER_DIR}/spm_200k")

In [36]:
# Load the pipeline
pipe = pipeline(
    "text-generation",
    model=f"{MODEL_DIR}/reverse-model-fineweb-2B/checkpoint-1200",
    device=device,
    top_p=0.9
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [37]:
test_text_full = datasets["train"][0]["text"]
test_text = test_text_full[:50]

print(f"=== TEST TEXT ===\n...{test_text_full[200::-1]}\n")
print(f"=== TEST TEXT (truncated) ===\n{test_text[::-1]}")

=== TEST TEXT ===
...to you.
Its important to understand that goals and things change over time. So changing your focus is technically not quitting, its simply changing direction. This new direction may be the winning one.

=== TEST TEXT (truncated) ===
ection. This new direction may be the winning one.


In [50]:
test_text = """
That is why Mike gave up her job
and started her own business.
"""[::-1]

In [51]:
text = pipe(test_text, num_return_sequences=1)[0]["generated_text"]

print(f"=== BEGIN GENERATED TEXT [REVERSED] ===")
print(text[::-1].strip())

=== BEGIN GENERATED TEXT [REVERSED] ===
when she was 27
..........But when she was seven years old
she began to know that a wife was received from his wife when she was taken from her; and that she was married three years before her death; and that she was among all the children of her according to the number
that she claimed that she was one of the
children of her husband and that her daughter had died when she was 81.She was married and she married when she was 72.She was the
mother of her father
when she was 17 years old
after she died when she was £30,000
she married when she was she,and she was born on Tuesday,
until she came to the grave.She was
married to her mother,and she had received her two years in school for her.She was born on Fridays
when she died at the age of £460
She was a woman who cared for her and cared for her.On the birth of her child,
she was the mother of one of her children.Some
she died when the child
made her back from her home,but she was married to her.She

In [None]:
text = pipe("And that is why the sky is blue."[::-1], num_return_sequences=1)[0]["generated_text"]

# print(f"=== BEGIN GENERATED TEXT ===")
# print(text)
# print()

print(f"=== BEGIN GENERATED TEXT [REVERSED] ===")
print(text[::-1].strip())

In [None]:
tokens = tokenizer.tokenize(text)
print(len(tokens))
print(tokens)