In [1]:
import os
import re
import json
import pickle
from datetime import datetime
from typing import Tuple, Dict, Any, Optional

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import h5py

try:
    import yaml
    HAVE_YAML = True
except Exception:
    HAVE_YAML = False


# --------- USER PATHS (edit if needed) ----------
CSV_PATH = r"C:\Users\sagni\Downloads\Suicidal Detection\archive\suicidal_ideation_reddit_annotated.csv"
OUT_DIR  = r"C:\Users\sagni\Downloads\Suicidal Detection"
# ------------------------------------------------


# Common name candidates
TEXT_CANDIDATES = [
    "text", "message", "content", "body", "post", "comment", "clean_text", "utterance",
    "selftext", "title"
]
LABEL_CANDIDATES = [
    "label", "class", "target", "is_suicidal", "suicidal", "suicide", "risk", "y"
]

# String->binary mapping heuristics
POS_TOKENS = {"1","true","t","yes","y","suicidal","suicide","positive","pos","high","at risk"}
NEG_TOKENS = {"0","false","f","no","n","non-suicidal","non suicidal","negative","neg","low","not at risk"}

RANDOM_STATE = 42
SPLIT_FRAC = (0.70, 0.15, 0.15)  # train, valid, test


def ensure_out_dir(path: str):
    os.makedirs(path, exist_ok=True)


def load_csv(path: str) -> pd.DataFrame:
    if not os.path.exists(path):
        raise FileNotFoundError(f"CSV not found: {path}")
    # Try utf-8 first; fallback to latin-1
    try:
        df = pd.read_csv(path)
    except UnicodeDecodeError:
        df = pd.read_csv(path, encoding="latin-1")
    return df


def detect_text_and_label(df: pd.DataFrame) -> Tuple[pd.Series, pd.Series, str, str]:
    cols_lower = {c.lower(): c for c in df.columns}

    # Text detection:
    text_col: Optional[str] = None
    for cand in TEXT_CANDIDATES:
        if cand in cols_lower:
            text_col = cols_lower[cand]
            break

    # Special Reddit-style: combine title + selftext if both exist and no single text found
    if text_col is None and "title" in cols_lower and "selftext" in cols_lower:
        title_col = cols_lower["title"]
        st_col = cols_lower["selftext"]
        text = (df[title_col].fillna("").astype(str) + " " + df[st_col].fillna("").astype(str)).str.strip()
    else:
        if text_col is None:
            # fallback: use the first object/string dtype column
            obj_cols = [c for c in df.columns if df[c].dtype == object]
            if not obj_cols:
                raise ValueError("Could not detect a text column. Please rename your text column to 'text'.")
            text_col = obj_cols[0]
        text = df[text_col].astype(str)

    # Clean empties
    text = text.fillna("").astype(str).str.replace(r"\s+", " ", regex=True).str.strip()
    text = text[text != ""]
    df = df.loc[text.index]  # align

    # Label detection:
    label_col: Optional[str] = None
    for cand in LABEL_CANDIDATES:
        if cand in cols_lower:
            label_col = cols_lower[cand]
            break
    if label_col is None:
        # Try to infer from columns that look binary (0/1 or True/False)
        for c in df.columns:
            ser = df[c]
            unique = pd.Series(ser.dropna().unique()).astype(str).str.lower().str.strip()
            if len(unique) <= 6 and unique.isin(list(POS_TOKENS | NEG_TOKENS)).any():
                label_col = c
                break

    if label_col is None:
        raise ValueError(
            f"Could not detect label column. "
            f"Please rename your label column to one of {LABEL_CANDIDATES}."
        )

    labels = normalize_labels(df[label_col])
    # Align with cleaned text index
    labels = labels.loc[text.index]
    return text, labels, text_col if text_col else "title+selftext", label_col


def normalize_labels(series: pd.Series) -> pd.Series:
    """Map label values to {0,1}, robust to text variants."""
    s = series.copy()

    def to01(x: Any) -> Optional[int]:
        if pd.isna(x):
            return None
        if isinstance(x, (int, np.integer, float, np.floating)):
            if int(x) in (0,1): return int(x)
        xs = str(x).strip().lower()
        if xs in POS_TOKENS: return 1
        if xs in NEG_TOKENS: return 0
        # common words
        if re.fullmatch(r"(non[-\s]?suicidal|no risk|not suicidal)", xs): return 0
        if re.fullmatch(r"(suicidal|at[-\s]?risk|high risk|suicide)", xs): return 1
        # last resort: try integer cast
        try:
            vi = int(float(xs))
            if vi in (0,1): return vi
        except Exception:
            pass
        return None

    mapped = s.map(to01)
    # Drop rows that can't be mapped
    mask = mapped.isin([0,1])
    if mask.sum() == 0:
        raise ValueError("Could not map labels to {0,1}. Please ensure labels indicate suicidal vs non.")
    return mapped[mask].astype(int)


def make_splits(text: pd.Series, labels: pd.Series) -> Dict[str, pd.DataFrame]:
    df = pd.DataFrame({"text": text, "label": labels}).dropna()
    # First split: train vs temp
    train_df, temp_df = train_test_split(
        df, test_size=(1.0 - SPLIT_FRAC[0]),
        stratify=df["label"], random_state=RANDOM_STATE, shuffle=True
    )
    # Second split: valid/test from temp
    valid_size = SPLIT_FRAC[1] / (SPLIT_FRAC[1] + SPLIT_FRAC[2])
    valid_df, test_df = train_test_split(
        temp_df, test_size=(1.0 - valid_size),
        stratify=temp_df["label"], random_state=RANDOM_STATE, shuffle=True
    )
    for split, d in [("train", train_df), ("valid", valid_df), ("test", test_df)]:
        d.reset_index(drop=True, inplace=True)
    return {"train": train_df, "valid": valid_df, "test": test_df}


def write_h5(path: str, splits: Dict[str, pd.DataFrame]) -> None:
    """Write HDF5 without PyTables (uses h5py)."""
    with h5py.File(path, "w") as h5:
        for split, df in splits.items():
            grp = h5.create_group(split)
            # variable-length UTF-8 strings
            str_dt = h5py.string_dtype(encoding="utf-8")
            texts = df["text"].astype(str).values
            labels = df["label"].astype(np.int8).values
            grp.create_dataset("text", data=texts, dtype=str_dt, compression="gzip")
            grp.create_dataset("label", data=labels, dtype=np.int8, compression="gzip")


def write_pkl(path: str, splits: Dict[str, pd.DataFrame], meta: Dict[str, Any]) -> None:
    payload = {
        "splits": {k: v.copy() for k, v in splits.items()},
        "meta": meta
    }
    with open(path, "wb") as f:
        pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)


def write_yaml(path: str, meta: Dict[str, Any]) -> None:
    if HAVE_YAML:
        with open(path, "w", encoding="utf-8") as f:
            yaml.safe_dump(meta, f, sort_keys=False, allow_unicode=True)
    else:
        # Fallback: write JSON but with .yaml extension (so you get *something*)
        with open(path, "w", encoding="utf-8") as f:
            f.write(json.dumps(meta, ensure_ascii=False, indent=2))


def write_jsonl(path: str, splits: Dict[str, pd.DataFrame]) -> None:
    with open(path, "w", encoding="utf-8") as f:
        for split, df in splits.items():
            for _, row in df.iterrows():
                rec = {"split": split, "text": row["text"], "label": int(row["label"])}
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")


def class_counts(df: pd.DataFrame) -> Dict[str, int]:
    c = df["label"].value_counts().to_dict()
    return {str(int(k)): int(v) for k, v in c.items()}


def main():
    ensure_out_dir(OUT_DIR)
    df = load_csv(CSV_PATH)
    text, labels, text_col_name, label_col_name = detect_text_and_label(df)

    # Align indices after normalization/cleaning
    aligned = pd.DataFrame({"text": text, "label": labels}).dropna()
    splits = make_splits(aligned["text"], aligned["label"])

    # Build metadata
    meta = {
        "dataset_name": "suicidal_ideation_reddit_annotated",
        "source_csv": CSV_PATH,
        "created_utc": datetime.utcnow().isoformat() + "Z",
        "text_column_used": text_col_name,
        "label_column_used": label_col_name,
        "label_mapping": {0: "non-suicidal", 1: "suicidal"},
        "sizes": {k: int(v.shape[0]) for k, v in splits.items()},
        "class_balance": {
            split: class_counts(df_split) for split, df_split in splits.items()
        },
        "splits": {"train": SPLIT_FRAC[0], "valid": SPLIT_FRAC[1], "test": SPLIT_FRAC[2]},
        "random_state": RANDOM_STATE
    }

    # File paths
    out_h5   = os.path.join(OUT_DIR, "mindshield_dataset.h5")
    out_pkl  = os.path.join(OUT_DIR, "mindshield_dataset.pkl")
    out_yaml = os.path.join(OUT_DIR, "mindshield_config.yaml")
    out_jsonl= os.path.join(OUT_DIR, "mindshield_dataset.jsonl")
    out_sum  = os.path.join(OUT_DIR, "mindshield_summary.json")

    # Write artifacts
    write_h5(out_h5, splits)
    write_pkl(out_pkl, splits, meta)
    write_yaml(out_yaml, meta)
    write_jsonl(out_jsonl, splits)
    with open(out_sum, "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    # Console summary
    print("=== MindShield Artifacts Written ===")
    print(f"H5:    {out_h5}")
    print(f"PKL:   {out_pkl}")
    print(f"YAML:  {out_yaml}")
    print(f"JSONL: {out_jsonl}")
    print(f"SUM:   {out_sum}")
    print("\nSizes:", meta["sizes"])
    print("Class balance:", meta["class_balance"])
    print("Label mapping:", meta["label_mapping"])


if __name__ == "__main__":
    main()


=== MindShield Artifacts Written ===
H5:    C:\Users\sagni\Downloads\Suicidal Detection\mindshield_dataset.h5
PKL:   C:\Users\sagni\Downloads\Suicidal Detection\mindshield_dataset.pkl
YAML:  C:\Users\sagni\Downloads\Suicidal Detection\mindshield_config.yaml
JSONL: C:\Users\sagni\Downloads\Suicidal Detection\mindshield_dataset.jsonl
SUM:   C:\Users\sagni\Downloads\Suicidal Detection\mindshield_summary.json

Sizes: {'train': 8858, 'valid': 1898, 'test': 1899}
Class balance: {'train': {'1': 4625, '0': 4233}, 'valid': {'1': 991, '0': 907}, 'test': {'1': 992, '0': 907}}
Label mapping: {0: 'non-suicidal', 1: 'suicidal'}
