### Step 1: Load data and prepare labels (reads processed CSV and shows class distribution)

In [2]:
### Step 1: Load data and prepare labels (reads processed CSV and shows class distribution)
import os
import random
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from collections import Counter

# --- Paths ---
DATA_PATH = r"D:\Data Science Projects\Data Citation Intent Classification\data\processed\train_labeled.csv"
OUTPUT_DIR = r"D:\Data Science Projects\Data Citation Intent Classification\models\bert_distil"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Read data ---
df = pd.read_csv(DATA_PATH)
print("‚úÖ Total rows:", len(df))
print("üîç Columns found:", df.columns.tolist())

# --- Detect label column automatically ---
possible_labels = ["label", "Label", "labels", "Labels", "type", "Type", "labels_list"]
label_col = None

for col in possible_labels:
    if col in df.columns:
        label_col = col
        break

# Try to infer if not found
if label_col is None:
    for col in df.columns:
        sample_text = " ".join(df[col].dropna().astype(str).head(100).tolist()).lower()
        if "primary" in sample_text or "secondary" in sample_text:
            label_col = col
            print(f"üîé Inferred label column: '{label_col}'")
            break

if label_col is None:
    raise KeyError("‚ùå Could not find a label column automatically. Please check df.columns and update manually.")
else:
    print(f"‚úÖ Using label column: '{label_col}'")

# --- Clean and standardize label values ---
def extract_label(value):
    s = str(value).lower()
    if "primary" in s:
        return "Primary"
    elif "secondary" in s:
        return "Secondary"
    elif s.strip() in ("0", "1"):
        return "Primary" if s.strip() == "0" else "Secondary"
    else:
        return np.nan

df["label"] = df[label_col].apply(extract_label)

# --- Drop rows without label ---
missing_labels = df["label"].isna().sum()
if missing_labels > 0:
    print(f"‚ö†Ô∏è Dropping {missing_labels} rows with missing labels.")
    df = df.dropna(subset=["label"]).reset_index(drop=True)

# --- Display class distribution ---
print("‚úÖ Class distribution:", dict(Counter(df["label"].astype(str))))

# --- Map labels to numeric IDs ---
unique_labels = sorted(df["label"].unique())
label2id = {lab: i for i, lab in enumerate(unique_labels)}
id2label = {v: k for k, v in label2id.items()}
df["label_id"] = df["label"].map(label2id)

print("‚úÖ Label mapping:", label2id)

# --- Stratified train/validation split ---
train_df, val_df = train_test_split(
    df,
    test_size=0.15,
    random_state=42,
    stratify=df["label_id"]
)

print(f"‚úÖ Train size: {len(train_df)}, Validation size: {len(val_df)}")



‚úÖ Total rows: 44899
üîç Columns found: ['article_id', 'ref_id', 'context', 'labels']
‚úÖ Using label column: 'labels'
‚ö†Ô∏è Dropping 23684 rows with missing labels.
‚úÖ Class distribution: {'Primary': 15639, 'Secondary': 5576}
‚úÖ Label mapping: {'Primary': 0, 'Secondary': 1}
‚úÖ Train size: 18032, Validation size: 3183
