In [None]:
import os
import pickle
from collections import defaultdict

# ---------- CONFIG ----------
TOKEN_FILE = r"F:/collage/NLP/Assignment_1/2/gu_meta_part_1.txt"  # file containing one sentence per line
BATCH_SIZE = 700_000    # number of tokens to read per batch
MODEL_DIR = "ngram_model"
CHECKPOINT_FILE = os.path.join(MODEL_DIR, "checkpoint.pkl")
# ----------------------------

os.makedirs(MODEL_DIR, exist_ok=True)


# ---------- Data Loader ----------
def sentence_batches(file_path, batch_size, start_pos=0):
    """
    Yield batches of tokens from sentences with start/end markers.
    Each line in the file = one sentence.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        # Skip already processed lines if resuming
        for _ in range(start_pos):
            next(f, None)

        batch = []
        pos = start_pos
        for line in f:
            line = line.strip()
            if line:
                tokens = line.split()  # basic whitespace tokenization
                tokens = ["<s>"] + tokens + ["</s>"]  # add sentence boundary markers
                batch.extend(tokens)
                pos += 1
                if len(batch) >= batch_size:
                    yield batch, pos
                    batch = []
        if batch:
            yield batch, pos


# ---------- Checkpoint Helpers ----------
def save_checkpoint(state):
    with open(CHECKPOINT_FILE, "wb") as f:
        pickle.dump(state, f)


def load_checkpoint():
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, "rb") as f:
            return pickle.load(f)
    return None


# ---------- Training ----------
def train_ngram_model(ngram_size: int):
    """
    Train an n-gram model of given size with checkpoint support.
    Saves:
        - final_ngram_counts.pkl
        - final_context_counts.pkl
        - final_ngram_probs.pkl
    """
    # Try to resume from checkpoint
    checkpoint = load_checkpoint()
    if checkpoint and checkpoint["ngram_size"] == ngram_size:
        print(f"[Resuming] {ngram_size}-gram training from checkpoint at sentence {checkpoint['pos']}")
        ngram_counts = checkpoint["ngram_counts"]
        context_counts = checkpoint["context_counts"]
        total_tokens = checkpoint["total_tokens"]
        start_pos = checkpoint["pos"]
        batch_no = checkpoint["batch_no"]
    else:
        print(f"[Starting Fresh] {ngram_size}-gram training")
        ngram_counts = defaultdict(int)
        context_counts = defaultdict(int)
        total_tokens = 0
        start_pos = 0
        batch_no = 0

    for batch, pos in sentence_batches(TOKEN_FILE, BATCH_SIZE, start_pos):
        for i in range(len(batch) - ngram_size + 1):
            ngram = tuple(batch[i:i + ngram_size])
            context = ngram[:-1] if ngram_size > 1 else ()
            ngram_counts[ngram] += 1
            context_counts[context] += 1

        batch_no += 1
        print(f"[Batch {batch_no}] Processed up to sentence {pos}")

        # For unigrams: maintain total token count
        if ngram_size == 1:
            total_tokens += len(batch)
            context_counts[()] = total_tokens

        # Save checkpoint after each batch
        save_checkpoint({
            "ngram_size": ngram_size,
            "ngram_counts": ngram_counts,
            "context_counts": context_counts,
            "total_tokens": total_tokens,
            "pos": pos,
            "batch_no": batch_no
        })
    print(f"[Checkpoint Saved] {ngram_size}-gram at batch {batch_no}, sentence {pos:,}")

    # Save final counts
    with open(os.path.join(MODEL_DIR, f"final_{ngram_size}gram_counts.pkl"), "wb") as f:
        pickle.dump(dict(ngram_counts), f)
    with open(os.path.join(MODEL_DIR, f"final_{ngram_size}gram_context_counts.pkl"), "wb") as f:
        pickle.dump(dict(context_counts), f)

    print(f"[Training Complete] Final {ngram_size}-gram counts saved.")

    # ---- Compute probabilities ----
    ngram_probs = {}
    for ngram, count in ngram_counts.items():
        context = ngram[:-1] if ngram_size > 1 else ()
        context_count = context_counts[context]
        if context_count > 0:
            ngram_probs[ngram] = count / context_count

    # Save probabilities
    with open(os.path.join(MODEL_DIR, f"final_{ngram_size}gram_probs.pkl"), "wb") as f:
        pickle.dump(ngram_probs, f)

    print(f"[Probabilities Saved] {len(ngram_probs)} entries stored.")

    # Remove checkpoint after successful completion
    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)

    return ngram_counts, context_counts, ngram_probs


# ---------- Loader ----------
def load_ngram_model(ngram_size: int):
    """Load counts and probabilities for a trained n-gram model."""
    with open(os.path.join(MODEL_DIR, f"final_{ngram_size}gram_counts.pkl"), "rb") as f:
        ngram_counts = defaultdict(int, pickle.load(f))
    with open(os.path.join(MODEL_DIR, f"final_{ngram_size}gram_context_counts.pkl"), "rb") as f:
        context_counts = defaultdict(int, pickle.load(f))
    with open(os.path.join(MODEL_DIR, f"final_{ngram_size}gram_probs.pkl"), "rb") as f:
        ngram_probs = pickle.load(f)
    return ngram_counts, context_counts, ngram_probs


# ---------- Example Run ----------
if _name_ == "_main_":
    for n in [1, 2, 3, 4]:  # unigram, bigram, trigram, 4-gram
        print(f"\n--- Training {n}-gram model ---")
        counts, contexts, probs = train_ngram_model(n)
        print(f"Model size ({n}-gram): {len(counts)} ngrams, {len(probs)} probabilities")