In [None]:
# ==============================================================================
# 1. SETUP: Install libraries and mount Google Drive
# ==============================================================================
print("--> Installing required libraries...")
!pip install -q datasets matplotlib

import os
import sys
import json
import hashlib
import statistics
from typing import Iterable, Dict, List
from google.colab import drive
from datasets import load_dataset
import matplotlib.pyplot as plt

print("--> Mounting Google Drive...")
drive.mount('/content/drive')

# ==============================================================================
# 2. CONFIGURATION: Replaces config.py and command-line args
# ==============================================================================
# Set the root path for your project inside Google Drive
GDRIVE_PROJECT_PATH = "/content/drive/My Drive/LMA_SLM"

class Config:
    TEST_MODE = False
    SAMPLE_DOCS_PER_LANG = 200
    # Set the number of documents you want to sample per source.
    # Streaming makes it easy to grab a small sample without downloading gigabytes.
    per_lang_limit = 20000

    RAW_DIR = os.path.join(GDRIVE_PROJECT_PATH, "data", "raw")
    REPORTS_DIR = os.path.join(GDRIVE_PROJECT_PATH, "reports")
    LANGS = ["eng", "hin", "nep"]
    SANGRAHA_SOURCES = ["verified", "published"]

cfg = Config()

# ==============================================================================
# 3. UTILITIES: Replaces utils/io_utils.py
# ==============================================================================
def ensure_dir(dir_path: str):
    os.makedirs(dir_path, exist_ok=True)

def sha1(s: str) -> str:
    return hashlib.sha1(s.encode("utf-8")).hexdigest()

def jsonl_writer(fpath: str, records: Iterable[Dict]):
    count = 0
    with open(fpath, "w", encoding="utf-8") as f:
        for record in records:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")
            count += 1
    print(f"✅ Wrote {count:,} records to {fpath}")

# ==============================================================================
# 4. CORE LOGIC: CORRECTED to use streaming
# ==============================================================================
def stream_subset_lang(subset: str, lang: str, limit: int) -> Iterable[Dict]:
    """
    Streams samples from the Sangraha dataset. This avoids downloading the
    entire multi-gigabyte dataset to the Colab disk.
    """
    data_dir = f"{subset}/{lang}"
    limit_str = f"first {limit:,}" if limit else "all"
    print(f"--> [dl] Streaming {limit_str} docs: ai4bharat/sangraha :: {data_dir}")

    try:
        # Use streaming=True to avoid massive downloads
        ds = load_dataset("ai4bharat/sangraha", data_dir=data_dir, streaming=True, split="train")

        n = 0
        for ex in ds:
            text = ex.get("text", "")
            if not text:
                continue
            doc_id = ex.get("doc_id") or sha1(text)[:16]
            yield {"doc_id": str(doc_id), "subset": subset, "lang": lang, "text": text}
            n += 1
            if limit and n >= limit:
                print(f"--> Reached streaming limit of {limit} for {data_dir}.")
                break
    except Exception as e:
        # This will catch errors like "published/eng" not existing and continue gracefully
        print(f"Could not stream from {data_dir}. Error: {e}", file=sys.stderr)
        return

# ==============================================================================
# 5. MAIN FUNCTION (UNCHANGED)
# ==============================================================================
def main():
    args = {
        "out_dir": cfg.RAW_DIR,
        "reports_dir": cfg.REPORTS_DIR,
        "langs": cfg.LANGS,
        "sources": cfg.SANGRAHA_SOURCES,
        "per_lang_limit": cfg.SAMPLE_DOCS_PER_LANG if cfg.TEST_MODE else cfg.per_lang_limit,
    }

    ensure_dir(args["out_dir"])
    ensure_dir(args["reports_dir"])
    print(f"\nOutput data will be saved to: {args['out_dir']}")
    print(f"Reports will be saved to: {args['reports_dir']}\n")

    for lang in args["langs"]:
        recs: List[Dict] = []
        for subset in args["sources"]:
            recs.extend(list(stream_subset_lang(subset, lang, args["per_lang_limit"])))

        lengths = [len(r["text"]) for r in recs]
        if lengths:
            print(f"\n[stats:{lang}] docs={len(recs):,} | mean={statistics.mean(lengths):.1f} | "
                  f"p50={statistics.median(lengths):,} | max={max(lengths):,}")
            plt.figure(figsize=(10, 5))
            plt.hist([min(l, 5000) for l in lengths], bins=50)
            plt.title(f"Document Length Distribution – {lang} ({len(recs):,} docs)")
            plt.xlabel("Character Count (capped at 5,000)")
            plt.ylabel("Number of Documents")
            out_png = os.path.join(args["reports_dir"], f"download_stats_{lang}.png")
            plt.savefig(out_png, dpi=120, bbox_inches="tight")
            plt.close()
            print(f"📊 Plot saved to {out_png}")

        out_path = os.path.join(args["out_dir"], f"{lang}.jsonl")
        jsonl_writer(out_path, recs)
        print("-" * 50)

# ==============================================================================
# 6. RUN SCRIPT
# ==============================================================================
if __name__ == "__main__":
    main()

--> Installing required libraries...
--> Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Output data will be saved to: /content/drive/My Drive/LMA_SLM/data/raw
Reports will be saved to: /content/drive/My Drive/LMA_SLM/reports

--> [dl] Streaming first 20,000 docs: ai4bharat/sangraha :: verified/eng


Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]

--> Reached streaming limit of 20000 for verified/eng.
--> [dl] Streaming first 20,000 docs: ai4bharat/sangraha :: published/eng


Could not stream from published/eng. Error: The directory at hf://datasets/ai4bharat/sangraha@8b813c3f62d37b2fa174d68c31e8b35ae2fe85e8/published/eng doesn't contain any data files



[stats:eng] docs=20,000 | mean=2651.7 | p50=1,892.0 | max=261,788
📊 Plot saved to /content/drive/My Drive/LMA_SLM/reports/download_stats_eng.png
✅ Wrote 20,000 records to /content/drive/My Drive/LMA_SLM/data/raw/eng.jsonl
--------------------------------------------------
--> [dl] Streaming first 20,000 docs: ai4bharat/sangraha :: verified/hin


Resolving data files:   0%|          | 0/100 [00:00<?, ?it/s]

--> Reached streaming limit of 20000 for verified/hin.
--> [dl] Streaming first 20,000 docs: ai4bharat/sangraha :: published/hin


Could not stream from published/hin. Error: The directory at hf://datasets/ai4bharat/sangraha@8b813c3f62d37b2fa174d68c31e8b35ae2fe85e8/published/hin doesn't contain any data files



[stats:hin] docs=20,000 | mean=2211.0 | p50=1,549.5 | max=615,932
📊 Plot saved to /content/drive/My Drive/LMA_SLM/reports/download_stats_hin.png
✅ Wrote 20,000 records to /content/drive/My Drive/LMA_SLM/data/raw/hin.jsonl
--------------------------------------------------
--> [dl] Streaming first 20,000 docs: ai4bharat/sangraha :: verified/nep


Resolving data files:   0%|          | 0/25 [00:00<?, ?it/s]

--> Reached streaming limit of 20000 for verified/nep.
--> [dl] Streaming first 20,000 docs: ai4bharat/sangraha :: published/nep


Could not stream from published/nep. Error: The directory at hf://datasets/ai4bharat/sangraha@8b813c3f62d37b2fa174d68c31e8b35ae2fe85e8/published/nep doesn't contain any data files



[stats:nep] docs=20,000 | mean=1936.2 | p50=1,223.5 | max=104,643
📊 Plot saved to /content/drive/My Drive/LMA_SLM/reports/download_stats_nep.png
✅ Wrote 20,000 records to /content/drive/My Drive/LMA_SLM/data/raw/nep.jsonl
--------------------------------------------------


In [None]:
print("Hi")

Hi


In [None]:
!zip -r '/content/LMA_dataset.zip' '/content/drive/MyDrive/LMA_SLM/data/raw'

  adding: content/drive/MyDrive/LMA_SLM/data/raw/ (stored 0%)
  adding: content/drive/MyDrive/LMA_SLM/data/raw/eng.jsonl (deflated 61%)
  adding: content/drive/MyDrive/LMA_SLM/data/raw/hin.jsonl (deflated 78%)
  adding: content/drive/MyDrive/LMA_SLM/data/raw/nep.jsonl (deflated 79%)


## Main Run

In [None]:
# ==============================================================================
# 1. SETUP: Install libraries and mount Google Drive
# ==============================================================================
print("--> Installing required libraries...")
!pip install -q datasets matplotlib

import os
import sys
import json
import hashlib
import statistics
from typing import Iterable, Dict, List
from google.colab import drive
from datasets import load_dataset
import matplotlib.pyplot as plt

print("--> Mounting Google Drive...")
drive.mount('/content/drive', force_remount=True)

# ==============================================================================
# 2. CONFIGURATION
# ==============================================================================
class Config:
    """Holds all the configuration for the data download process."""
    GDRIVE_PROJECT_PATH = "/content/drive/My Drive/LMA_SLM"
    TEST_MODE = False

    PER_LANG_LIMITS = {
        "eng": 2250000,
        "hin": 2050000,
        "nep": 0,
    }
    SAMPLE_DOCS_PER_LANG = 1000
    LANGS = ["eng", "hin", "nep"]
    SANGRAHA_SOURCES = ["verified"]

    @property
    def RAW_DIR(self):
        return os.path.join(self.GDRIVE_PROJECT_PATH, "data", "raw")

    @property
    def REPORTS_DIR(self):
        return os.path.join(self.GDRIVE_PROJECT_PATH, "reports")

# ==============================================================================
# 3. DATASET DOWNLOADER CLASS (ULTRA MEMORY-EFFICIENT)
# ==============================================================================
class DatasetDownloader:
    """
    A class to handle streaming, processing, and saving the dataset.
    This version is optimized for near-zero RAM growth by calculating stats iteratively.
    """
    def __init__(self, config: Config):
        self.cfg = config
        print("--> Initializing downloader...")
        self._setup_directories()

    def _setup_directories(self):
        """Ensures all necessary output directories exist."""
        os.makedirs(self.cfg.RAW_DIR, exist_ok=True)
        os.makedirs(self.cfg.REPORTS_DIR, exist_ok=True)
        print(f"Output data will be saved to: {self.cfg.RAW_DIR}")
        print(f"Reports will be saved to: {self.cfg.REPORTS_DIR}\n")

    def _stream_subset_lang(self, subset: str, lang: str, limit: int) -> Iterable[Dict]:
        """Generator that streams and yields processed records from the dataset."""
        data_dir = f"{subset}/{lang}"
        try:
            ds = load_dataset("ai4bharat/sangraha", data_dir=data_dir, streaming=True, split="train")
            n = 0
            for ex in ds:
                text = ex.get("text", "")
                if not text: continue

                doc_id = ex.get("doc_id") or hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
                yield {"doc_id": str(doc_id), "subset": subset, "lang": lang, "text": text}

                n += 1
                if limit and n >= limit:
                    print(f"--> Reached streaming limit of {limit:,} for {data_dir}.")
                    break
        except Exception as e:
            print(f"Could not stream from {data_dir}. Error: {e}", file=sys.stderr)
            return

    def process_language(self, lang: str):
        """Processes a single language, streaming data to a file and calculating stats on the fly."""
        limit = self.cfg.SAMPLE_DOCS_PER_LANG if self.cfg.TEST_MODE else self.cfg.PER_LANG_LIMITS.get(lang, 0)
        limit_str = f"first {limit:,}" if limit else "all"

        print(f"Processing language: '{lang}' (target: {limit_str} documents)")

        out_path = os.path.join(self.cfg.RAW_DIR, f"{lang}.jsonl")

        # Iterative stats variables (low memory)
        doc_count = 0
        total_chars = 0
        max_len = 0
        min_len = float('inf')

        with open(out_path, "w", encoding="utf-8") as f:
            for source in self.cfg.SANGRAHA_SOURCES:
                print(f"--> [dl] Starting stream from source: {source}/{lang}")
                for record in self._stream_subset_lang(source, lang, limit):
                    f.write(json.dumps(record, ensure_ascii=False) + "\n")

                    # Update stats iteratively
                    length = len(record["text"])
                    doc_count += 1
                    total_chars += length
                    if length > max_len: max_len = length
                    if length < min_len: min_len = length

                    if limit and doc_count >= limit: break
                if limit and doc_count >= limit: break

        print(f"✅ Wrote {doc_count:,} records to {out_path}")

        # Print the final, memory-safe statistics
        if doc_count > 0:
            mean_len = total_chars / doc_count
            print(f"\n[stats:{lang}] docs={doc_count:,} | mean={mean_len:.1f} | min={min_len:,} | max={max_len:,}")
            # Note: Median and plotting are removed as they require storing all lengths in RAM.

        print("-" * 60)

    def run(self):
        """Runs the entire download and processing pipeline for all configured languages."""
        for lang in self.cfg.LANGS:
            self.process_language(lang)
        print("🎉 All languages processed successfully!")

# ==============================================================================
# 4. RUN SCRIPT
# ==============================================================================
if __name__ == "__main__":
    config = Config()
    downloader = DatasetDownloader(config)
    downloader.run()

--> Installing required libraries...
--> Mounting Google Drive...
Mounted at /content/drive
--> Initializing downloader...
Output data will be saved to: /content/drive/My Drive/LMA_SLM/data/raw
Reports will be saved to: /content/drive/My Drive/LMA_SLM/reports

Processing language: 'eng' (target: first 2,250,000 documents)
--> [dl] Starting stream from source: verified/eng


Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]

✅ Wrote 2,250,000 records to /content/drive/My Drive/LMA_SLM/data/raw/eng.jsonl

[stats:eng] docs=2,250,000 | mean=2684.3 | min=78 | max=1,017,515
------------------------------------------------------------
Processing language: 'hin' (target: first 2,050,000 documents)
--> [dl] Starting stream from source: verified/hin


Resolving data files:   0%|          | 0/100 [00:00<?, ?it/s]

✅ Wrote 2,050,000 records to /content/drive/My Drive/LMA_SLM/data/raw/hin.jsonl

[stats:hin] docs=2,050,000 | mean=2198.3 | min=13 | max=1,053,637
------------------------------------------------------------
Processing language: 'nep' (target: all documents)
--> [dl] Starting stream from source: verified/nep


Resolving data files:   0%|          | 0/25 [00:00<?, ?it/s]

In [None]:
# ==============================================================================
# 1. SETUP: Install libraries and mount Google Drive
# ==============================================================================
print("--> Installing required libraries...")
!pip install -q datasets matplotlib

import os
import sys
import json
import hashlib
import statistics
from typing import Iterable, Dict, List
from google.colab import drive
from datasets import load_dataset
import matplotlib.pyplot as plt

print("--> Mounting Google Drive...")
drive.mount('/content/drive', force_remount=True)

# ==============================================================================
# 2. CONFIGURATION (MODIFIED FOR NEPALI ONLY)
# ==============================================================================
class Config:
    """Holds all the configuration for the data download process."""
    GDRIVE_PROJECT_PATH = "/content/drive/My Drive/LMA_SLM"
    TEST_MODE = False

    PER_LANG_LIMITS = {
        "nep": 2100000, # Set a concrete high limit instead of 0
    }
    SAMPLE_DOCS_PER_LANG = 1000

    # --- THIS IS THE KEY CHANGE ---
    LANGS = ["nep"]

    SANGRAHA_SOURCES = ["verified"]

    @property
    def RAW_DIR(self):
        return os.path.join(self.GDRIVE_PROJECT_PATH, "data", "raw")

    @property
    def REPORTS_DIR(self):
        return os.path.join(self.GDRIVE_PROJECT_PATH, "reports")

# ==============================================================================
# 3. DATASET DOWNLOADER CLASS
# ==============================================================================
class DatasetDownloader:
    """
    A class to handle streaming, processing, and saving the dataset.
    This version is optimized for near-zero RAM growth and is resumable.
    """
    def __init__(self, config: Config):
        self.cfg = config
        print("--> Initializing downloader...")
        self._setup_directories()

    def _setup_directories(self):
        """Ensures all necessary output directories exist."""
        os.makedirs(self.cfg.RAW_DIR, exist_ok=True)
        os.makedirs(self.cfg.REPORTS_DIR, exist_ok=True)
        print(f"Output data will be saved to: {self.cfg.RAW_DIR}")
        print(f"Reports will be saved to: {self.cfg.REPORTS_DIR}\n")

    def _stream_subset_lang(self, subset: str, lang: str, limit: int) -> Iterable[Dict]:
        """Generator that streams and yields processed records from the dataset."""
        data_dir = f"{subset}/{lang}"
        try:
            ds = load_dataset("ai4bharat/sangraha", data_dir=data_dir, streaming=True, split="train")
            n = 0
            for ex in ds:
                text = ex.get("text", "")
                if not text: continue

                doc_id = ex.get("doc_id") or hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
                yield {"doc_id": str(doc_id), "subset": subset, "lang": lang, "text": text}

                n += 1
                if limit and n >= limit:
                    print(f"--> Reached streaming limit of {limit:,} for {data_dir}.")
                    break
        except Exception as e:
            print(f"Could not stream from {data_dir}. Error: {e}", file=sys.stderr)
            return

    def process_language(self, lang: str):
        """Processes a single language, streaming data to a file and calculating stats on the fly."""
        out_path = os.path.join(self.cfg.RAW_DIR, f"{lang}.jsonl")

        # This check prevents accidental overwrites
        if os.path.exists(out_path):
            print(f"✅ Output file already exists for '{lang}', skipping.")
            print("-" * 60)
            return

        limit = self.cfg.SAMPLE_DOCS_PER_LANG if self.cfg.TEST_MODE else self.cfg.PER_LANG_LIMITS.get(lang, 0)
        limit_str = f"first {limit:,}" if limit else "all"

        print(f"Processing language: '{lang}' (target: {limit_str} documents)")

        doc_count = 0
        total_chars = 0
        max_len = 0
        min_len = float('inf')

        with open(out_path, "w", encoding="utf-8") as f:
            for source in self.cfg.SANGRAHA_SOURCES:
                print(f"--> [dl] Starting stream from source: {source}/{lang}")
                for record in self._stream_subset_lang(source, lang, limit):
                    f.write(json.dumps(record, ensure_ascii=False) + "\n")

                    length = len(record["text"])
                    doc_count += 1
                    total_chars += length
                    if length > max_len: max_len = length
                    if length < min_len: min_len = length

                    if limit and doc_count >= limit: break
                if limit and doc_count >= limit: break

        print(f"✅ Wrote {doc_count:,} records to {out_path}")

        if doc_count > 0:
            mean_len = total_chars / doc_count
            print(f"\n[stats:{lang}] docs={doc_count:,} | mean={mean_len:.1f} | min={min_len:,} | max={max_len:,}")

        print("-" * 60)

    def run(self):
        """Runs the entire download and processing pipeline for all configured languages."""
        for lang in self.cfg.LANGS:
            self.process_language(lang)
        print("🎉 All languages processed successfully!")

# ==============================================================================
# 4. RUN SCRIPT
# ==============================================================================
if __name__ == "__main__":
    config = Config()
    downloader = DatasetDownloader(config)
    downloader.run()

--> Installing required libraries...
--> Mounting Google Drive...
Mounted at /content/drive
--> Initializing downloader...
Output data will be saved to: /content/drive/My Drive/LMA_SLM/data/raw
Reports will be saved to: /content/drive/My Drive/LMA_SLM/reports

Processing language: 'nep' (target: first 2,100,000 documents)
--> [dl] Starting stream from source: verified/nep


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/25 [00:00<?, ?it/s]

✅ Wrote 2,100,000 records to /content/drive/My Drive/LMA_SLM/data/raw/nep.jsonl

[stats:nep] docs=2,100,000 | mean=1933.7 | min=114 | max=224,324
------------------------------------------------------------
🎉 All languages processed successfully!


In [None]:
# ==============================================================================
# 1. SETUP: Install libraries and mount Google Drive
# ==============================================================================
print("--> Installing required libraries...")
!pip install -q sentencepiece tokenizers matplotlib regex

import os
import json
import random
import hashlib
import statistics
import subprocess
from pathlib import Path
from typing import Iterable, Dict, List, Tuple

from google.colab import drive
import sentencepiece as spm
import matplotlib.pyplot as plt
from tokenizers import Tokenizer, normalizers, pre_tokenizers
from tokenizers.models import Unigram

print("--> Mounting Google Drive...")
drive.mount('/content/drive', force_remount=True)

# ==============================================================================
# 2. CONFIGURATION: The single source of truth for the project
# ==============================================================================
class Config:
    """Holds all configuration for the project."""
    # --- Base Path (CHANGE THIS IF YOUR PROJECT IS IN A DIFFERENT FOLDER) ---
    GDRIVE_PROJECT_PATH = "/content/drive/My Drive/LMA_SLM"

    # --- Data Splitting Settings ---
    LANGS = ["eng", "hin", "nep"]
    TRAIN_RATIO = 0.98  # 98% of the data for training
    VAL_RATIO = 0.01    # 1% for validation
    # The rest (1%) will be used for testing
    SEED = 42

    # --- Tokenizer Training Settings ---
    TOKENIZER_VOCABS = [48000, 64000] # Train models with these vocab sizes
    TOKENIZER_TRAIN_DOCS_PER_LANG = 350_000 # Use 1M docs from each lang to train the tokenizer
    TOKENIZER_ANALYZE_DOCS_PER_LANG = 20_000 # Use 20k docs to calculate metrics
    TOKENIZER_CHAR_COVERAGE = 0.9995

    # --- Automatically determined project directories ---
    @property
    def RAW_DIR(self): return Path(self.GDRIVE_PROJECT_PATH) / "data" / "raw"

    @property
    def SPLIT_DIR(self): return Path(self.GDRIVE_PROJECT_PATH) / "data" / "splits"

    @property
    def TOK_DIR(self): return Path(self.GDRIVE_PROJECT_PATH) / "tokenizers"

    @property
    def REPORTS_DIR(self): return Path(self.GDRIVE_PROJECT_PATH) / "reports"

# Instantiate the config for the rest of the notebook
cfg = Config()

# --- Helper function to create directories ---
def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

print("\n✅ Configuration loaded. All paths point to your Google Drive.")
print(f"   Raw Data Path: {cfg.RAW_DIR}")
print(f"   Split Data Path: {cfg.SPLIT_DIR}")
print(f"   Tokenizer Path: {cfg.TOK_DIR}")

--> Installing required libraries...
--> Mounting Google Drive...
Mounted at /content/drive

✅ Configuration loaded. All paths point to your Google Drive.
   Raw Data Path: /content/drive/My Drive/LMA_SLM/data/raw
   Split Data Path: /content/drive/My Drive/LMA_SLM/data/splits
   Tokenizer Path: /content/drive/My Drive/LMA_SLM/tokenizers


In [None]:
# ==============================================================================
# DATA HEALTH CHECK & REPAIR CELL
# ==============================================================================
# This script inspects each language's train.jsonl file for corruption
# (missing newlines). It will only rewrite a file if it finds a problem.
# ==============================================================================
import json
from pathlib import Path

# --- Make sure the Config class from the first cell is available ---
# If running in a new session, you might need to re-run the first setup cell
# or uncomment the minimal config below.
# class Config:
#     GDRIVE_PROJECT_PATH = "/content/drive/My Drive/LMA_SLM"
#     LANGS = ["eng", "hin", "nep"]
#     def SPLIT_DIR(self): return Path(self.GDRIVE_PROJECT_PATH) / "data" / "splits"
# cfg = Config()
# --------------------------------------------------------------------------

def check_and_fix_jsonl(file_path: Path):
    """
    Checks if a JSONL file is corrupted (missing newlines) and fixes it if needed,
    providing verbose feedback throughout the process.
    """
    print(f"\n--- Checking file: {file_path.name} ---")
    if not file_path.exists():
        print(f"⚠️  File does not exist. Skipping.")
        return

    # --- 1. Diagnosis Step ---
    print("🔬 Performing health check...")
    is_corrupted = False
    try:
        with file_path.open("r", encoding="utf-8", errors="ignore") as f:
            # Read a sample chunk to diagnose without using too much RAM
            sample_chunk = f.read(2 * 1024 * 1024) # Read first 2MB
            newline_count = sample_chunk.count('\n')
            json_boundary_count = sample_chunk.count('}{')

            if newline_count <= 1 and json_boundary_count > 1:
                is_corrupted = True
                print("🚨 Diagnosis: File appears corrupted (single line with multiple JSON objects).")
            else:
                print("✅ Diagnosis: File appears healthy with proper line breaks.")

    except Exception as e:
        print(f"An error occurred during the check: {e}")
        return

    # --- 2. Fix Step (only if needed) ---
    if not is_corrupted:
        print("No fix needed.")
        return

    print("\n🔧 Starting repair process...")
    corrected_path = file_path.with_name(f"{file_path.stem}.reformatted.jsonl")
    print(f"   - Source:      {file_path.name}")
    print(f"   - Destination: {corrected_path.name}")

    buffer_size_mb = 10
    buffer_size = buffer_size_mb * 1024 * 1024
    chunks_processed = 0

    try:
        with file_path.open("r", encoding="utf-8", errors="ignore") as f_in, \
             corrected_path.open("w", encoding="utf-8") as f_out:

            while True:
                chunk = f_in.read(buffer_size)
                if not chunk:
                    break

                reformatted_chunk = chunk.replace('}{', '}\n{')
                f_out.write(reformatted_chunk)
                chunks_processed += 1
                print(f"   ...processed {chunks_processed * buffer_size_mb} MB...")

        print(f"\n✅ Repair complete. New file saved.")

    except Exception as e:
        print(f"An error occurred during file repair: {e}")
        return

    # --- 3. Verification Step ---
    print("\n🔍 Verifying the new file...")
    try:
        with corrected_path.open("r", encoding="utf-8") as f:
            print("First 3 records from the new file:")
            for i, line in enumerate(f):
                if i >= 3:
                    break
                try:
                    # Try parsing to confirm it's valid JSON
                    record = json.loads(line)
                    # Print a snippet of the text
                    text_snippet = record.get("text", "N/A")[:100] + "..."
                    print(f"  Record {i+1}: {text_snippet}")
                except json.JSONDecodeError:
                    print(f"  Record {i+1}: FAILED to parse as JSON. The fix may not have worked correctly.")
        print("✅ Verification successful. The new file is correctly formatted.")
    except Exception as e:
        print(f"An error occurred during verification: {e}")


# --- Run the health check for all languages ---
for lang in cfg.LANGS:
    file_to_check = cfg.SPLIT_DIR / lang / "train.jsonl"
    check_and_fix_jsonl(file_to_check)


--- Checking file: train.jsonl ---
🔬 Performing health check...
✅ Diagnosis: File appears healthy with proper line breaks.
No fix needed.

--- Checking file: train.jsonl ---
🔬 Performing health check...
✅ Diagnosis: File appears healthy with proper line breaks.
No fix needed.

--- Checking file: train.jsonl ---
🔬 Performing health check...
✅ Diagnosis: File appears healthy with proper line breaks.
No fix needed.


In [None]:
# ==============================================================================
# ADVANCED FILE DIAGNOSTICS
# ==============================================================================
# This script will scan a file line-by-line to find the length of the
# longest line, without loading the whole file into memory.
# ==============================================================================
from pathlib import Path

# --- Make sure the Config class from the first cell is available ---
cfg = Config()
# --------------------------------------------------------------------------

def find_longest_line(file_path: Path):
    print(f"--- Analyzing line lengths in: {file_path.name} ---")
    if not file_path.exists():
        print(f"⚠️  File does not exist. Skipping.")
        return

    max_len = 0
    max_line_num = -1

    with file_path.open("r", encoding="utf-8", errors="ignore") as f:
        for i, line in enumerate(f):
            # Check length of the current line
            current_len = len(line)
            if current_len > max_len:
                max_len = current_len
                max_line_num = i + 1

            # Print progress update
            if (i + 1) % 200000 == 0:
                print(f"  ...scanned {i+1:,} lines...")

    print("\n--- Analysis Complete ---")
    print(f"Total lines scanned: {i+1:,}")
    print(f"Longest line found at line number: {max_line_num:,}")
    print(f"Length of longest line: {max_len:,} characters")

    if max_len > 10_000_000: # 10 million characters
        print("\n🚨 Verdict: Found an extremely large line, which is the likely cause of the RAM crash.")
    elif max_len > 1_000_000: # 1 million characters
        print("\n⚠️ Verdict: Found a very long line. This could be problematic and might be the cause.")
    else:
        print("\n✅ Verdict: No abnormally large lines found. The issue may lie elsewhere.")


# --- Run the analysis on the Nepali training data ---
nepali_file_to_check = cfg.SPLIT_DIR / "nep" / "train.jsonl"
find_longest_line(nepali_file_to_check)

--- Analyzing line lengths in: train.jsonl ---
  ...scanned 200,000 lines...
  ...scanned 400,000 lines...
  ...scanned 600,000 lines...
  ...scanned 800,000 lines...
  ...scanned 1,000,000 lines...
  ...scanned 1,200,000 lines...
  ...scanned 1,400,000 lines...
  ...scanned 1,600,000 lines...
  ...scanned 1,800,000 lines...
  ...scanned 2,000,000 lines...

--- Analysis Complete ---
Total lines scanned: 2,058,000
Longest line found at line number: 1,314,539
Length of longest line: 225,245 characters

✅ Verdict: No abnormally large lines found. The issue may lie elsewhere.


In [None]:
# ==============================================================================
# FINAL DIAGNOSTICS: Deep Scan for the Problematic Record
# ==============================================================================
# This script performs a low-level, memory-safe scan of a JSONL file to
# identify the exact line and error type causing system crashes.
# ==============================================================================
import json
from pathlib import Path

# --- Make sure the Config class from the first cell is available ---
# If running in a new session, re-run the first setup cell.
# cfg = Config()
# --------------------------------------------------------------------------

def deep_scan_jsonl_file(file_path: Path):
    """
    Scans a JSONL file record by record to find specific data errors.

    Checks for:
    1. JSON Formatting Errors (Is the line valid JSON?)
    2. Schema Errors (Does the record have a 'text' field that is a string?)
    3. Encoding Errors (Does the text contain malformed Unicode characters?)
    """
    print(f"--- Starting deep forensic scan of: {file_path.name} ---")
    if not file_path.exists():
        print(f"❌ ERROR: File not found at {file_path}. Cannot proceed.")
        return

    found_errors = 0
    max_errors_to_report = 10 # Stop after finding a few errors

    with file_path.open("r", encoding="utf-8", errors="ignore") as f:
        for line_num, line in enumerate(f, 1):

            if found_errors >= max_errors_to_report:
                print(f"\nStopping scan after finding {max_errors_to_report} errors.")
                break

            # --- Check 1: Is the line valid JSON? ---
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                found_errors += 1
                print("\n" + "="*80)
                print(f"🚨 CRITICAL ERROR FOUND at Line: {line_num:,}")
                print(f"   - Problem: Invalid JSON format. The line could not be parsed.")
                print(f"   - Snippet: {line[:200]}...")
                print("="*80)
                continue # Move to the next line

            # --- Check 2: Does it have a valid 'text' field? ---
            text = record.get("text")
            if text is None:
                found_errors += 1
                print("\n" + "="*80)
                print(f"🚨 CRITICAL ERROR FOUND at Line: {line_num:,}")
                print(f"   - Problem: Schema error. Record is missing the 'text' field.")
                print(f"   - Snippet: {str(record)[:200]}...")
                print("="*80)
                continue

            if not isinstance(text, str):
                found_errors += 1
                print("\n" + "="*80)
                print(f"🚨 CRITICAL ERROR FOUND at Line: {line_num:,}")
                print(f"   - Problem: Schema error. The 'text' field is not a string (it's a {type(text)}).")
                print(f"   - Snippet: {str(record)[:200]}...")
                print("="*80)
                continue

            # --- Check 3: Does the text contain malformed characters? ---
            try:
                # This is a strict test. If it fails, there's a deep encoding issue.
                text.encode('utf-8', 'strict')
            except UnicodeEncodeError as e:
                found_errors += 1
                print("\n" + "="*80)
                print(f"🚨 CRITICAL ERROR FOUND at Line: {line_num:,}")
                print(f"   - Problem: Corrupted Unicode. The text contains a malformed character sequence.")
                print(f"   - Details: {e}")
                # Try to find the problematic character position
                bad_char_pos = e.start
                context_snippet = text[max(0, bad_char_pos-50):min(len(text), bad_char_pos+50)]
                print(f"   - Context: ...{context_snippet}...")
                print("="*80)
                continue

            # --- Progress Update ---
            if line_num % 250000 == 0:
                print(f"   ...scanned {line_num:,} lines. No critical errors found so far.")

    print("\n--- Scan Complete ---")
    if found_errors == 0:
        print("✅ No critical errors were found in the file. The issue is likely environmental.")
    else:
        print(f"Found a total of {found_errors} critical error(s).")

# --- Run the deep scan on the Nepali training data ---
nepali_file_to_check = cfg.SPLIT_DIR / "nep" / "train.jsonl"
deep_scan_jsonl_file(nepali_file_to_check)

--- Starting deep forensic scan of: train.jsonl ---
   ...scanned 250,000 lines. No critical errors found so far.
   ...scanned 500,000 lines. No critical errors found so far.
   ...scanned 750,000 lines. No critical errors found so far.
   ...scanned 1,000,000 lines. No critical errors found so far.
   ...scanned 1,250,000 lines. No critical errors found so far.
   ...scanned 1,500,000 lines. No critical errors found so far.
   ...scanned 1,750,000 lines. No critical errors found so far.
   ...scanned 2,000,000 lines. No critical errors found so far.

--- Scan Complete ---
✅ No critical errors were found in the file. The issue is likely environmental.


In [None]:
# --- installs -----------------------------------------------------------------------------
!pip -q install sentencepiece tokenizers zstandard psutil >/dev/null

import os, sys, io, re, json, shutil, subprocess, math, gzip, lzma, random
from typing import Dict, List, Tuple, Optional

# Reduce thread fan-out on shared VMs
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# --- libs ---------------------------------------------------------------------------------
import sentencepiece as spm
from tokenizers import Tokenizer
from tokenizers.models import Unigram as HFUnigram
from tokenizers import normalizers, pre_tokenizers
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Whitespace, Sequence
from tokenizers import AddedToken
from tokenizers.processors import TemplateProcessing

# ==== CONFIG ==============================================================================
# Drive paths
DRIVE_SPLITS_DIR = "/content/drive/MyDrive/LMA_SLM/data/splits"
DRIVE_RAW_DIR    = "/content/drive/MyDrive/LMA_SLM/data/raw"   # optional fallback

# Output roots (persist on Drive)
OUT_DIR_DRIVE     = "/content/drive/MyDrive/LMA_SLM/tokenizers"
REPORTS_DIR_DRIVE = "/content/drive/MyDrive/LMA_SLM/tokenizer_reports"

# Tokenizer languages
LANGS = ["eng", "hin", "nep"]

# ---- Core RAM fix: train on a byte-budgeted subset ---------------------------------------
TRAIN_BYTE_BUDGET_TOTAL = 120_000_000      # ~120MB UTF-8 text total (tune 80–150MB)
MAX_CHARS_PER_LINE      = 2000             # trim ultra-long lines before counting bytes
ROUND_ROBIN_CHUNK       = 200_000          # read this many lines per language per pass
SHUFFLE_WITHIN_PASS     = False            # set True to shuffle buffers per pass

# Pretok & placeholders
USE_INDIC_PRETOK = True
USE_PLACEHOLDERS = True
ADD_LANG_TAGS    = True

# SPM training params (safe for Colab):
VOCAB_SIZES = [48000, 64000]
CHAR_COVERAGE = 0.9999
NUM_THREADS = 4

# IMPORTANT: Because we pre-sample by bytes, keep this 0 so SPM won't reload more data.
INPUT_SENTENCE_SIZE = 0

# Lighter EM
SEED_SENTENCEPIECE_SIZE = 200_000
NUM_SUB_ITERATIONS      = 1

# Analysis sample per language for metrics
ANALYZE_DOCS_PER_LANG = 8000

# Locked tokens file (optional; one token per line)
LOCK_TOKENS_FILE = ""

# ==========================================================================================

try:
    import zstandard as zstd
except Exception:
    zstd = None

def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)

def mem_mb():
    try:
        import psutil
        vm = psutil.virtual_memory()
        return f"{(vm.total - vm.available)/ (1024*1024):.1f} MB"
    except Exception:
        return "n/a"

def _open_smart(path: str):
    lower = path.lower()
    if lower.endswith(".gz"):
        return gzip.open(path, "rb")
    if lower.endswith(".xz"):
        return lzma.open(path, "rb")
    if lower.endswith(".zst"):
        if zstd is None:
            raise RuntimeError(f"zstandard not installed but needed: {path}")
        dctx = zstd.ZstdDecompressor()
        return dctx.stream_reader(open(path, "rb"))
    return open(path, "rb")

def jsonl_iter_texts(path: str, limit: Optional[int] = None,
                     encoding: str = "utf-8", errors: str = "replace"):
    bad = 0; n = 0
    with _open_smart(path) as fb:
        with io.TextIOWrapper(fb, encoding=encoding, errors=errors, newline="") as f:
            for raw in f:
                s = raw.strip()
                if not s: continue
                try:
                    rec = json.loads(s)
                except Exception:
                    bad += 1
                    if bad <= 3:
                        print(f"[warn] bad JSON in {os.path.basename(path)} (#{bad})")
                    continue
                t = rec.get("text","")
                if t:
                    yield t
                    n += 1
                    if limit and n >= limit:
                        break

# --- light pretok / placeholders -----------------------------------------------------------
def indic_pretokenize(text: str) -> str:
    try:
        import regex as re2
        pieces = re2.split(r"(\p{Z}|\p{P}|\p{S})+", text)
        toks = [p for p in pieces if p and not p.isspace()]
        return " ".join(toks)
    except Exception:
        parts = re.split(r"([,.:;!?\"'()\[\]{}\-–—_/\\]|[\s]+)", text)
        toks = [p for p in parts if p and not p.isspace()]
        return " ".join(toks)

_URL_RE   = re.compile(r'https?://\S+')
_EMAIL_RE = re.compile(r'\b[\w\.-]+@[\w\.-]+\.\w+\b')
_DATE_RE  = re.compile(r'\b\d{1,4}([./-]\d{1,2}){1,2}\b')
_NUM_RE   = re.compile(r'\b\d+\b')

def apply_placeholders(text: str) -> str:
    t = _URL_RE.sub("<URL>", text)
    t = _EMAIL_RE.sub("<EMAIL>", t)
    t = _DATE_RE.sub("<DATE>", t)
    t = _NUM_RE.sub("<NUM>", t)
    return t

LANG_TAGS = {"eng": "<eng>", "hin": "<hin>", "nep": "<nep>"}

def pick_train_file(lang: str, splits_dir: str, raw_dir: str) -> str:
    p1 = os.path.join(splits_dir, lang, "train.jsonl")
    p2 = os.path.join(raw_dir, f"{lang}.jsonl")
    if os.path.exists(p1):
        print(f"[src:{lang}] {p1}")
        return p1
    print(f"[src:{lang}] (fallback) {p2}")
    return p2

def pick_eval_files(lang: str, splits_dir: str, raw_dir: str) -> List[str]:
    cand = [
        os.path.join(splits_dir, lang, "val.jsonl"),
        os.path.join(splits_dir, lang, "train.jsonl"),
        os.path.join(raw_dir, f"{lang}.jsonl"),
    ]
    seen, out = set(), []
    for p in cand:
        if os.path.exists(p) and p not in seen:
            seen.add(p); out.append(p)
    return out

def format_num(n: float) -> str:
    for unit, div in (("T",1e12),("B",1e9),("M",1e6),("K",1e3)):
        if n >= div: return f"{n/div:.2f}{unit}"
    return f"{n:.0f}"

def load_lock_tokens(path: Optional[str]) -> List[str]:
    if not path: return []
    if not os.path.exists(path):
        print(f"[lock] not found: {path}"); return []
    toks=[]
    with open(path,"r",encoding="utf-8") as f:
        for line in f:
            s=line.strip()
            if s and not s.startswith("#"):
                toks.append(s)
    print(f"[lock] loaded {len(toks)} tokens")
    return toks

# --- NEW: build byte-budgeted subset (per language quota) --------------------------------
def build_byte_budget_subset(splits_dir: str, raw_dir: str, langs: List[str],
                             byte_budget_total: int, max_chars_per_line: int,
                             use_indic_pretok: bool, use_placeholders: bool,
                             add_lang_tags: bool, out_file: str):
    ensure_dir(os.path.dirname(out_file))
    per_lang_quota = byte_budget_total // max(1,len(langs))
    totals: Dict[str,int] = {l:0 for l in langs}
    written: Dict[str,int] = {l:0 for l in langs}
    with open(out_file, "w", encoding="utf-8") as out:
        print(f"[subset] byte budget total={format_num(byte_budget_total)} (~UTF-8), per-lang≈{format_num(per_lang_quota)}")
        # Round-robin over languages with bounded passes to avoid scanning entire corpora
        active = True
        passes = 0
        while active:
            active = False
            passes += 1
            for lang in langs:
                if totals[lang] >= per_lang_quota:
                    continue
                src = pick_train_file(lang, splits_dir, raw_dir)
                # read a chunk of lines for this pass
                took = 0
                lines = []
                for t in jsonl_iter_texts(src, limit=ROUND_ROBIN_CHUNK):
                    if use_placeholders:
                        t = apply_placeholders(t)
                    if use_indic_pretok:
                        t = indic_pretokenize(t)
                    if add_lang_tags:
                        t = f"{LANG_TAGS.get(lang, f'<{lang}>')} {t}"
                    if max_chars_per_line and len(t) > max_chars_per_line:
                        t = t[:max_chars_per_line]
                    lines.append(t)
                if not lines:
                    continue
                if SHUFFLE_WITHIN_PASS:
                    random.shuffle(lines)
                for t in lines:
                    b = len(t.encode("utf-8")) + 1  # +1 for newline
                    if totals[lang] + b > per_lang_quota:
                        break
                    out.write(t.replace("\n"," ") + "\n")
                    totals[lang] += b
                    written[lang] += 1
                    took += 1
                if took:
                    active = True
                    print(f"[subset:{lang}] +{took} lines (pass {passes}) bytes={format_num(totals[lang])} mem={mem_mb()}")
            # stop if no language could add more
        print("[subset] DONE per-lang bytes:", {k: format_num(v) for k,v in totals.items()})
        print("[subset] lines written:", {k: written[k] for k in langs})
    return written, totals

# --- SPM train -------------------------------------------------------------------------------
def train_sp_unigram(corpus_path: str, model_prefix: str, vocab_size: int):
    print(f"[spm] train | vocab={vocab_size} char_cov={CHAR_COVERAGE} threads={NUM_THREADS} iss={INPUT_SENTENCE_SIZE}")
    # prepare symbols
    lang_syms = list(LANG_TAGS.values()) if ADD_LANG_TAGS else []
    placeholders = ["<URL>","<EMAIL>","<NUM>","<DATE>"] if USE_PLACEHOLDERS else []
    locked = load_lock_tokens(LOCK_TOKENS_FILE)
    user_defined = lang_syms + placeholders + locked

    spm.SentencePieceTrainer.Train(
        input=corpus_path,
        model_prefix=model_prefix,
        model_type="unigram",
        vocab_size=vocab_size,
        character_coverage=CHAR_COVERAGE,
        num_threads=max(1, NUM_THREADS),

        # CRUCIAL: we already pre-sampled by bytes; don't let SPM re-sample.
        input_sentence_size=max(0, INPUT_SENTENCE_SIZE),

        # Lighter EM to keep memory in check
        seed_sentencepiece_size=max(100000, SEED_SENTENCEPIECE_SIZE),
        num_sub_iterations=max(1, NUM_SUB_ITERATIONS),

        # Robustness/settings
        byte_fallback=True,
        treat_whitespace_as_suffix=True,
        remove_extra_whitespaces=True,
        max_sentence_length=4096,   # keep long lines bounded further

        user_defined_symbols=",".join(user_defined) if user_defined else "",
        hard_vocab_limit=False,
        train_extremely_large_corpus=True,  # harmless when data is small
        self_test_sample_size=0
    )
    print(f"[spm] wrote {model_prefix}.model / .vocab | mem={mem_mb()}")

# --- HF export -----------------------------------------------------------------------------
def load_sp_vocab(vocab_path: str) -> Tuple[List[Tuple[str, float]], int]:
    vocab=[]; unk_id=0
    with open(vocab_path,"r",encoding="utf-8") as f:
        for idx, line in enumerate(f):
            s=line.rstrip("\n")
            if not s: continue
            piece, score = s.split("\t")
            vocab.append((piece, float(score)))
            if piece == "<unk>":
                unk_id = idx
    return vocab, unk_id

def export_hf(sp_vocab: str, sp_model: str, out_dir: str):
    os.makedirs(out_dir, exist_ok=True)
    vocab, unk_id = load_sp_vocab(sp_vocab)
    tok = Tokenizer(HFUnigram(vocab=vocab, unk_id=unk_id))
    tok.normalizer = normalizers.Sequence([NFKC()])
    tok.pre_tokenizer = Sequence([Whitespace()])

    sentinels=[]  # set if you want <extra_id_*>
    lang_syms = list(LANG_TAGS.values()) if ADD_LANG_TAGS else []
    placeholders = ["<URL>","<EMAIL>","<NUM>","<DATE>"] if USE_PLACEHOLDERS else []
    locked = load_lock_tokens(LOCK_TOKENS_FILE)

    specials = ["<unk>","<s>","</s>"] + sentinels
    tok.add_special_tokens([AddedToken(s, single_word=False, normalized=False) for s in specials])
    tok.add_tokens([AddedToken(s, single_word=False, normalized=False) for s in (lang_syms + placeholders + locked)])

    tok.post_processor = TemplateProcessing(
        single="<s> $A </s>", pair="<s> $A </s> </s> $B </s>",
        special_tokens=[("<s>", 0), ("</s>", 0)]
    )

    out_json = os.path.join(out_dir, "tokenizer.json")
    tok.save(out_json)
    with open(os.path.join(out_dir, "special_tokens_map.json"), "w", encoding="utf-8") as f:
        json.dump({"unk_token":"<unk>","bos_token":"<s>","eos_token":"</s>",
                   "additional_special_tokens": sentinels + lang_syms + placeholders + locked}, f, indent=2)
    with open(os.path.join(out_dir, "tokenizer_config.json"), "w", encoding="utf-8") as f:
        json.dump({"model_max_length":2048,"unk_token":"<unk>","bos_token":"<s>","eos_token":"</s>",
                   "special_tokens_map_file":"special_tokens_map.json"}, f, indent=2)
    print(f"[hf] wrote {out_json} (+sidecars)")

# --- Metrics (streaming) -------------------------------------------------------------------
def _is_byte_piece(sp: spm.SentencePieceProcessor, tid: int) -> bool:
    p = sp.IdToPiece(tid)
    return p.startswith("<0x") and p.endswith(">")

def compute_metrics(sp_model_path: str, files_per_lang: Dict[str, List[str]], max_docs_per_lang: int):
    sp = spm.SentencePieceProcessor(); sp.Load(sp_model_path)
    import math
    bins = [0,8,16,32,64,128,256,512,1024,2048,4096,8192,16384]
    def bini(x):
        for i, b in enumerate(bins):
            if x <= b: return i
        return len(bins)-1

    results={}
    for lang, files in files_per_lang.items():
        total_bytes = total_tokens = docs_seen = 0
        mean = 0.0; m2 = 0.0; min_len = 10**9; max_len = 0
        hist = [0]*len(bins); byte_tok = 0
        for p in files:
            if not os.path.exists(p): continue
            for t in jsonl_iter_texts(p, limit=None):
                b = len(t.encode("utf-8"))
                ids = sp.EncodeAsIds(t)
                L = len(ids)
                total_bytes += b; total_tokens += L
                byte_tok += sum(1 for tid in ids if _is_byte_piece(sp, tid))
                docs_seen += 1
                d = L - mean; mean += d / docs_seen; m2 += d * (L - mean)
                min_len = min(min_len, L); max_len = max(max_len, L)
                hist[bini(L)] += 1
                if max_docs_per_lang and docs_seen >= max_docs_per_lang:
                    break
            if max_docs_per_lang and docs_seen >= max_docs_per_lang:
                break
        if total_tokens == 0:
            print(f"[metrics:{lang}] no tokens measured.")
            results[lang] = {"bytes_per_token": float("inf")}
            continue
        bpt = total_bytes / total_tokens
        var = (m2 / (docs_seen-1)) if docs_seen > 1 else 0.0
        std = math.sqrt(var)
        # approx percentiles from hist
        def pct(p):
            target = math.ceil(p*docs_seen); c=0
            for i,cnt in enumerate(hist):
                c += cnt
                if c >= target: return bins[i]
            return bins[-1]
        p50,p90,p99 = pct(0.5), pct(0.9), pct(0.99)
        byte_rate = byte_tok / max(1,total_tokens)
        print(f"[metrics:{lang}] docs={docs_seen:,} tokens={format_num(total_tokens)} bytes={format_num(total_bytes)}")
        print(f"  - bytes/token = {bpt:.3f}")
        print(f"  - len(tokens/doc): min={min_len} max={max_len} mean={mean:.1f} std={std:.1f} p50≈{p50} p90≈{p90} p99≈{p99}")
        print(f"  - byte_fallback rate = {byte_rate*100:.3f}%")
        results[lang] = {"docs_seen":docs_seen,"total_bytes":int(total_bytes),"total_tokens":int(total_tokens),
                         "bytes_per_token":float(bpt),"len_min":int(min_len),"len_max":int(max_len),
                         "len_mean":float(mean),"len_std":float(std),
                         "len_p50_approx":int(p50),"len_p90_approx":int(p90),"len_p99_approx":int(p99),
                         "byte_fallback_rate":float(byte_rate)}
    return results

# ==== Pipeline =============================================================================
# from google.colab import drive
# drive.mount('/content/drive')

ensure_dir(OUT_DIR_DRIVE); ensure_dir(REPORTS_DIR_DRIVE)
subset_path = "/content/train_subset.txt"

# 1) Build byte-budgeted subset
written, totals = build_byte_budget_subset(
    DRIVE_SPLITS_DIR, DRIVE_RAW_DIR, LANGS,
    TRAIN_BYTE_BUDGET_TOTAL, MAX_CHARS_PER_LINE,
    USE_INDIC_PRETOK, USE_PLACEHOLDERS, ADD_LANG_TAGS,
    out_file=subset_path
)
print(f"[subset] wrote -> {subset_path} (mem={mem_mb()})")

# 2) Train & export per vocab
tok_jsons=[]
for vs in VOCAB_SIZES:
    prefix = os.path.join(OUT_DIR_DRIVE, f"sp_unigram_{vs}")
    model_path = prefix + ".model"; vocab_path = prefix + ".vocab"
    train_sp_unigram(subset_path, prefix, vs)
    export_hf(vocab_path, model_path, os.path.join(OUT_DIR_DRIVE, f"sp_unigram_{vs}"))
    tok_jsons.append(os.path.join(OUT_DIR_DRIVE, f"sp_unigram_{vs}", "tokenizer.json"))

    # 3) quick metrics on real files (streaming)
    files_per_lang = {lang: pick_eval_files(lang, DRIVE_SPLITS_DIR, DRIVE_RAW_DIR) for lang in LANGS}
    metrics = compute_metrics(model_path, files_per_lang, ANALYZE_DOCS_PER_LANG)

    # 4) manifest
    manifest = {
        "langs": LANGS,
        "vocab_size": vs,
        "character_coverage": CHAR_COVERAGE,
        "use_indic_pretok": USE_INDIC_PRETOK,
        "use_placeholders": USE_PLACEHOLDERS,
        "add_lang_tags": ADD_LANG_TAGS,
        "input_sentence_size": INPUT_SENTENCE_SIZE,
        "seed_sentencepiece_size": SEED_SENTENCEPIECE_SIZE,
        "num_sub_iterations": NUM_SUB_ITERATIONS,
        "byte_budget_total": TRAIN_BYTE_BUDGET_TOTAL,
        "max_chars_per_line": MAX_CHARS_PER_LINE,
        "subset_bytes_per_lang": totals,
        "subset_lines_per_lang": written,
        "subset_path": subset_path,
        "sp_model": model_path,
        "hf_tokenizer_json": os.path.join(OUT_DIR_DRIVE, f"sp_unigram_{vs}", "tokenizer.json"),
        "metrics": metrics
    }
    mf_path = os.path.join(REPORTS_DIR_DRIVE, f"tokenizer_manifest_sp{vs}.json")
    with open(mf_path,"w",encoding="utf-8") as f:
        json.dump(manifest, f, indent=2, ensure_ascii=False)
    print(f"[manifest] wrote {mf_path} | mem={mem_mb()}")

print("\n[done] Artifacts:", OUT_DIR_DRIVE)
print("[done] Reports:", REPORTS_DIR_DRIVE)

[subset] byte budget total=120.00M (~UTF-8), per-lang≈40.00M
[src:eng] /content/drive/MyDrive/LMA_SLM/data/splits/eng/train.jsonl
[subset:eng] +25481 lines (pass 1) bytes=40.00M mem=1403.7 MB
[src:hin] /content/drive/MyDrive/LMA_SLM/data/splits/hin/train.jsonl
[subset:hin] +10826 lines (pass 1) bytes=40.00M mem=1506.3 MB
[src:nep] /content/drive/MyDrive/LMA_SLM/data/splits/nep/train.jsonl
[subset:nep] +11786 lines (pass 1) bytes=40.00M mem=1502.2 MB
[src:eng] /content/drive/MyDrive/LMA_SLM/data/splits/eng/train.jsonl
[src:hin] /content/drive/MyDrive/LMA_SLM/data/splits/hin/train.jsonl
[src:nep] /content/drive/MyDrive/LMA_SLM/data/splits/nep/train.jsonl
[subset] DONE per-lang bytes: {'eng': '40.00M', 'hin': '40.00M', 'nep': '40.00M'}
[subset] lines written: {'eng': 25481, 'hin': 10826, 'nep': 11786}
[subset] wrote -> /content/train_subset.txt (mem=1489.2 MB)
[spm] train | vocab=48000 char_cov=0.9999 threads=4 iss=0
[spm] wrote /content/drive/MyDrive/LMA_SLM/tokenizers/sp_unigram_48000.m

In [None]:
print("Hi")

Hi


In [None]:
# ======================= Tokenizer & Data QA (no plots, RAM friendly) ===========================
# What it does:
# - Loads the SentencePiece model (.model) AND HF tokenizer.json (for consistency sanity checks)
# - Streams each language's split files (val/train/raw fallbacks), no big arrays
# - Computes:
#     * bytes/token, token-length stats (min/mean/p50/p90/p99/max)
#     * byte_fallback usage and which byte pieces are most frequent
#     * placeholder coverage & correctness: <URL>, <EMAIL>, <NUM>, <DATE>
#     * language tag coverage (<eng>/<hin>/<nep>) and position correctness (should be 1st)
#     * special tokens integrity (<unk>, <s>, </s>)
#     * character inventory coverage: which Unicode chars appear in data but only tokenize via bytes
#     * “fertility”: tokens per whitespace-separated word; top over-segmented words
#     * top-N pieces overall & per language, suspicious pieces (very short/long/whitespace)
#     * round-trip encode→decode→encode consistency checks (sampled)
#


import os, io, json, re, math, gzip, lzma, unicodedata
from collections import Counter, defaultdict
from typing import List, Dict, Tuple, Optional

# ------------------------ CONFIG (edit paths if needed) -----------------------------------------
# Point to the SAME drive locations you used earlier
DRIVE_SPLITS_DIR = "/content/drive/MyDrive/LMA_SLM/data/splits"
DRIVE_RAW_DIR    = "/content/drive/MyDrive/LMA_SLM/data/raw"      # optional fallback

# Pick which tokenizer variant to inspect
TOK_DIR = "/content/drive/MyDrive/LMA_SLM/tokenizers/sp_unigram_64000"  # or sp_unigram_48000
SPM_MODEL = "/content/drive/MyDrive/LMA_SLM/tokenizers/sp_unigram_64000.model"

# Languages present
LANGS = ["eng", "hin", "nep"]
LANG_TAGS = {"eng": "<eng>", "hin": "<hin>", "nep": "<nep>"}

# Streaming/sample limits for analysis
DOCS_PER_LANG = 8000                 # how many JSONL records to scan per language (val/train/raw combined)
PRINT_TOP_N   = 40                   # how many top pieces/words to show
MAX_CHARS_PER_LINE = 10000           # trim pathological lines to avoid skew

# Regexes reused (same as training stage)
_URL_RE   = re.compile(r'https?://\S+')
_EMAIL_RE = re.compile(r'\b[\w\.-]+@[\w\.-]+\.\w+\b')
_DATE_RE  = re.compile(r'\b\d{1,4}([./-]\d{1,2}){1,2}\b')
_NUM_RE   = re.compile(r'\b\d+\b')

PLACEHOLDERS = ["<URL>", "<EMAIL>", "<NUM>", "<DATE>"]
SPECIALS     = ["<unk>", "<s>", "</s>"]

# ------------------------------------------------------------------------------------------------

# libs
import sentencepiece as spm
from tokenizers import Tokenizer

try:
    import zstandard as zstd
except Exception:
    zstd = None

def _open_smart(path: str):
    lower = path.lower()
    if lower.endswith(".gz"):
        return gzip.open(path, "rb")
    if lower.endswith(".xz"):
        return lzma.open(path, "rb")
    if lower.endswith(".zst"):
        if zstd is None:
            raise RuntimeError(f"zstandard needed for: {path}")
        dctx = zstd.ZstdDecompressor()
        return dctx.stream_reader(open(path, "rb"))
    return open(path, "rb")

def jsonl_iter_texts(path: str, limit: Optional[int] = None,
                     encoding: str = "utf-8", errors: str = "replace"):
    """Stream `text` field from JSONL; skips bad rows; trims very long lines."""
    bad = 0; n = 0
    with _open_smart(path) as fb:
        with io.TextIOWrapper(fb, encoding=encoding, errors=errors, newline="") as f:
            for raw in f:
                s = raw.strip()
                if not s: continue
                try:
                    rec = json.loads(s)
                except Exception:
                    bad += 1
                    if bad <= 3:
                        print(f"[warn] bad JSON in {os.path.basename(path)} (#{bad})")
                    continue
                t = rec.get("text", "")
                if not t: continue
                if MAX_CHARS_PER_LINE and len(t) > MAX_CHARS_PER_LINE:
                    t = t[:MAX_CHARS_PER_LINE]
                yield t
                n += 1
                if limit and n >= limit:
                    break

def pick_eval_files(lang: str, splits_dir: str, raw_dir: str) -> List[str]:
    cand = [
        os.path.join(splits_dir, lang, "val.jsonl"),
        os.path.join(splits_dir, lang, "train.jsonl"),
        os.path.join(raw_dir, f"{lang}.jsonl"),
    ]
    seen, out = set(), []
    for p in cand:
        if os.path.exists(p) and p not in seen:
            seen.add(p); out.append(p)
    return out

# ------------------------------- helpers ---------------------------------------------------------
def format_num(n: float) -> str:
    for unit, div in (("T",1e12),("B",1e9),("M",1e6),("K",1e3)):
        if n >= div: return f"{n/div:.2f}{unit}"
    return f"{n:.0f}"

def _is_byte_piece(piece: str) -> bool:
    return piece.startswith("<0x") and piece.endswith(">")

def tokens_per_word(sp: spm.SentencePieceProcessor, text: str) -> Tuple[int,int]:
    """Return (num_tokens, num_words) with a simple whitespace word split."""
    ids = sp.EncodeAsIds(text)
    words = [w for w in text.split() if w]
    return len(ids), len(words)

def encode_decode_ok(sp: spm.SentencePieceProcessor, text: str) -> bool:
    """Round-trip check: encode->decode->encode produces same id sequence."""
    a = sp.EncodeAsIds(text)
    d = sp.DecodeIds(a)
    b = sp.EncodeAsIds(d)
    return a == b

def top_n(counter: Counter, n: int) -> List[Tuple[str,int]]:
    return counter.most_common(n)

# --------------------------- main analysis -------------------------------------------------------
def analyze_tokenizer_and_data(
    spm_model_path: str,
    hf_tokenizer_json_dir: str,
    langs: List[str],
    splits_dir: str,
    raw_dir: str,
    docs_per_lang: int,
    print_top_n: int
):
    print(f"[load] SPM model: {spm_model_path}")
    sp = spm.SentencePieceProcessor()
    sp.Load(spm_model_path)

    hf_json = os.path.join(hf_tokenizer_json_dir, "tokenizer.json")
    if os.path.exists(hf_json):
        print(f"[load] HF tokenizer: {hf_json}")
        tok_hf = Tokenizer.from_file(hf_json)
    else:
        tok_hf = None
        print("[load] HF tokenizer.json not found (skipping HF-specific checks).")

    # ----- global vocab overview -----
    vocab_size = sp.GetPieceSize()
    pieces = [sp.IdToPiece(i) for i in range(vocab_size)]
    byte_pieces = sum(1 for p in pieces if _is_byte_piece(p))
    specials_present = {sp.IdToPiece(i) for i in range(min(10, vocab_size)) if sp.IdToPiece(i) in set(SPECIALS)}
    lang_syms_present = [p for p in pieces if p in set(LANG_TAGS.values())]
    placeholders_present = [p for p in pieces if p in set(PLACEHOLDERS)]

    print("\n[vocab]")
    print(f"  - vocab_size={vocab_size}")
    print(f"  - byte_pieces={byte_pieces}  (should be 256 when byte_fallback=True)")
    print(f"  - specials found at low ids: {sorted(list(specials_present))}")
    print(f"  - language tags present: {lang_syms_present}")
    print(f"  - placeholders present: {placeholders_present}")

    # Identify suspicious pieces (leading/trailing space, all-punct, very long)
    susp = []
    for p in pieces:
        if len(p) > 24:
            susp.append(("long", p))
        if p.strip() != p and p not in SPECIALS:
            susp.append(("spacey", repr(p)))
        if all(ch.isspace() or unicodedata.category(ch)[0] in ("P","S") for ch in p) and p not in SPECIALS and not _is_byte_piece(p):
            susp.append(("punct/sym-only", p))
    if susp:
        print("\n[vocab suspicious pieces] (showing up to 20)")
        for k, p in susp[:20]:
            print(f"  - {k}: {p}")
    else:
        print("\n[vocab suspicious pieces] none")

    # ----- per-language streaming stats -----
    overall_piece_freq = Counter()
    per_lang_piece_freq: Dict[str, Counter] = {l: Counter() for l in langs}
    byte_piece_freq = Counter()
    word_overseg = Counter()  # words with high fertility

    bins = [0,8,16,32,64,128,256,512,1024,2048,4096,8192,16384]
    def bini(x):
        for i,b in enumerate(bins):
            if x <= b: return i
        return len(bins)-1

    print("\n[per-language metrics]")
    per_lang_results = {}
    for lang in langs:
        files = pick_eval_files(lang, splits_dir, raw_dir)
        if not files:
            print(f"  - {lang}: no files found; skipping.")
            continue

        docs = 0
        total_bytes = total_tokens = 0
        min_len = 10**9; max_len = 0; mean = 0.0; m2 = 0.0
        hist = [0]*len(bins)
        byte_tok = 0

        # coverage checks
        char_inventory = Counter()
        byte_only_chars = Counter()  # chars for which all pieces were byte_fallback
        placeholder_hits = Counter()
        lang_tag_ok = 0
        roundtrip_ok = 0
        roundtrip_bad = 0

        for p in files:
            for text in jsonl_iter_texts(p, limit=None):
                # Keep character inventory
                char_inventory.update(list(text))

                # Placeholders present?
                ph = {
                    "<URL>": bool(_URL_RE.search(text)),
                    "<EMAIL>": bool(_EMAIL_RE.search(text)),
                    "<DATE>": bool(_DATE_RE.search(text)),
                    "<NUM>": bool(_NUM_RE.search(text)),
                }
                for k,v in ph.items():
                    if v: placeholder_hits[k] += 1

                ids = sp.EncodeAsIds(text)
                L = len(ids)
                total_bytes += len(text.encode("utf-8"))
                total_tokens += L
                min_len = min(min_len, L); max_len = max(max_len, L)
                d = L - mean; mean += d / (docs+1); m2 += d*(L-mean)
                hist[bini(L)] += 1
                docs += 1

                # piece frequencies
                for tid in ids:
                    piece = sp.IdToPiece(tid)
                    overall_piece_freq[piece] += 1
                    per_lang_piece_freq[lang][piece] += 1
                    if _is_byte_piece(piece):
                        byte_tok += 1

                # language tag correctness (first token)
                pieces_this = [sp.IdToPiece(tid) for tid in ids[:2]]  # first couple tokens
                if LANG_TAGS.get(lang) in pieces_this[:1]:
                    lang_tag_ok += 1

                # fertility (tokens per whitespace word)
                t_count, w_count = tokens_per_word(sp, text)
                if w_count > 0:
                    # crude average; flag extreme ratios
                    fert = t_count / w_count
                    if fert > 4.0 and w_count >= 4:
                        # store a few culprit words by over-segmentation
                        for w in text.split():
                            if len(w) >= 6:
                                # re-encode word alone
                                wl = len(sp.EncodeAsIds(w))
                                if wl >= 5:
                                    word_overseg[w.lower()] += 1

                # round-trip check (sample ~1 in 50 docs)
                if (docs % 50) == 1:
                    if encode_decode_ok(sp, text):
                        roundtrip_ok += 1
                    else:
                        roundtrip_bad += 1

                if docs_per_lang and docs >= docs_per_lang:
                    break
            if docs_per_lang and docs >= docs_per_lang:
                break

        if total_tokens == 0:
            print(f"  - {lang}: no tokens collected")
            continue

        bpt = total_bytes / total_tokens
        var = (m2 / (docs-1)) if docs > 1 else 0.0
        std = math.sqrt(var)
        def pct(pp):
            target = math.ceil(pp*docs); c=0
            for i,cnt in enumerate(hist):
                c += cnt
                if c >= target: return bins[i]
            return bins[-1]
        p50,p90,p99 = pct(0.5), pct(0.9), pct(0.99)
        byte_rate = byte_tok / max(1,total_tokens)

        # which chars require byte fallback?
        # We approximate: if any piece for that char is a non-byte piece, it's "covered".
        covered_chars = set()
        for ch, cnt in char_inventory.items():
            # encode single char and see if any non-byte piece appears
            ids_ch = sp.EncodeAsIds(ch)
            if any(not _is_byte_piece(sp.IdToPiece(tid)) for tid in ids_ch):
                covered_chars.add(ch)
        for ch, cnt in char_inventory.items():
            if ch not in covered_chars:
                byte_only_chars[ch] += cnt

        print(f"\n  [{lang}] docs={docs:,} tokens={format_num(total_tokens)} bytes={format_num(total_bytes)}")
        print(f"    - bytes/token = {bpt:.3f}  | byte_fallback rate = {byte_rate*100:.3f}%")
        print(f"    - len(tokens/doc): min={min_len} max={max_len} mean={mean:.1f} std={std:.1f} p50≈{p50} p90≈{p90} p99≈{p99}")
        print(f"    - lang tag first-token OK in {lang_tag_ok}/{docs} docs ({lang_tag_ok/docs*100:.1f}%)")
        if roundtrip_ok+roundtrip_bad > 0:
            print(f"    - round-trip checks: ok={roundtrip_ok} bad={roundtrip_bad}  (target: 100% ok; some bad is usually data noise)")

        ph_total_docs = max(1, docs)
        for ph in PLACEHOLDERS:
            if placeholder_hits[ph]:
                print(f"    - placeholder '{ph}' seen in {placeholder_hits[ph]} docs "
                      f"({placeholder_hits[ph]/ph_total_docs*100:.2f}%); "
                      f"inspect tokens manually if suspicious.")

        # Save for summary
        per_lang_results[lang] = {
            "docs": docs,
            "bytes": total_bytes,
            "tokens": total_tokens,
            "bytes_per_token": bpt,
            "byte_fallback_rate": byte_rate,
            "len_min": min_len, "len_max": max_len,
            "len_mean": mean, "len_std": std,
            "len_p50": p50, "len_p90": p90, "len_p99": p99,
            "lang_tag_first_ok": lang_tag_ok,
            "roundtrip_ok": roundtrip_ok, "roundtrip_bad": roundtrip_bad,
        }

        # top pieces for this lang (excluding byte pieces for readability)
        top_lang = [(p,c) for p,c in per_lang_piece_freq[lang].most_common() if not _is_byte_piece(p)]
        print(f"    - top pieces (no-bytes) [{lang}] top{min(print_top_n, len(top_lang))}:")
        for p,c in top_lang[:print_top_n]:
            print(f"        {p}\t{c}")

        # frequent byte pieces (if any)
        for p,c in per_lang_piece_freq[lang].most_common():
            if _is_byte_piece(p):
                byte_piece_freq[p] += c

        # top byte-only chars (bad if common)
        if byte_only_chars:
            top_boc = sorted(byte_only_chars.items(), key=lambda kv: kv[1], reverse=True)[:10]
            print(f"    - chars falling back to bytes (top 10):")
            for ch, cnt in top_boc:
                name = unicodedata.name(ch, "UNKNOWN")
                cp = f"U+{ord(ch):04X}"
                printable = ch if not ch.isspace() else repr(ch)
                print(f"        {printable} ({cp} {name}) x{cnt}")

        # worst over-segmented words (heuristic)
        if word_overseg:
            print(f"    - over-segmentation suspects (words often split into >=5 tokens):")
            for w, c in word_overseg.most_common(10):
                print(f"        {w}\t{c}")

    # ----- global top pieces overall -----
    print("\n[top pieces overall (no-bytes)]")
    for p,c in [(p,c) for p,c in overall_piece_freq.most_common() if not _is_byte_piece(p)][:print_top_n]:
        print(f"  {p}\t{c}")

    if byte_piece_freq:
        print("\n[top byte pieces overall]")
        for p,c in byte_piece_freq.most_common(20):
            print(f"  {p}\t{c}")

    # ----- HF cross-checks (optional) -----
    if tok_hf:
        print("\n[HF cross-check]")
        sample = " ".join([LANG_TAGS.get("hin","<hin>"), "नमस्ते", "world!", "<URL>", "2024-01-01", "<EMAIL>", "<NUM> 42"])
        ids_sp = sp.EncodeAsIds(sample)
        ids_hf = tok_hf.encode(sample).ids
        print(f"  - sample: {sample}")
        print(f"  - SPM ids: {len(ids_sp)} tokens")
        print(f"  - HF  ids: {len(ids_hf)} tokens")
        if len(ids_sp) != len(ids_hf):
            print("  [warn] HF vs SPM token counts differ. Acceptable if HF normalizer differs; inspect if large.")
        else:
            print("  [ok] HF and SPM agree on token count for sample.")

    # ----- summary JSON to reports dir -----
    reports_dir = "/content/drive/MyDrive/LMA_SLM/tokenizer_reports"
    os.makedirs(reports_dir, exist_ok=True)
    out_summary = os.path.join(reports_dir, "tokenizer_data_QA_summary.json")
    with open(out_summary, "w", encoding="utf-8") as f:
        json.dump({
            "tokenizer_dir": hf_tokenizer_json_dir,
            "spm_model": spm_model_path,
            "per_language": per_lang_results,
            "top_pieces_overall": top_n(overall_piece_freq, print_top_n),
            "top_byte_pieces": top_n(byte_piece_freq, min(20, len(byte_piece_freq))),
            "specials_present": list(sorted(specials_present)) if 'specials_present' in locals() else [],
            "placeholders_present": placeholders_present,
            "lang_tags_present": lang_syms_present
        }, f, ensure_ascii=False, indent=2)
    print(f"\n[done] wrote summary -> {out_summary}")

# --------------------------- run ---------------------------------------------------------------
analyze_tokenizer_and_data(
    spm_model_path=SPM_MODEL,
    hf_tokenizer_json_dir=TOK_DIR,
    langs=LANGS,
    splits_dir=DRIVE_SPLITS_DIR,
    raw_dir=DRIVE_RAW_DIR,
    docs_per_lang=DOCS_PER_LANG,
    print_top_n=PRINT_TOP_N
)
# ===============================================================================================

[load] SPM model: /content/drive/MyDrive/LMA_SLM/tokenizers/sp_unigram_64000.model
[load] HF tokenizer: /content/drive/MyDrive/LMA_SLM/tokenizers/sp_unigram_64000/tokenizer.json

[vocab]
  - vocab_size=64000
  - byte_pieces=256  (should be 256 when byte_fallback=True)
  - specials found at low ids: ['</s>', '<s>', '<unk>']
  - language tags present: ['<eng>', '<hin>', '<nep>']
  - placeholders present: ['<URL>', '<EMAIL>', '<NUM>', '<DATE>']

[vocab suspicious pieces] (showing up to 20)
  - punct/sym-only: ▁
  - punct/sym-only: <▁
  - punct/sym-only: .▁
  - punct/sym-only: -▁
  - punct/sym-only: '▁
  - punct/sym-only: ।▁
  - punct/sym-only: ’▁
  - punct/sym-only: (▁
  - punct/sym-only: "▁
  - punct/sym-only: “▁
  - punct/sym-only: )▁
  - punct/sym-only: :▁
  - punct/sym-only: ?▁
  - punct/sym-only: ,▁
  - punct/sym-only: ‘▁
  - punct/sym-only: !▁
  - punct/sym-only: ...▁
  - punct/sym-only: ·
  - punct/sym-only: ÷
  - punct/sym-only: °

[per-language metrics]

  [eng] docs=8,000 tokens

In [None]:
# ======================= Fast Token Counts per Language (SSD + Batching + Checkpoints) =======================
import os, io, json, shutil, time, gzip, lzma
from typing import List, Optional, Tuple
import sentencepiece as spm

# --------- CONFIG (edit paths if needed) ----------
DRIVE_SPLITS_DIR = "/content/drive/MyDrive/LMA_SLM/data/splits"
DRIVE_RAW_DIR    = "/content/drive/MyDrive/LMA_SLM/data/raw"   # fallback if split missing
SPM_MODEL        = "/content/drive/MyDrive/LMA_SLM/tokenizers/sp_unigram_64000.model"

LANGS = ["eng", "hin", "nep"]

# Performance knobs
BATCH_LINES      = 4000     # lines per encode batch (increase if RAM allows)
REPORT_EVERY_TOK = 2_000_000
REPORT_EVERY_SEC = 30
COPY_TO_LOCAL    = True     # True = copy Drive files to /content first (recommended)

# Checkpointing
CKPT_DIR = "/content/token_count_ckpts"
os.makedirs(CKPT_DIR, exist_ok=True)

# ------------ helpers ------------
try:
    import zstandard as zstd
except Exception:
    zstd = None

def _open_smart(path: str):
    lower = path.lower()
    if lower.endswith(".gz"):
        return gzip.open(path, "rb")
    if lower.endswith(".xz"):
        return lzma.open(path, "rb")
    if lower.endswith(".zst"):
        if zstd is None:
            raise RuntimeError("zstandard needed for .zst files")
        dctx = zstd.ZstdDecompressor()
        return dctx.stream_reader(open(path, "rb"))
    return open(path, "rb")

def pick_files_for_lang(lang: str) -> List[str]:
    cand = [
        os.path.join(DRIVE_SPLITS_DIR, lang, "val.jsonl"),
        os.path.join(DRIVE_SPLITS_DIR, lang, "train.jsonl"),
        os.path.join(DRIVE_RAW_DIR, f"{lang}.jsonl"),
    ]
    seen, out = set(), []
    for p in cand:
        if os.path.exists(p) and p not in seen:
            seen.add(p); out.append(p)
    return out

def copy_to_local_if_needed(files: List[str], lang: str) -> List[str]:
    if not COPY_TO_LOCAL:
        return files
    dst_root = f"/content/token_count_local/{lang}"
    os.makedirs(dst_root, exist_ok=True)
    local_files = []
    for src in files:
        dst = os.path.join(dst_root, os.path.basename(src))
        if not os.path.exists(dst):
            print(f"[copy:{lang}] {src} -> {dst}")
            shutil.copy2(src, dst)
        local_files.append(dst)
    return local_files

def load_ckpt(lang: str) -> dict:
    path = os.path.join(CKPT_DIR, f"{lang}.json")
    if os.path.exists(path):
        try:
            with open(path, "r") as f:
                return json.load(f)
        except Exception:
            pass
    return {"files_done": {}, "tokens": 0, "docs": 0, "lines": 0}

def save_ckpt(lang: str, ck: dict):
    path = os.path.join(CKPT_DIR, f"{lang}.json")
    with open(path, "w") as f:
        json.dump(ck, f)

def iter_jsonl_texts(path: str):
    with _open_smart(path) as fb:
        with io.TextIOWrapper(fb, encoding="utf-8", errors="replace", newline="") as f:
            for raw in f:
                s = raw.strip()
                if not s:
                    continue
                try:
                    rec = json.loads(s)
                except Exception:
                    continue
                t = rec.get("text", "")
                if t:
                    yield t

def encode_batch_count(sp: spm.SentencePieceProcessor, lines: List[str]) -> int:
    # Joining with a space keeps token totals exact (SPM is context-free across whitespace boundaries).
    text = " ".join(lines)
    return len(sp.EncodeAsIds(text))

# ------------ main ------------
print(f"[load] SentencePiece model: {SPM_MODEL}")
sp = spm.SentencePieceProcessor()
sp.Load(SPM_MODEL)

lang_token_counts = {}
grand_total = 0

for lang in LANGS:
    files = pick_files_for_lang(lang)
    if not files:
        print(f"[skip] no files for {lang}")
        continue

    files = copy_to_local_if_needed(files, lang)
    ck = load_ckpt(lang)

    tok_count = ck["tokens"]
    doc_count = ck["docs"]
    line_count = ck["lines"]

    print(f"\n=== [{lang}] starting from checkpoint: tokens={tok_count:,}, docs={doc_count:,}, lines={line_count:,} ===")
    t0 = time.time()
    last_report_t = t0
    last_report_tok = tok_count

    for fp in files:
        # skip file if already fully processed with same size
        done_info = ck["files_done"].get(fp)
        sz = None
        try:
            sz = os.path.getsize(fp)
        except OSError:
            pass
        if done_info and done_info.get("size") == sz and done_info.get("done", False):
            print(f"[{lang}] skip (done) {os.path.basename(fp)}")
            continue

        print(f"[{lang}] counting {os.path.basename(fp)} ({sz if sz is not None else 'size?' } bytes)")
        batch = []
        for text in iter_jsonl_texts(fp):
            # count every JSONL record as a "doc"
            doc_count += 1
            batch.append(text)
            if len(batch) >= BATCH_LINES:
                tok_count += encode_batch_count(sp, batch)
                line_count += len(batch)
                batch.clear()

                # progress report
                now = time.time()
                if tok_count - last_report_tok >= REPORT_EVERY_TOK or (now - last_report_t) >= REPORT_EVERY_SEC:
                    dt = now - last_report_t
                    speed = (tok_count - last_report_tok) / max(1e-6, dt)
                    print(f"    [{lang}] progress: tokens={tok_count:,} docs={doc_count:,} lines={line_count:,} ~{int(speed):,} tok/s")
                    last_report_t = now
                    last_report_tok = tok_count
                    # checkpoint
                    ck["tokens"] = tok_count; ck["docs"] = doc_count; ck["lines"] = line_count
                    ck["files_done"][fp] = {"size": sz, "done": False}
                    save_ckpt(lang, ck)

        if batch:
            tok_count += encode_batch_count(sp, batch)
            line_count += len(batch)
            batch.clear()

        # mark file done in ckpt
        ck["tokens"] = tok_count; ck["docs"] = doc_count; ck["lines"] = line_count
        ck["files_done"][fp] = {"size": sz, "done": True}
        save_ckpt(lang, ck)
        print(f"[{lang}] done {os.path.basename(fp)} → tokens={tok_count:,}, docs={doc_count:,}, lines={line_count:,}")

    lang_token_counts[lang] = tok_count
    grand_total += tok_count
    elapsed = time.time() - t0
    rate = tok_count / max(1e-6, elapsed)
    print(f"\n[{lang}] FINAL: tokens={tok_count:,} docs={doc_count:,} lines={line_count:,}  ({int(rate):,} tok/s)")

# -------- summary --------
print("\n=== TOKEN COUNTS PER LANGUAGE ===")
for lang, n in lang_token_counts.items():
    pct = n / grand_total * 100 if grand_total else 0
    print(f"  {lang}: {n:,} tokens  ({pct:.2f}%)")
print(f"  TOTAL: {grand_total:,} tokens")

# also save to Drive
summary_path = "/content/drive/MyDrive/LMA_SLM/tokenizer_reports/token_counts_per_lang.json"
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
with open(summary_path, "w") as f:
    json.dump({"per_lang": lang_token_counts, "total": grand_total}, f, indent=2)
print(f"\n[saved] {summary_path}")
# =============================================================================================================

[load] SentencePiece model: /content/drive/MyDrive/LMA_SLM/tokenizers/sp_unigram_64000.model
[copy:eng] /content/drive/MyDrive/LMA_SLM/data/splits/eng/val.jsonl -> /content/token_count_local/eng/val.jsonl
[copy:eng] /content/drive/MyDrive/LMA_SLM/data/splits/eng/train.jsonl -> /content/token_count_local/eng/train.jsonl
[copy:eng] /content/drive/MyDrive/LMA_SLM/data/raw/eng.jsonl -> /content/token_count_local/eng/eng.jsonl
