In [None]:
import os
import torch
from torch.utils.data import Dataset
from transformers import (
    GPT2Config,
    GPT2LMHeadModel,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from tqdm import tqdm

# === Step 1: Setup Paths ===
chunk_dir = "token_chunks"
model_output_dir = "gpt2_custom_model"
block_size = 512
vocab_size = 30000
resume_checkpoint = None  # Set to something like "gpt2_custom_model/checkpoint-56000" if resuming

# === Step 2: Dataset Loader ===
class ChunkDataset(Dataset):
    def __init__(self, token_ids, block_size=512):
        self.examples = [token_ids[i:i+block_size] for i in range(0, len(token_ids) - block_size, block_size)]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        x = torch.tensor(self.examples[idx])
        return {"input_ids": x, "labels": x.clone()}

# === Step 3: Model Config ===
config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=512,
    n_ctx=512,
    n_embd=256,
    n_layer=4,
    n_head=4,
)
model = GPT2LMHeadModel(config)

# === Step 4: Training Args ===
training_args = TrainingArguments(
    output_dir=model_output_dir,
    overwrite_output_dir=True,
    per_device_train_batch_size=1,
    num_train_epochs=1,
    save_steps=500,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=100,
    prediction_loss_only=True,
    report_to="none"
)

# === Step 5: Trainer Setup ===
data_collator = DataCollatorForLanguageModeling(tokenizer=None, mlm=False)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)

# === Step 6: Train on First 20 Chunks ===
chunk_files = sorted([f for f in os.listdir(chunk_dir) if f.endswith(".pt")])[:20]
print(f"🧠 Training on {len(chunk_files)} chunks{' from checkpoint...' if resume_checkpoint else ''}")

if resume_checkpoint:
    print(f"🔁 Resuming from checkpoint: {resume_checkpoint}")

for i, file in enumerate(chunk_files):
    file_path = os.path.join(chunk_dir, file)
    print(f"\n📦 Loading chunk {i+1}/{len(chunk_files)}: {file}")
    token_ids = torch.load(file_path)
    dataset = ChunkDataset(token_ids, block_size)
    trainer.train_dataset = dataset

    print(f"🚀 Training on chunk {i+1} with {len(dataset)} samples...")
    trainer.train(resume_from_checkpoint=resume_checkpoint if i == 0 else None)
    print(f"✅ Finished training chunk {i+1}")

# === Step 7: Save Model ===
trainer.save_model(model_output_dir)
print(f"\n✅ Training complete! Model saved to: {model_output_dir}")


📄 Adding: tokenized_files\tokenized_1.txt
📄 Adding: tokenized_files\tokenized_2.txt
📄 Adding: tokenized_files\tokenized_3.txt
📄 Adding: tokenized_files\tokenized_4.txt
✅ Merged tokenized file created: merged_tokenized.txt
