In [None]:
import math
from typing import Dict, Tuple, List
import numpy as np
import pandas as pd

# ----------------------------
# I/O CONFIG
# ----------------------------
INPUT_CSV = "D:/Sarah_Professional/LJMU Research/Dataset/updated_Excel.xlsx"       # <-- change me
OUT_EXP   = "D:/Sarah_Professional/LJMU Research/Output/experiment_set_100.csv"
OUT_HOLD  = "D:/Sarah_Professional/LJMU Research/Output/holdout_set_50.csv"

# ----------------------------
# SAMPLING CONFIG
# ----------------------------
RANDOM_SEED = 42
EXPERIMENT_TOTAL = 100
HOLDOUT_TOTAL = 50
MIN_PER_TOPIC_EXPERIMENT = 6   # with 15 topics → feasible (100/15 ≈ 6.67)
MIN_PER_TOPIC_HOLDOUT = 3      # with 15 topics → feasible (50/15 ≈ 3.33)

# Enforced (canonical) 15 topics — adjust to YOUR final list of 15
ENFORCED_TOPICS = [
    "Artificial Intelligence", "Machine Learning", "Deep Learning",
    "Natural Language Processing", "Generative AI", "Agentic AI",
    "Data Science", "Python Programming", "Reinforcement Learning",
    "Time Series", "Mlops", "Langchain", "Langraph",
    "Statistics", "Prompt Engineering"
]
# Note: anything unmapped can go to "Other" (not enforced for minima)

# ----------------------------
# PLAYLIST → TOPIC MAPPING
# 1) Exact-name dictionary (fast path)
# 2) Fallback keyword rules (best-effort)
# ----------------------------
PLAYLIST_TO_TOPIC: Dict[str, str] = {
    # ==== EXAMPLES — extend/replace with your real playlists ====
    # Artificial Intelligence
    "AI Fundamentals": "Artificial Intelligence",
    "AI Technical Tutorials": "Artificial Intelligence",

    # Machine Learning
    "AI and Machine Learning with Google Cloud": "AI and ML",
    "Intro to Machine Learning": "Machine Learning",
    "Machine Learning with Python": "Machine Learning",

    # Deep Learning
    "Deep Learning | Udacity": "Deep Learning",
    "Deep Learning With Tensorflow 2.0, Keras and Python": "Deep Learning",

    # Natural Language Processing
    "NLP Tutorial Python": "Natural Language Processing",
    "Natural Language Processing": "Natural Language Processing",

    # Generative AI
    "Generative AI": "Generative AI",
    "Gen AI Tutorials": "Generative AI",

    # Agentic AI
    "Agentic AI": "Agentic AI",

    # Data Science
    "Intro to Data Science": "Data Science",

    # Python Programming
    "Python Training - Complete Python Training Course": "Python Programming",

    # Reinforcement Learning
    "Reinforcement Learning 101": "Reinforcement Learning",

    # Time Series
    "Time Series Crash Course": "Time Series",

    # MLOps
    "Machine Learning Engineering for Production(Mlops)": "Mlops",

    # LangChain
    "Updated Langchain": "Langchain",

    # Langraph
    "LangGraph Crash Course: From Basic to Building Powerful Agents | 2025": "Langraph",

    # Statistics
    "Statistics in Machine learning": "Statistics",

    # Prompt Engineering
    "Prompt Engineering Full Course with LLM": "Prompt Engineering",
}

# Fallback keyword rules (case-insensitive) if a playlist_name
# isn't in PLAYLIST_TO_TOPIC. Add patterns as you like.
KEYWORD_RULES: List[Tuple[str, str]] = [
    ("\\bAI\\b|artificial intelligence", "Artificial Intelligence"),
    ("\\bML\\b|machine learning", "Machine Learning"),
    ("deep learning|neural network|cnn|rnn|transformer", "Deep Learning"),
    ("\\bNLP\\b|natural language processing|text mining", "Natural Language Processing"),
    ("generative ai|genai|diffusion|llm", "Generative AI"),
    ("agentic ai|multi-agent|autonomous agent|swarm", "Agentic AI"),
    ("data science|data analysis|eda", "Data Science"),
    ("python", "Python Programming"),
    ("reinforcement learning|q-learning|policy gradient|rl", "Reinforcement Learning"),
    ("time series|temporal|forecast", "Time Series"),
    ("mlops|ml ops|deployment|monitoring|model ops", "Mlops"),
    ("langchain", "Langchain"),
    ("langraph", "Langraph"),
    ("statistics|probability|bayes|hypothesis testing", "Statistics"),
    ("prompt engineering|prompting", "Prompt Engineering"),
]

import re
def infer_topic_from_playlist(playlist_name: str) -> str:
    """Map playlist_name → topic using dict first, then keyword rules."""
    if not isinstance(playlist_name, str) or not playlist_name.strip():
        return "Other"
    # Exact map (case sensitive on key); try case-normalization too
    if playlist_name in PLAYLIST_TO_TOPIC:
        return PLAYLIST_TO_TOPIC[playlist_name]
    # try normalized lookup
    norm_key = playlist_name.strip()
    if norm_key in PLAYLIST_TO_TOPIC:
        return PLAYLIST_TO_TOPIC[norm_key]
    # Keyword rules
    lower = playlist_name.lower()
    for pattern, topic in KEYWORD_RULES:
        if re.search(pattern, lower):
            return topic
    return "Other"

# ----------------------------
# LENGTH BINNING
# ----------------------------
def length_bin(num_tokens: int) -> str:
    if num_tokens < 1000: return "Short"
    if num_tokens <= 3000: return "Medium"
    return "Long"

def approx_tokens(text: str) -> int:
    if not isinstance(text, str): return 0
    return len(text.split())

def ensure_columns(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    # topic from playlist_name
    if "topic" not in df.columns:
        if "playlist_name" not in df.columns:
            raise ValueError("DataFrame must have 'topic' or 'playlist_name'.")
        df["topic"] = df["playlist_name"].apply(infer_topic_from_playlist)

    # tokens
    if "num_tokens" not in df.columns:
        if "transcript" not in df.columns:
            raise ValueError("Need 'num_tokens' or 'transcript' to compute length bins.")
        df["num_tokens"] = df["transcript"].apply(approx_tokens)

    df["length_bin"] = df["num_tokens"].apply(length_bin)
    return df

# ----------------------------
# ALLOCATION HELPERS
# ----------------------------
def largest_remainder_allocation(capacity: Dict[str, int], totalsize: int, seed: int) -> Dict[str, int]:
    """Proportional allocation with caps + largest remainder tie-break."""
    sizes = {k: max(0, int(v)) for k, v in capacity.items()}
    N = sum(sizes.values())
    if totalsize > N:
        raise ValueError(f"Requested {totalsize} but only {N} available across groups.")
    if N == 0 or totalsize <= 0:
        return {k: 0 for k in sizes}

    reals = {k: (totalsize * sizes[k] / N) for k in sizes}
    floors = {k: min(sizes[k], int(math.floor(x))) for k, x in reals.items()}
    picked = sum(floors.values())
    remaining = totalsize - picked

    rng = np.random.default_rng(seed)
    keys = list(sizes.keys())
    rng.shuffle(keys)
    keys.sort(key=lambda k: (reals[k] - math.floor(reals[k])), reverse=True)

    caps = {k: sizes[k] - floors[k] for k in sizes}
    for k in keys:
        if remaining == 0: break
        if caps[k] > 0:
            floors[k] += 1
            caps[k] -= 1
            remaining -= 1

    assert sum(floors.values()) == totalsize
    return floors

def allocate_per_topic(
    df: pd.DataFrame,
    total: int,
    min_per_topic: int,
    seed: int,
    enforced_topics: List[str]
) -> Dict[str, int]:
    """
    Give each ENFORCED topic at least min_per_topic (if possible),
    then distribute the remainder proportionally across ALL topics with capacity
    (including 'Other', if present).
    """
    topic_sizes = df.groupby("topic").size().to_dict()

    # base minima only for enforced topics
    base = {t: 0 for t in topic_sizes}
    for t in topic_sizes:
        if t in enforced_topics:
            base[t] = min(min_per_topic, topic_sizes[t])

    base_sum = sum(base.values())
    if base_sum > total:
        raise ValueError(f"Sum of minima ({base_sum}) exceeds total ({total}). "
                         f"Lower min_per_topic or increase total.")

    # Remaining capacity per topic after base
    capacity = {t: max(0, topic_sizes[t] - base[t]) for t in topic_sizes}
    remainder = total - base_sum
    extra = largest_remainder_allocation(capacity, remainder, seed)
    quotas = {t: base.get(t, 0) + extra.get(t, 0) for t in topic_sizes}
    return quotas

def allocate_within_topic_bins(topic_df: pd.DataFrame, quota: int, seed: int) -> Dict[Tuple[str,str], int]:
    """
    Split a topic's quota across Short/Medium/Long.
    - If quota >= #non-empty bins → give each non-empty bin at least 1, then distribute remainder by capacity.
    - Else → allocate to largest bins first.
    """
    topic = topic_df["topic"].iloc[0]
    bin_sizes = topic_df.groupby("length_bin").size().to_dict()
    nonempty = [b for b, n in bin_sizes.items() if n > 0]
    if quota <= 0 or len(nonempty) == 0:
        return {(topic, b): 0 for b in bin_sizes}

    base = {b: 0 for b in bin_sizes}
    if quota >= len(nonempty):
        for b in nonempty: base[b] = 1
        remaining = quota - len(nonempty)
    else:
        # fewer quota than bins → fill biggest first
        largest = sorted(nonempty, key=lambda b: bin_sizes[b], reverse=True)
        for b in largest[:quota]: base[b] = 1
        remaining = 0

    if remaining > 0:
        capacity = {b: max(0, bin_sizes[b] - base[b]) for b in bin_sizes}
        extra = largest_remainder_allocation(capacity, remaining, seed)
        for b in extra:
            base[b] += extra[b]

    # respect bin caps
    for b in base:
        base[b] = min(base[b], bin_sizes[b])

    return {(topic, b): base[b] for b in bin_sizes}

# ----------------------------
# SAMPLING
# ----------------------------
def stratified_sample_2stage(
    df: pd.DataFrame,
    total: int,
    min_per_topic: int,
    seed: int,
    enforced_topics: List[str]
) -> pd.DataFrame:
    """
    Stage 1: per-topic quotas (minima for enforced_topics).
    Stage 2: within-topic, split quota across length bins.
    Then sample per (topic, length_bin).
    """
    df = ensure_columns(df)

    topic_quotas = allocate_per_topic(
        df, total=total, min_per_topic=min_per_topic, seed=seed, enforced_topics=enforced_topics
    )

    bin_targets: Dict[Tuple[str, str], int] = {}
    for topic, q in topic_quotas.items():
        tdf = df[df["topic"] == topic]
        within = allocate_within_topic_bins(tdf, q, seed+hash(topic) % 100000)
        bin_targets.update(within)

    rng = np.random.default_rng(seed)
    parts = []
    grouped = df.groupby(["topic", "length_bin"], dropna=False)
    for (topic, b), n in bin_targets.items():
        if n <= 0: continue
        if (topic, b) not in grouped.groups:  # empty stratum
            continue
        stratum = grouped.get_group((topic, b))
        if n >= len(stratum):
            parts.append(stratum)
        else:
            parts.append(stratum.sample(n=n, random_state=int(rng.integers(0, 2**31-1))))

    if not parts:
        return df.iloc[0:0].copy()
    out = pd.concat(parts, axis=0).sample(frac=1.0, random_state=seed)  # shuffle
    assert len(out) == sum(topic_quotas.values()) == total, "Sampling mismatch."
    return out

def build_experiment_and_holdout(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    1) Pick 100 with ≥6 per enforced topic, split across bins.
    2) From the remaining pool, pick 50 with ≥3 per enforced topic, split across bins.
    """
    df = ensure_columns(df)

    if len(df) < (EXPERIMENT_TOTAL + HOLDOUT_TOTAL):
        raise ValueError(f"Need at least {EXPERIMENT_TOTAL + HOLDOUT_TOTAL} rows; got {len(df)}.")

    exp_df = stratified_sample_2stage(
        df, total=EXPERIMENT_TOTAL, min_per_topic=MIN_PER_TOPIC_EXPERIMENT,
        seed=RANDOM_SEED, enforced_topics=ENFORCED_TOPICS
    )
    remaining = df.drop(exp_df.index)
    hold_df = stratified_sample_2stage(
        remaining, total=HOLDOUT_TOTAL, min_per_topic=MIN_PER_TOPIC_HOLDOUT,
        seed=RANDOM_SEED+1, enforced_topics=ENFORCED_TOPICS
    )
    return exp_df, hold_df

# ----------------------------
# MAIN: read input, run, write output
# ----------------------------
if __name__ == "__main__":
    df = pd.read_excel(INPUT_CSV)

    # Create 'topic' from 'playlist_name' if needed (mapping + keyword fallback)
    if "topic" not in df.columns and "playlist_name" in df.columns:
        df["topic"] = df["playlist_name"].apply(infer_topic_from_playlist)

    exp_df, hold_df = build_experiment_and_holdout(df)

    # Optional sanity prints
    def report(x: pd.DataFrame, name: str):
        print(f"\n{name} size = {len(x)}")
        print("Per topic counts:")
        print(x.groupby("topic").size().rename("count").sort_values(ascending=False))
        print("\nPer topic × length_bin:")
        print(x.groupby(["topic", "length_bin"]).size().rename("count"))

    report(exp_df, "Experiment (100)")
    report(hold_df, "Hold-out (50)")

    exp_df.to_csv(OUT_EXP, index=False)
    hold_df.to_csv(OUT_HOLD, index=False)
    print(f"\nSaved:\n - {OUT_EXP}\n - {OUT_HOLD}")
