In [2]:
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_score, recall_score
import time
import warnings
import pyarrow.parquet as pq
warnings.filterwarnings('ignore')

data_path = '/kaggle/input/fullfinal/telangana_data_with_features_and_targets (1).parquet'
flag_map = {
    'Y': 1, 'YES': 1, 'Yes': 1, 'y': 1, 'yes': 1,
    'N': 0, 'NO': 0, 'No': 0, 'n': 0, 'no': 0,
    None: np.nan, 'None': np.nan, '': np.nan, 'nan': np.nan
}

def prepare_data_for_targets(data_path, flag_map, batch_size=10000):
    """Load and preprocess data from Parquet file in batches."""
    # Initialize an empty list to store processed chunks
    processed_chunks = []
    
    # Define required and numeric columns
    required_cols = ['MOTHER_ID', 'GRAVIDA']
    numeric_cols = ['AGE', 'AGE_preg', 'AGE_final', 'GRAVIDA', 'PARITY', 'ABORTIONS', 'TOTAL_ANC_VISITS',
                    'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'WEIGHT_max', 'HEIGHT',
                    'PHQ_SCORE_max', 'GAD_SCORE_max', 'WEIGHT_last', 'WEIGHT_first',
                    'NO_OF_WEEKS_max', 'WEIGHT_min', 'WEIGHT_mean']
    
    # Open Parquet file using pyarrow
    parquet_file = pq.ParquetFile(data_path)
    num_rows = parquet_file.metadata.num_rows
    num_batches = (num_rows + batch_size - 1) // batch_size  # Ceiling division
    
    # Iterate through batches
    for batch_idx in range(num_batches):
        # Read a batch of rows
        start_row = batch_idx * batch_size
        end_row = min(start_row + batch_size, num_rows)
        batch = parquet_file.read_row_group(batch_idx).to_pandas() if parquet_file.num_row_groups > batch_idx else pd.DataFrame()
        
        if batch.empty:
            continue
            
        # Check for required columns
        for col in required_cols:
            if col not in batch.columns:
                raise ValueError(f"Required column {col} missing in data")

        # Debug: Check non-numeric values
        for col in numeric_cols:
            if col in batch.columns and batch[col].dtype == 'object':
                non_numeric = batch[col][~batch[col].apply(lambda x: str(x).replace('.', '').isdigit() if pd.notna(x) else False)]
                if not non_numeric.empty:
                    print(f"Non-numeric values in {col}: {non_numeric.unique()[:10]}")

        # Clean GRAVIDA specifically
        if 'GRAVIDA' in batch.columns:
            batch['GRAVIDA'] = batch['GRAVIDA'].replace('nan', np.nan)
            batch['GRAVIDA'] = pd.to_numeric(batch['GRAVIDA'], errors='coerce')
            # Replace NaN with median GRAVIDA (or 1 as default) within batch
            if batch['GRAVIDA'].notna().any():
                median_gravida = batch['GRAVIDA'].median()
                if pd.isna(median_gravida):
                    median_gravida = 1.0
                batch['GRAVIDA'] = batch['GRAVIDA'].fillna(0)
                print(f"Filled {batch['GRAVIDA'].isna().sum()} GRAVIDA NaN values with 0 in batch")

        # Convert other numeric columns
        for col in numeric_cols:
            if col in batch.columns and col != 'GRAVIDA':
                batch[col] = pd.to_numeric(batch[col], errors='coerce')

        # Map flag columns
        if 'IS_CHILD_DEATH' in batch.columns and batch['IS_CHILD_DEATH'].dtype == 'object':
            batch['IS_CHILD_DEATH'] = batch['IS_CHILD_DEATH'].map(flag_map)
        if 'IS_DEFECTIVE_BIRTH' in batch.columns and batch['IS_DEFECTIVE_BIRTH'].dtype == 'object':
            batch['IS_DEFECTIVE_BIRTH'] = batch['IS_DEFECTIVE_BIRTH'].map(flag_map)

        # Fill NaN for numeric columns
        batch_numeric_cols = batch.select_dtypes(include=['float64', 'float32', 'int64', 'int32', 'int8']).columns
        batch[batch_numeric_cols] = batch[batch_numeric_cols].fillna(0)

        # Append processed batch
        processed_chunks.append(batch)

    # Concatenate all chunks
    df = pd.concat(processed_chunks, ignore_index=True)
    return df

# Execute the function
df = prepare_data_for_targets(data_path, flag_map)

Non-numeric values in GRAVIDA: ['nan']
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch


In [3]:
target_columns = [
            'stillbirth_risk'
        ]
def create_stratified_sample(df, target_column, sample_size=2000000):
    """Create a stratified sample ensuring all critical cases are included."""
    child_death_col = 'IS_CHILD_DEATH'
    critical_cases = pd.DataFrame()
    
    if child_death_col in df.columns:
        child_deaths = df[df[child_death_col] == 1]
        critical_cases = pd.concat([critical_cases, child_deaths])
    if 'maternal_mortality_risk' in df.columns:
        maternal_deaths = df[df['maternal_mortality_risk'] == 1]
        critical_cases = pd.concat([critical_cases, maternal_deaths])
    if 'stillbirth_risk' in df.columns:
        stillbirths = df[df['stillbirth_risk'] == 1]
        critical_cases = pd.concat([critical_cases, stillbirths])
    critical_cases = critical_cases.drop_duplicates()
    
    remaining_size = sample_size - len(critical_cases)
    
    if remaining_size > 0:
        other_cases = df[~df.index.isin(critical_cases.index)]
        sampled_others = other_cases.sample(n=remaining_size, random_state=42)
        final_sample = pd.concat([critical_cases, sampled_others])
    else:
        final_sample = critical_cases.sample(n=sample_size, random_state=42)
    
    return final_sample

sample_df=create_stratified_sample(df, 'stillbirth_risk')

exclude_cols = [
    # Targets & labels
    'maternal_mortality_risk', 'stillbirth_risk', 'premature_birth_risk',
    'birth_defect_risk', 'anc_dropout', 'high_risk_pregnancy',

    # Derived risk scores and flags
    'total_risk_factors', 'clinical_risk_score', 'risk_level', 'predicted_risk',
    'total_missed_visits', 'age_risk_score', 'demographic_risk', 'anemia_risk_score',
    'overall_risk_score',

    # Identifiers
    'MOTHER_ID', 'ANC_ID', 'CHILD_ID', 'unique_id', 'EID', 'UID_NUMBER',

    # Delivery outcome or post-delivery info (leakage)
    'WEIGHT_child_mean', 'DELIVERY_MODE', 'MATERNAL_OUTCOME', 'IS_DELIVERED',
    'DELIVERY_OUTCOME', 'DATE_OF_DELIVERY', 'PLACE_OF_DELIVERY', 'DEL_TIME',
    'DATE_OF_DISCHARGE', 'DISCHARGE_TIME', 'JSY_BENEFICIARY',
    'IS_MOTHER_ALIVE', 'IS_CHILD_DEATH', 'CHILD_DEATH_DATE', 'CHILD_DEATH_REASON',
    'DEFECT_HEALTH_CENTER', 'IS_DEFECTIVE_BIRTH', 'BIRTH_DEFECT_TYPE',
    'BIRTH_DEFECT_SUBTYPE', 'DEFECT_SUBTYPE_OTHER', 'DEFECT_TYPE_OTHER',
    'NOTIFICATION_SENT', 'FBIR_COMPLETED_BY_ANM', 'NEWBORN_SCREENING',
    'DEATH_REASON_OTHER', 'SNCU_ADMITTED', 'SNCU_REFERRAL_HOSPITAL',
    'TERTIARY_REFERRAL_HOSPITAL', 'OTHER_REFERRAL_HOSPITAL',
    'DATE_OF_DEATH', 'REASON_FOR_DEATH', 'PLACE_OF_DEATH',
    'INDICATION_FOR_C_SECTION', 'CH', 'CAH', 'GALACTOCEMIA', 'G6PDD', 'BIOTINIDASE',

    # Delivery-specific process info
    'DELIVERY_INSTITUTION', 'DELIVERY_DONE_BY', 'CONDUCT_BY',
    'MISOPROSTAL_TABLET', 'DEL_COMPLICATIONS', 'OTHER_DEL_COMPLICATIONS',
    'NOTIFICATION_SENT_del', 'FBIR_COMPLETED_BY_ANM_del',

    # Administrative / logging columns
    'REGISTRATION_DT', 'REGTYPE', 'CURRENT_USR', 'OTHER_STATE_PLACE',
    'OTHER_STATE_PLACE_FILEPATH', 'OTHER_GOVT_PLACE_FILEPATH',
    'ANC2_TAG_FAC_ID', 'ANC3_TAG_FAC_ID',

    # Facility and geographic info
    'ANC_INSTITUTE', 'FACILITY_TYPE', 'FACILITY_NAME', 'DOCTOR_ANM',
    'DISTRICT_anc', 'DISTRICT_child',

    # Feeding / newborn care
    'IS_BF_IN_HOUR', 'FEEDING_TYPE', 'DATE_OF_FIRST_FEEDING',
    'TIME_OF_FIRST_FEEDING', 'BABY_ON_MEDICATION', 'MEDICATION_REMARKS',
    'DATE_OF_BLOODSAMPLE_COLLECTION', 'TIME_OF_BLOODSAMPLE_COLLECTION',
    'HOURS_OF_SAMPLE_COLLECTION', 'TRANSFUSION_DONE',

    # Screening/test results
    'VDRL_DATE', 'VDRL_STATUS', 'HIV_DATE', 'HIV_STATUS', 'HBSAG_DATE',
    'HBSAG_STATUS', 'HEP_DATE', 'HEP_STATUS', 'VDRL_RESULT',
    'HIV_RESULT', 'HBSAG_RESULT', 'HEP_RESULT',

    # Missed ANC flag columns
    'MISSANC1FLG', 'MISSANC2FLG', 'MISSANC3FLG', 'MISSANC4FLG',

    # Manually added known leaky or post-hoc columns
    'HIGH_RISKS', 'DISEASES', 'CHILD_NAME',

    # Any column with 'risk' or 'score' in name not already excluded
    *[col for col in df.columns if (
        ('risk' in col.lower() or 'score' in col.lower()) and col not in {
            'age_risk_score', 'anemia_risk_score', 'overall_risk_score',
            'demographic_risk', 'clinical_risk_score', 'total_risk_factors',
            'maternal_mortality_risk', 'stillbirth_risk', 'premature_birth_risk',
            'birth_defect_risk', 'high_risk_pregnancy', 'mental_health_risk'
        }
    )],
]

more_leakage_cols = [
    # aggregate or binary risk flags
#     'bp_risk', 'mental_health_risk', 'age_category', 'multigravida', 'grand_multipara',
#      'no_anc',
#     'missed_first_anc', 'consecutive_missed',
#     # anemia + BP severity flags
#     'severe_hypertension',
#     # weight / BMI flags
#     'low_birth_weight', 'very_low_birth_weight', 'avg_birth_weight_low',
#     # mental-health flags
#     # trend or helper columns
#     'hemoglobin_trend', 'unique_id', 'TT_DATE', 'MAL_PRESENT','IS_ADMITTED_SNCU',
#  'IS_PREV_PREG','CHILD_ID', 'ANC1FLG',
#  'ANC2FLG',
#  'ANC3FLG',
#  'ANC4FLG',
#  'ANC_DATE',
#  'DEATH',
#  'EXP_DOD',
#  'DELIVERY_PLACE',
#  'EXP_DOD_preg',
#  'FASTING',
#  'LMP_DT',
#  'SCREENED_FOR_MENTAL_HEALTH',
#  'AGE',
#  'AGE_final',
#  'AGE_preg',
#  'WEIGHT_child_min',
# 'GENDER',
#     'TIME_OF_BIRTH',
#      'TWIN_PREGNANCY_max', 'TOTAL_ANC_VISITS',
#     'RNK', 

#     # any other “risk / score” not already in exclude list
#     *[c for c in df.columns if ('risk' in c.lower() or 'score' in c.lower()) and c not in exclude_cols]
    
    "stillbirth_risk",
    "DELIVERY_OUTCOME",
    "IS_DEFECTIVE_BIRTH",
    "BIRTH_DEFECT_TYPE",
    "BIRTH_DEFECT_SUBTYPE",
    "DEFECT_TYPE_OTHER",
    "CHILD_DEATH_DATE",
    "CHILD_DEATH_REASON",
    "IS_CHILD_DEATH",
    "DEATH_REASON_OTHER",
    "MATERNAL_OUTCOME",
    "REASON_FOR_DEATH",
    "DATE_OF_DEATH",
    "DATE_OF_DELIVERY",
    "PLACE_OF_DELIVERY",
    "MODE_OF_DELIVERY",
    "INDICATION_FOR_C_SECTION",
    "DEL_COMPLICATIONS",
    "OTHER_DEL_COMPLICATIONS",
    "NOTIFICATION_SENT_del",
    "FBIR_COMPLETED_BY_ANM_del",
    "IS_DELIVERED",
    "DATE_OF_DISCHARGE",
    "DISCHARGE_TIME",
    "CHILD_ID",
    "CHILD_NAME",
    "GENDER",
    "TIME_OF_BIRTH",
    "IS_BF_IN_HOUR",
    "FEEDING_TYPE",
    "DATE_OF_FIRST_FEEDING",
    "TIME_OF_FIRST_FEEDING",
    "WEIGHT_child_mean",
    "WEIGHT_child_min",
    "avg_birth_weight_low",
    "LOW_BIRTH_WEIGHT",
    "very_low_birth_weight",
    "IS_ADMITTED_SNCU",
    "SNCU_ADMITTED",
    "SNCU_REFERRAL_HOSPITAL",
    "TERTIARY_REFERRAL_HOSPITAL",
    "OTHER_REFERRAL_HOSPITAL",
    "IMMUNE_CYCLE_DONE",
    "BIRTH_SCREENING",
    "NEWBORN_SCREENING",
    "EID",
    "DATE_OF_BLOODSAMPLE_COLLECTION",
    "TIME_OF_BLOODSAMPLE_COLLECTION",
    "HOURS_OF_SAMPLE_COLLECTION",
    "TRANSFUSION_DONE",
    "BABY_ON_MEDICATION",
    "MEDICATION_REMARKS",
    "CH",
    "CAH",
    "GALACTOCEMIA",
    "G6PDD",
    "BIOTINIDASE",
    "total_risk_factors",
    "high_risk_pregnancy",
    "overall_risk_score",
    "clinical_risk_score",
    "demographic_risk",
    "maternal_mortality_risk",
    "premature_birth_risk",
    "anemia_risk_score",
    "bp_risk",
    "mental_health_risk",
    "age_risk_score",
    "PHQ_SCORE_max",
    "GAD_SCORE_max",
    "depression",
    "severe_depression",
    "anxiety",
    "severe_anxiety",
    "mental_health_risk",
    "SCREENED_FOR_MENTAL_HEALTH",
    "NO_OF_WEEKS_max",
    "TWIN_PREGNANCY_max",
    "DELIVERY_PLACE",
    "DELIVERY_INSTITUTION",
    "DELIVERY_DONE_BY",
    "CONDUCT_BY",
    "OTHER_NAME",
    "JSY_BENEFICIARY",
    "NOTIFICATION_SENT",
    "FBIR_COMPLETED_BY_ANM",
    "OTHER_STATE_PLACE",
    "OTHER_STATE_PLACE_FILEPATH",
    "OTHER_GOVT_PLACE_FILEPATH",
        'IS_DEFECTIVE_BIRTH',
    'EID',
    'UID_NUMBER',
    'IS_MOTHER_ALIVE',
    'WEIGHT_child_mean',
    'WEIGHT_child_min',
    'low_birth_weight',
    'very_low_birth_weight',
    'avg_birth_weight_low',
    'AGE_final',
    'premature_birth_risk',
    'high_risk_pregnancy',
    'clinical_risk_score',
    'overall_risk_score',
    'total_risk_factors'

]
# 'NO_OF_WEEKS_max','HEMOGLOBIN_min', 'HEMOGLOBIN_max', 'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 
    # 'previous_loss', 'recurrent_loss',
# Finally combine everything
exclude_cols = list(set(exclude_cols + more_leakage_cols))

features = [col for col in df.columns if col not in exclude_cols and df[col].dtype in ['float64', 'float32', 'int64', 'int32', 'int8']]

if not features:
    print(f"Skipping {target_columns}: No valid features available.")



In [3]:
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, f1_score
import lightgbm as lgb
import numpy as np
import time

# Assuming sample_df, features, and target_column are defined (correcting target_columns to target_column)
X = sample_df[features]
y = sample_df[target_columns]  # Changed from target_columns to target_column

# Initial train-test split
X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Initialize k-fold cross-validation
n_splits = 5  # Number of folds
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Lists to store metrics for each fold
auc_scores = []
f1_scores = []
fold_times = []

# Calculate scale_pos_weight for class imbalance (based on training dataset)
neg_count = (y_train_full == 0).sum()
pos_count = (y_train_full == 1).sum()
scale_pos_weight = neg_count[0] / pos_count[0] if pos_count[0] > 0 else 1  # Simplified, assuming y is a Series


In [4]:
!pip show scikit-learn
!pip show imbalanced-learn

Name: scikit-learn
Version: 1.2.2
Summary: A set of python modules for machine learning and data mining
Home-page: http://scikit-learn.org
Author: 
Author-email: 
License: new BSD
Location: /usr/local/lib/python3.11/dist-packages
Requires: joblib, numpy, scipy, threadpoolctl
Required-by: bayesian-optimization, Boruta, category_encoders, cesium, eli5, fastai, hdbscan, hep_ml, imbalanced-learn, librosa, lime, mlxtend, nilearn, pyLDAvis, pynndescent, rgf-python, scikit-learn-intelex, scikit-optimize, scikit-plot, sentence-transformers, shap, sklearn-compat, sklearn-pandas, TPOT, tsfresh, umap-learn, woodwork, yellowbrick
Name: imbalanced-learn
Version: 0.13.0
Summary: Toolbox for imbalanced dataset in machine learning
Home-page: https://imbalanced-learn.org/
Author: 
Author-email: "G. Lemaitre" <g.lemaitre58@gmail.com>, "C. Aridas" <ichkoar@gmail.com>
License: 
Location: /usr/local/lib/python3.11/dist-packages
Requires: joblib, numpy, scikit-learn, scipy, sklearn-compat, threadpoolctl
Req

In [5]:
!pip install scikit-learn==1.2.2 imbalanced-learn==0.10.1 --no-deps -q
from imblearn.over_sampling import SMOTE
print("SMOTE imported successfully")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.0/226.0 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hSMOTE imported successfully


In [6]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)  # Suppress warnings for cleaner output
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, confusion_matrix, precision_recall_curve
from sklearn.ensemble import RandomForestClassifier
from imblearn.over_sampling import SMOTE
import numpy as np
import time
import shap
import matplotlib.pyplot as plt
import pandas as pd

# Verify SMOTE availability
try:
    from imblearn.over_sampling import SMOTE
    SMOTE_AVAILABLE = True
    print("SMOTE imported successfully")
except ImportError:
    SMOTE_AVAILABLE = False
    print("Error: SMOTE not available. Please install imbalanced-learn: pip install imbalanced-learn")
    raise ImportError("SMOTE is required for this script.")

# Define exclude_cols (adjust based on your dataset)
# Stratified sampling function
def create_stratified_sample(df, target_column, sample_size=2000000, min_positive=100):
    child_death_col = 'IS_CHILD_DEATH'
    critical_cases = pd.DataFrame()
    
    if target_column not in df.columns:
        raise ValueError(f"Target column '{target_column}' not found in DataFrame. Available columns: {df.columns.tolist()}")
    
    if child_death_col in df.columns:
        child_deaths = df[df[child_death_col] == 1]
        critical_cases = pd.concat([critical_cases, child_deaths])
    if 'maternal_mortality_risk' in df.columns:
        maternal_deaths = df[df['maternal_mortality_risk'] == 1]
        critical_cases = pd.concat([critical_cases, maternal_deaths])
    if 'stillbirth_risk' in df.columns and target_column == 'stillbirth_risk':
        stillbirths = df[df['stillbirth_risk'] == 1]
        critical_cases = pd.concat([critical_cases, stillbirths])
    critical_cases = critical_cases.drop_duplicates()
    
    positive_cases = df[df[target_column] == 1]
    print(f"Found {len(positive_cases)} positive cases for {target_column}")
    if len(positive_cases) == 0:
        raise ValueError(f"No positive cases ({target_column}=1) found in the dataset. Cannot proceed with sampling.")
    if len(positive_cases) < min_positive:
        print(f"Warning: Only {len(positive_cases)} positive cases found. Oversampling to {min_positive}.")
        oversampled_positives = positive_cases.sample(n=min_positive, replace=True, random_state=42)
        critical_cases = pd.concat([critical_cases, oversampled_positives]).drop_duplicates()
    
    remaining_size = sample_size - len(critical_cases)
    
    if remaining_size > 0:
        other_cases = df[~df.index.isin(critical_cases.index)]
        if len(other_cases) < remaining_size:
            print(f"Warning: Only {len(other_cases)} non-critical cases available. Adjusting sample size.")
            remaining_size = len(other_cases)
        sampled_others = other_cases.sample(n=remaining_size, random_state=42)
        final_sample = pd.concat([critical_cases, sampled_others])
    else:
        final_sample = critical_cases.sample(n=sample_size, random_state=42)
    
    return final_sample

# Diagnostic checks
print("Columns in df:", df.columns.tolist())
print("stillbirth_risk distribution:")
print(df['stillbirth_risk'].value_counts(dropna=False))
print("IS_CHILD_DEATH distribution:")
print(df['IS_CHILD_DEATH'].value_counts(dropna=False))
print("DELIVERY_OUTCOME distribution:")
print(df['DELIVERY_OUTCOME'].value_counts(dropna=False))
print("maternal_mortality_risk distribution:")
print(df['maternal_mortality_risk'].value_counts(dropna=False))

# Attempt to derive stillbirth_risk
if df['stillbirth_risk'].eq(1).sum() == 0 and 'IS_CHILD_DEATH' in df.columns:
    print("No positive cases in stillbirth_risk. Attempting to derive from IS_CHILD_DEATH...")
    df['stillbirth_risk'] = df['IS_CHILD_DEATH'].fillna(0).astype(int)
    print("Redefined stillbirth_risk distribution:")
    print(df['stillbirth_risk'].value_counts(dropna=False))

# Select target
target_column = 'stillbirth_risk' if df['stillbirth_risk'].eq(1).sum() > 0 else 'maternal_mortality_risk'
if df[target_column].eq(1).sum() == 0:
    raise ValueError(f"No positive cases found for {target_column}. Cannot proceed.")
print(f"Using target: {target_column}")

# Select numeric features
features = [col for col in df.columns if col not in exclude_cols and 
            df[col].dtype in ['float64', 'float32', 'int64', 'int32', 'int8']]

# Create sample
try:
    sample_df = create_stratified_sample(df, target_column, min_positive=100)
    print(f"Class distribution in sample_df for {target_column}:")
    print(sample_df[target_column].value_counts())
except ValueError as e:
    print(f"Error: {e}")
    raise

# Check features
if not features:
    raise ValueError(f"No valid features available for {target_column}. Available columns: {df.columns.tolist()}")
print("Features used:", features)

# Prepare data
X = sample_df[features]
y = sample_df[target_column]
X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Initialize k-fold cross-validation
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Metrics storage with lower thresholds
metrics = {
    'thresh_0_3': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_4': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_5': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_6': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'auc': []
}
fold_models = []
fold_avg_f1 = []

# Random Forest parameters optimized for recall
params = {
    'n_estimators': 200,          # Increased for better learning
    'max_depth': 15,              # Increased for more complexity
    'min_samples_split': 20,      # Reduced for flexibility
    'min_samples_leaf': 10,       # Reduced for flexibility
    'class_weight': {0: 1, 1: 10}, # Heavily weight positive class
    'random_state': 42,
    'n_jobs': -1
}

# K-fold cross-validation
for fold, (train_idx, val_idx) in enumerate(skf.split(X_train_full, y_train_full)):
    print(f"Training fold {fold + 1}/{n_splits}")
    
    X_train, X_val = X_train_full.iloc[train_idx], X_train_full.iloc[val_idx]
    y_train, y_val = y_train_full.iloc[train_idx], y_train_full.iloc[val_idx]
    
    print(f"Training class counts:\n{y_train.value_counts()}")
    print(f"Validation class counts:\n{y_val.value_counts()}")
    
    if len(y_train.unique()) < 2 or len(y_val.unique()) < 2:
        print(f"Warning: Fold {fold + 1} has only one class. Skipping.")
        continue
    
    # Force SMOTE in every fold
    print(f"Forcing SMOTE for fold {fold + 1}")
    smote = SMOTE(random_state=42, sampling_strategy=0.5)  # 1:2 positive:negative ratio
    X_train, y_train = smote.fit_resample(X_train, y_train)
    print(f"Post-SMOTE training class counts:\n{y_train.value_counts()}")
    
    start_time = time.time()
    model = RandomForestClassifier(**params)
    model.fit(X_train, y_train)
    fold_time = time.time() - start_time
    fold_models.append(model)
    
    y_pred_proba = model.predict_proba(X_val)
    if y_pred_proba.shape[1] == 1:
        print(f"Warning: Fold {fold + 1} predicts only one class. Assigning zero probabilities for class 1.")
        y_pred_proba = np.zeros(len(X_val))
    else:
        y_pred_proba = y_pred_proba[:, 1]
    
    # Evaluate at lower thresholds
    y_pred_0_3 = (y_pred_proba > 0.3).astype(int)
    y_pred_0_4 = (y_pred_proba > 0.4).astype(int)
    y_pred_0_5 = (y_pred_proba > 0.5).astype(int)
    y_pred_0_6 = (y_pred_proba > 0.6).astype(int)
    
    f1_0_3 = f1_score(y_val, y_pred_0_3)
    accuracy_0_3 = accuracy_score(y_val, y_pred_0_3)
    precision_0_3 = precision_score(y_val, y_pred_0_3, zero_division=0)
    recall_0_3 = recall_score(y_val, y_pred_0_3, zero_division=0)
    cm_0_3 = confusion_matrix(y_val, y_pred_0_3)
    
    f1_0_4 = f1_score(y_val, y_pred_0_4)
    accuracy_0_4 = accuracy_score(y_val, y_pred_0_4)
    precision_0_4 = precision_score(y_val, y_pred_0_4, zero_division=0)
    recall_0_4 = recall_score(y_val, y_pred_0_4, zero_division=0)
    cm_0_4 = confusion_matrix(y_val, y_pred_0_4)
    
    f1_0_5 = f1_score(y_val, y_pred_0_5)
    accuracy_0_5 = accuracy_score(y_val, y_pred_0_5)
    precision_0_5 = precision_score(y_val, y_pred_0_5, zero_division=0)
    recall_0_5 = recall_score(y_val, y_pred_0_5, zero_division=0)
    cm_0_5 = confusion_matrix(y_val, y_pred_0_5)
    
    f1_0_6 = f1_score(y_val, y_pred_0_6)
    accuracy_0_6 = accuracy_score(y_val, y_pred_0_6)
    precision_0_6 = precision_score(y_val, y_pred_0_6, zero_division=0)
    recall_0_6 = recall_score(y_val, y_pred_0_6, zero_division=0)
    cm_0_6 = confusion_matrix(y_val, y_pred_0_6)
    
    auc = roc_auc_score(y_val, y_pred_proba) if len(np.unique(y_val)) > 1 else 0
    
    metrics['thresh_0_3']['f1'].append(f1_0_3)
    metrics['thresh_0_3']['accuracy'].append(accuracy_0_3)
    metrics['thresh_0_3']['precision'].append(precision_0_3)
    metrics['thresh_0_3']['recall'].append(recall_0_3)
    
    metrics['thresh_0_4']['f1'].append(f1_0_4)
    metrics['thresh_0_4']['accuracy'].append(accuracy_0_4)
    metrics['thresh_0_4']['precision'].append(precision_0_4)
    metrics['thresh_0_4']['recall'].append(recall_0_4)
    
    metrics['thresh_0_5']['f1'].append(f1_0_5)
    metrics['thresh_0_5']['accuracy'].append(accuracy_0_5)
    metrics['thresh_0_5']['precision'].append(precision_0_5)
    metrics['thresh_0_5']['recall'].append(recall_0_5)
    
    metrics['thresh_0_6']['f1'].append(f1_0_6)
    metrics['thresh_0_6']['accuracy'].append(accuracy_0_6)
    metrics['thresh_0_6']['precision'].append(precision_0_6)
    metrics['thresh_0_6']['recall'].append(recall_0_6)
    
    metrics['auc'].append(auc)
    
    avg_f1 = np.mean([f1_0_3, f1_0_4, f1_0_5, f1_0_6])
    fold_avg_f1.append(avg_f1)
    
    print(f"Fold {fold + 1}")
    print(f"  AUC: {auc:.4f}, Time: {fold_time:.2f} seconds")
    print(f"  Threshold 0.3 - F1: {f1_0_3:.4f}, Accuracy: {accuracy_0_3:.4f}, Precision: {precision_0_3:.4f}, Recall: {recall_0_3:.4f}")
    print(f"  Confusion Matrix (Threshold 0.3):\n{cm_0_3}")
    print(f"  Threshold 0.4 - F1: {f1_0_4:.4f}, Accuracy: {accuracy_0_4:.4f}, Precision: {precision_0_4:.4f}, Recall: {recall_0_4:.4f}")
    print(f"  Confusion Matrix (Threshold 0.4):\n{cm_0_4}")
    print(f"  Threshold 0.5 - F1: {f1_0_5:.4f}, Accuracy: {accuracy_0_5:.4f}, Precision: {precision_0_5:.4f}, Recall: {recall_0_5:.4f}")
    print(f"  Confusion Matrix (Threshold 0.5):\n{cm_0_5}")
    print(f"  Threshold 0.6 - F1: {f1_0_6:.4f}, Accuracy: {accuracy_0_6:.4f}, Precision: {precision_0_6:.4f}, Recall: {recall_0_6:.4f}")
    print(f"  Confusion Matrix (Threshold 0.6):\n{cm_0_6}")
    
    if len(np.unique(y_val)) > 1:
        precisions, recalls, thresholds = precision_recall_curve(y_val, y_pred_proba)
        f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-10)
        optimal_threshold = thresholds[np.argmax(f1_scores)]
        print(f"  Optimal threshold: {optimal_threshold:.4f}")

print(f"\nCross-Validation Mean Metrics:")
print(f"  AUC: {np.mean(metrics['auc']):.4f} ± {np.std(metrics['auc']):.4f}")
for thresh in ['thresh_0_3', 'thresh_0_4', 'thresh_0_5', 'thresh_0_6']:
    print(f"\n{thresh.replace('_', ' ').title()}:")
    print(f"  F1 Score: {np.mean(metrics[thresh]['f1']):.4f} ± {np.std(metrics[thresh]['f1']):.4f}")
    print(f"  Accuracy: {np.mean(metrics[thresh]['accuracy']):.4f} ± {np.std(metrics[thresh]['accuracy']):.4f}")
    print(f"  Precision: {np.mean(metrics[thresh]['precision']):.4f} ± {np.std(metrics[thresh]['precision']):.4f}")
    print(f"  Recall: {np.mean(metrics[thresh]['recall']):.4f} ± {np.std(metrics[thresh]['recall']):.4f}")

if fold_avg_f1:
    best_fold_idx = np.argmax(fold_avg_f1)
    best_model = fold_models[best_fold_idx]
    print(f"\nBest Model from Fold {best_fold_idx + 1} with Average F1 Score: {fold_avg_f1[best_fold_idx]:.4f}")
    
    y_test_pred_proba = best_model.predict_proba(X_test)
    if y_test_pred_proba.shape[1] == 1:
        print("Warning: Test set prediction has only one class. Assigning zero probabilities for class 1.")
        y_test_pred_proba = np.zeros(len(X_test))
    else:
        y_test_pred_proba = y_test_pred_proba[:, 1]
    
    y_test_pred_0_3 = (y_test_pred_proba > 0.3).astype(int)
    y_test_pred_0_4 = (y_test_pred_proba > 0.4).astype(int)
    y_test_pred_0_5 = (y_test_pred_proba > 0.5).astype(int)
    y_test_pred_0_6 = (y_test_pred_proba > 0.6).astype(int)
    
    test_auc = roc_auc_score(y_test, y_test_pred_proba) if len(np.unique(y_test)) > 1 else 0
    test_f1_0_3 = f1_score(y_test, y_test_pred_0_3)
    test_accuracy_0_3 = accuracy_score(y_test, y_test_pred_0_3)
    test_precision_0_3 = precision_score(y_test, y_test_pred_0_3, zero_division=0)
    test_recall_0_3 = recall_score(y_test, y_test_pred_0_3, zero_division=0)
    test_cm_0_3 = confusion_matrix(y_test, y_test_pred_0_3)
    
    test_f1_0_4 = f1_score(y_test, y_test_pred_0_4)
    test_accuracy_0_4 = accuracy_score(y_test, y_test_pred_0_4)
    test_precision_0_4 = precision_score(y_test, y_test_pred_0_4, zero_division=0)
    test_recall_0_4 = recall_score(y_test, y_test_pred_0_4, zero_division=0)
    test_cm_0_4 = confusion_matrix(y_test, y_test_pred_0_4)
    
    test_f1_0_5 = f1_score(y_test, y_test_pred_0_5)
    test_accuracy_0_5 = accuracy_score(y_test, y_test_pred_0_5)
    test_precision_0_5 = precision_score(y_test, y_test_pred_0_5, zero_division=0)
    test_recall_0_5 = recall_score(y_test, y_test_pred_0_5, zero_division=0)
    test_cm_0_5 = confusion_matrix(y_test, y_test_pred_0_5)
    
    test_f1_0_6 = f1_score(y_test, y_test_pred_0_6)
    test_accuracy_0_6 = accuracy_score(y_test, y_test_pred_0_6)
    test_precision_0_6 = precision_score(y_test, y_test_pred_0_6, zero_division=0)
    test_recall_0_6 = recall_score(y_test, y_test_pred_0_6, zero_division=0)
    test_cm_0_6 = confusion_matrix(y_test, y_test_pred_0_6)
    
    print(f"\nTest Set Metrics (Best Model from Fold {best_fold_idx + 1}):")
    print(f"  AUC: {test_auc:.4f}")
    print(f"\nThreshold 0.3:\n  F1: {test_f1_0_3:.4f}, Accuracy: {test_accuracy_0_3:.4f}, Precision: {test_precision_0_3:.4f}, Recall: {test_recall_0_3:.4f}\n  Confusion Matrix:\n{test_cm_0_3}")
    print(f"\nThreshold 0.4:\n  F1: {test_f1_0_4:.4f}, Accuracy: {test_accuracy_0_4:.4f}, Precision: {test_precision_0_4:.4f}, Recall: {test_recall_0_4:.4f}\n  Confusion Matrix:\n{test_cm_0_4}")
    print(f"\nThreshold 0.5:\n  F1: {test_f1_0_5:.4f}, Accuracy: {test_accuracy_0_5:.4f}, Precision: {test_precision_0_5:.4f}, Recall: {test_recall_0_5:.4f}\n  Confusion Matrix:\n{test_cm_0_5}")
    print(f"\nThreshold 0.6:\n  F1: {test_f1_0_6:.4f}, Accuracy: {test_accuracy_0_6:.4f}, Precision: {test_precision_0_6:.4f}, Recall: {test_recall_0_6:.4f}\n  Confusion Matrix:\n{test_cm_0_6}")
    
    test_f1_scores = {0.3: test_f1_0_3, 0.4: test_f1_0_4, 0.5: test_f1_0_5, 0.6: test_f1_0_6}
    best_threshold = max(test_f1_scores, key=test_f1_scores.get)
    print(f"\nBest Threshold on Test Set: {best_threshold} with F1 Score: {test_f1_scores[best_threshold]:.4f}")
    
    # SHAP analysis with sampled test set
    print("\nPerforming SHAP analysis...")
    X_test_sample = X_test.sample(n=min(1000, len(X_test)), random_state=42)  # Sample to reduce computation
    explainer = shap.TreeExplainer(best_model)
    shap_values = explainer.shap_values(X_test_sample)
    
    plt.figure()
    shap.summary_plot(shap_values[1], X_test_sample, show=False)
    plt.savefig("shap_summary_plot.png")
    plt.close()
    print("SHAP summary plot saved as 'shap_summary_plot.png'")
    
    plt.figure()
    shap.summary_plot(shap_values[1], X_test_sample, plot_type="bar", show=False)
    plt.savefig("shap_importance_bar.png")
    plt.close()
    print("SHAP feature importance bar plot saved as 'shap_importance_bar.png'")
    
    shap_importance = np.abs(shap_values[1]).mean(axis=0)
    importance_df = pd.DataFrame({
        'Feature': X_test_sample.columns,
        'SHAP_Importance': shap_importance
    }).sort_values(by='SHAP_Importance', ascending=False)
    print("\nSHAP Feature Importance:")
    print(importance_df)

else:
    print("\nNo valid models trained due to single-class folds.")
#     print("\nPerforming SHAP analysis...")
#     explainer = shap.TreeExplainer(best_model)
#     shap_values = explainer.shap_values(X_test)
    
#     plt.figure()
#     shap.summary_plot(shap_values[1], X_test, show=False)
#     plt.savefig("shap_summary_plot.png")
#     plt.close()
#     print("SHAP summary plot saved as 'shap_summary_plot.png'")
    
#     plt.figure()
#     shap.summary_plot(shap_values[1], X_test, plot_type="bar", show=False)
#     plt.savefig("shap_importance_bar.png")
#     plt.close()
#     print("SHAP feature importance bar plot saved as 'shap_importance_bar.png'")
    
#     shap_importance = np.abs(shap_values[1]).mean(axis=0)
#     importance_df = pd.DataFrame({
#         'Feature': X_test.columns,
#         'SHAP_Importance': shap_importance
#     }).sort_values(by='SHAP_Importance', ascending=False)
#     print("\nSHAP Feature Importance:")
#     print(importance_df)
# else:
#     print("\nNo valid models trained due to single-class folds.")#     print("\nPerforming SHAP analysis...")
#     explainer = shap.TreeExplainer(best_model)
#     shap_values = explainer.shap_values(X_test)
    
#     plt.figure()
#     shap.summary_plot(shap_values[1], X_test, show=False)
#     plt.savefig("shap_summary_plot.png")
#     plt.close()
#     print("SHAP summary plot saved as 'shap_summary_plot.png'")
    
#     plt.figure()
#     shap.summary_plot(shap_values[1], X_test, plot_type="bar", show=False)
#     plt.savefig("shap_importance_bar.png")
#     plt.close()
#     print("SHAP feature importance bar plot saved as 'shap_importance_bar.png'")
    
#     shap_importance = np.abs(shap_values[1]).mean(axis=0)
#     importance_df = pd.DataFrame({
#         'Feature': X_test.columns,
#         'SHAP_Importance': shap_importance
#     }).sort_values(by='SHAP_Importance', ascending=False)
#     print("\nSHAP Feature Importance:")
#     print(importance_df)
# else:
#     print("\nNo valid models trained due to single-class folds.")

SMOTE imported successfully
stillbirth_risk distribution:
stillbirth_risk
0    4029571
Name: count, dtype: int64
IS_CHILD_DEATH distribution:
IS_CHILD_DEATH
0.0    3980904
1.0      48667
Name: count, dtype: int64
DELIVERY_OUTCOME distribution:
DELIVERY_OUTCOME
live             4023217
none                3361
still birth         1832
twin birth           845
iud                  303
more than two         12
-1                     1
Name: count, dtype: int64
maternal_mortality_risk distribution:
maternal_mortality_risk
0    4028194
1       1377
Name: count, dtype: int64
No positive cases in stillbirth_risk. Attempting to derive from IS_CHILD_DEATH...
Redefined stillbirth_risk distribution:
stillbirth_risk
0    3980904
1      48667
Name: count, dtype: int64
Using target: stillbirth_risk
Found 48667 positive cases for stillbirth_risk
Class distribution in sample_df for stillbirth_risk:
stillbirth_risk
0    1951333
1      48667
Name: count, dtype: int64
Features used: ['GRAVIDA', 'RNK', 'A