In [4]:
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_score, recall_score
import time
import warnings
import pyarrow.parquet as pq
warnings.filterwarnings('ignore')

data_path = '/kaggle/input/fullfinal/telangana_data_with_features_and_targets (1).parquet'
flag_map = {
    'Y': 1, 'YES': 1, 'Yes': 1, 'y': 1, 'yes': 1,
    'N': 0, 'NO': 0, 'No': 0, 'n': 0, 'no': 0,
    None: np.nan, 'None': np.nan, '': np.nan, 'nan': np.nan
}

def prepare_data_for_targets(data_path, flag_map, batch_size=10000):
    """Load and preprocess data from Parquet file in batches."""
    # Initialize an empty list to store processed chunks
    processed_chunks = []
    
    # Define required and numeric columns
    required_cols = ['MOTHER_ID', 'GRAVIDA']
    numeric_cols = ['AGE', 'AGE_preg', 'AGE_final', 'GRAVIDA', 'PARITY', 'ABORTIONS', 'TOTAL_ANC_VISITS',
                    'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'WEIGHT_max', 'HEIGHT',
                    'PHQ_SCORE_max', 'GAD_SCORE_max', 'WEIGHT_last', 'WEIGHT_first',
                    'NO_OF_WEEKS_max', 'WEIGHT_min', 'WEIGHT_mean']
    
    # Open Parquet file using pyarrow
    parquet_file = pq.ParquetFile(data_path)
    num_rows = parquet_file.metadata.num_rows
    num_batches = (num_rows + batch_size - 1) // batch_size  # Ceiling division
    
    # Iterate through batches
    for batch_idx in range(num_batches):
        # Read a batch of rows
        start_row = batch_idx * batch_size
        end_row = min(start_row + batch_size, num_rows)
        batch = parquet_file.read_row_group(batch_idx).to_pandas() if parquet_file.num_row_groups > batch_idx else pd.DataFrame()
        
        if batch.empty:
            continue
            
        # Check for required columns
        for col in required_cols:
            if col not in batch.columns:
                raise ValueError(f"Required column {col} missing in data")

        # Debug: Check non-numeric values
        for col in numeric_cols:
            if col in batch.columns and batch[col].dtype == 'object':
                non_numeric = batch[col][~batch[col].apply(lambda x: str(x).replace('.', '').isdigit() if pd.notna(x) else False)]
                if not non_numeric.empty:
                    print(f"Non-numeric values in {col}: {non_numeric.unique()[:10]}")

        # Clean GRAVIDA specifically
        if 'GRAVIDA' in batch.columns:
            batch['GRAVIDA'] = batch['GRAVIDA'].replace('nan', np.nan)
            batch['GRAVIDA'] = pd.to_numeric(batch['GRAVIDA'], errors='coerce')
            # Replace NaN with median GRAVIDA (or 1 as default) within batch
            if batch['GRAVIDA'].notna().any():
                median_gravida = batch['GRAVIDA'].median()
                if pd.isna(median_gravida):
                    median_gravida = 1.0
                batch['GRAVIDA'] = batch['GRAVIDA'].fillna(0)
                print(f"Filled {batch['GRAVIDA'].isna().sum()} GRAVIDA NaN values with 0 in batch")

        # Convert other numeric columns
        for col in numeric_cols:
            if col in batch.columns and col != 'GRAVIDA':
                batch[col] = pd.to_numeric(batch[col], errors='coerce')

        # Map flag columns
        if 'IS_CHILD_DEATH' in batch.columns and batch['IS_CHILD_DEATH'].dtype == 'object':
            batch['IS_CHILD_DEATH'] = batch['IS_CHILD_DEATH'].map(flag_map)
        if 'IS_DEFECTIVE_BIRTH' in batch.columns and batch['IS_DEFECTIVE_BIRTH'].dtype == 'object':
            batch['IS_DEFECTIVE_BIRTH'] = batch['IS_DEFECTIVE_BIRTH'].map(flag_map)

        # Fill NaN for numeric columns
        batch_numeric_cols = batch.select_dtypes(include=['float64', 'float32', 'int64', 'int32', 'int8']).columns
        batch[batch_numeric_cols] = batch[batch_numeric_cols].fillna(0)

        # Append processed batch
        processed_chunks.append(batch)

    # Concatenate all chunks
    df = pd.concat(processed_chunks, ignore_index=True)
    return df

# Execute the function
df = prepare_data_for_targets(data_path, flag_map)

Non-numeric values in GRAVIDA: ['nan']
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch


In [5]:
import pandas as pd

# Define target column for maternal mortality
target_columns = ['maternal_mortality_risk']

def create_stratified_sample(df, target_column, sample_size=2000000):
    """Create a stratified sample ensuring all critical cases for maternal mortality are included."""
    maternal_death_col = 'maternal_mortality_risk'
    critical_cases = pd.DataFrame()
    
    # Prioritize maternal mortality cases
    if maternal_death_col in df.columns:
        maternal_deaths = df[df[maternal_death_col] == 1]
        critical_cases = pd.concat([critical_cases, maternal_deaths])
    
    # Optionally include related critical cases (e.g., stillbirths or child deaths) if relevant
    if 'stillbirth_risk' in df.columns:
        stillbirths = df[df['stillbirth_risk'] == 1]
        critical_cases = pd.concat([critical_cases, stillbirths])
    if 'IS_CHILD_DEATH' in df.columns:
        child_deaths = df[df['IS_CHILD_DEATH'] == 1]
        critical_cases = pd.concat([critical_cases, child_deaths])
    
    critical_cases = critical_cases.drop_duplicates()
    
    remaining_size = sample_size - len(critical_cases)
    
    if remaining_size > 0:
        other_cases = df[~df.index.isin(critical_cases.index)]
        sampled_others = other_cases.sample(n=remaining_size, random_state=42)
        final_sample = pd.concat([critical_cases, sampled_others])
    else:
        final_sample = critical_cases.sample(n=sample_size, random_state=42)
    
    return final_sample

# Create sample for maternal mortality
sample_df = create_stratified_sample(df, 'maternal_mortality_risk')

# Define columns to exclude to prevent data leakage
exclude_cols = [
    # Targets & labels
    'maternal_mortality_risk', 'stillbirth_risk', 'premature_birth_risk',
    'birth_defect_risk', 'anc_dropout', 'high_risk_pregnancy',

    # Derived risk scores and flags
    'total_risk_factors', 'clinical_risk_score', 'risk_level', 'predicted_risk',
    'total_missed_visits', 'age_risk_score', 'demographic_risk', 'anemia_risk_score',
    'overall_risk_score',

    # Identifiers
    'MOTHER_ID', 'ANC_ID', 'CHILD_ID', 'unique_id', 'EID', 'UID_NUMBER',

    # Delivery outcome or post-delivery info (leakage)
    'WEIGHT_child_mean', 'DELIVERY_MODE', 'MATERNAL_OUTCOME', 'IS_DELIVERED',
    'DELIVERY_OUTCOME', 'DATE_OF_DELIVERY', 'PLACE_OF_DELIVERY', 'DEL_TIME',
    'DATE_OF_DISCHARGE', 'DISCHARGE_TIME', 'JSY_BENEFICIARY',
    'IS_MOTHER_ALIVE', 'IS_CHILD_DEATH', 'CHILD_DEATH_DATE', 'CHILD_DEATH_REASON',
    'DEFECT_HEALTH_CENTER', 'IS_DEFECTIVE_BIRTH', 'BIRTH_DEFECT_TYPE',
    'BIRTH_DEFECT_SUBTYPE', 'DEFECT_SUBTYPE_OTHER', 'DEFECT_TYPE_OTHER',
    'NOTIFICATION_SENT', 'FBIR_COMPLETED_BY_ANM', 'NEWBORN_SCREENING',
    'DEATH_REASON_OTHER', 'SNCU_ADMITTED', 'SNCU_REFERRAL_HOSPITAL',
    'TERTIARY_REFERRAL_HOSPITAL', 'OTHER_REFERRAL_HOSPITAL',
    'DATE_OF_DEATH', 'REASON_FOR_DEATH', 'PLACE_OF_DEATH',
    'INDICATION_FOR_C_SECTION', 'CH', 'CAH', 'GALACTOCEMIA', 'G6PDD', 'BIOTINIDASE',

    # Delivery-specific process info
    'DELIVERY_INSTITUTION', 'DELIVERY_DONE_BY', 'CONDUCT_BY',
    'MISOPROSTAL_TABLET', 'DEL_COMPLICATIONS', 'OTHER_DEL_COMPLICATIONS',
    'NOTIFICATION_SENT_del', 'FBIR_COMPLETED_BY_ANM_del',

    # Administrative / logging columns
    'REGISTRATION_DT', 'REGTYPE', 'CURRENT_USR', 'OTHER_STATE_PLACE',
    'OTHER_STATE_PLACE_FILEPATH', 'OTHER_GOVT_PLACE_FILEPATH',
    'ANC2_TAG_FAC_ID', 'ANC3_TAG_FAC_ID',

    # Facility and geographic info
    'ANC_INSTITUTE', 'FACILITY_TYPE', 'FACILITY_NAME', 'DOCTOR_ANM',
    'DISTRICT_anc', 'DISTRICT_child',

    # Feeding / newborn care
    'IS_BF_IN_HOUR', 'FEEDING_TYPE', 'DATE_OF_FIRST_FEEDING',
    'TIME_OF_FIRST_FEEDING', 'BABY_ON_MEDICATION', 'MEDICATION_REMARKS',
    'DATE_OF_BLOODSAMPLE_COLLECTION', 'TIME_OF_BLOODSAMPLE_COLLECTION',
    'HOURS_OF_SAMPLE_COLLECTION', 'TRANSFUSION_DONE',

    # Screening/test results
    'VDRL_DATE', 'VDRL_STATUS', 'HIV_DATE', 'HIV_STATUS', 'HBSAG_DATE',
    'HBSAG_STATUS', 'HEP_DATE', 'HEP_STATUS', 'VDRL_RESULT',
    'HIV_RESULT', 'HBSAG_RESULT', 'HEP_RESULT',

    # Missed ANC flag columns
    'MISSANC1FLG', 'MISSANC2FLG', 'MISSANC3FLG', 'MISSANC4FLG',

    # Manually added known leaky or post-hoc columns
    'HIGH_RISKS', 'DISEASES', 'CHILD_NAME',

    # Additional leakage columns
    'age_category', 'multigravida', 'grand_multipara', 'no_anc',
    'missed_first_anc', 'consecutive_missed', 'severe_hypertension',
    'low_birth_weight', 'very_low_birth_weight', 'avg_birth_weight_low',
    'depression', 'severe_depression', 'anxiety', 'severe_anxiety',
    'hemoglobin_trend', 'TT_DATE', 'MAL_PRESENT', 'IS_ADMITTED_SNCU',
    'IS_PREV_PREG', 'ANC1FLG', 'ANC2FLG', 'ANC3FLG', 'ANC4FLG', 'ANC_DATE',
    'DEATH', 'EXP_DOD', 'DELIVERY_PLACE', 'EXP_DOD_preg', 'FASTING', 'LMP_DT',
    'SCREENED_FOR_MENTAL_HEALTH', 'AGE', 'AGE_final', 'AGE_preg',
    'GENDER', 'TIME_OF_BIRTH', 'TWIN_PREGNANCY_max', 'TOTAL_ANC_VISITS', 'RNK',
    'PHQ_SCORE_max', 'GAD_SCORE_max', 'mental_health_risk', 'NO_OF_WEEKS_max',
    'DELIVERY_INSTITUTION', 'DELIVERY_DONE_BY', 'CONDUCT_BY', 'OTHER_NAME',
    'NOTIFICATION_SENT', 'FBIR_COMPLETED_BY_ANM', 'OTHER_STATE_PLACE',
    'OTHER_STATE_PLACE_FILEPATH', 'OTHER_GOVT_PLACE_FILEPATH',
    'bp_risk'
]

# Combine exclude_cols
exclude_cols = list(set(exclude_cols))

# Select numeric features
features = [col for col in df.columns if col not in exclude_cols and df[col].dtype in ['float64', 'float32', 'int64', 'int32', 'int8']]

if not features:
    print(f"Skipping {target_columns}: No valid features available.")
else:
    print(f"Features used for maternal_mortality_risk: {features}")

Features used for maternal_mortality_risk: ['GRAVIDA', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'HEMOGLOBIN_max', 'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 'WEIGHT_child_min', 'age_adolescent', 'age_elderly', 'age_very_young', 'previous_loss', 'recurrent_loss', 'gravida_parity_ratio', 'inadequate_anc', 'irregular_anc', 'anemia_mild', 'anemia_moderate', 'anemia_severe', 'ever_severe_anemia', 'systolic_bp', 'diastolic_bp', 'hypertension', 'BMI', 'underweight', 'obese', 'normal_weight', 'weight_gain', 'weight_gain_per_week', 'inadequate_weight_gain']


In [6]:
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, f1_score
import lightgbm as lgb
import numpy as np
import time

# Assuming sample_df, features, and target_column are defined (correcting target_columns to target_column)
X = sample_df[features]
y = sample_df[target_columns]  # Changed from target_columns to target_column

# Initial train-test split
X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Initialize k-fold cross-validation
n_splits = 5  # Number of folds
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Lists to store metrics for each fold
auc_scores = []
f1_scores = []
fold_times = []

# Calculate scale_pos_weight for class imbalance (based on training dataset)
neg_count = (y_train_full == 0).sum()
pos_count = (y_train_full == 1).sum()
scale_pos_weight = neg_count[0] / pos_count[0] if pos_count[0] > 0 else 1  # Simplified, assuming y is a Series


In [7]:
!pip show scikit-learn
!pip show imbalanced-learn
!pip install scikit-learn==1.2.2 imbalanced-learn==0.10.1 --no-deps -q
from imblearn.over_sampling import SMOTE
print("SMOTE imported successfully")

Name: scikit-learn
Version: 1.2.2
Summary: A set of python modules for machine learning and data mining
Home-page: http://scikit-learn.org
Author: 
Author-email: 
License: new BSD
Location: /usr/local/lib/python3.11/dist-packages
Requires: joblib, numpy, scipy, threadpoolctl
Required-by: bayesian-optimization, Boruta, category_encoders, cesium, eli5, fastai, hdbscan, hep_ml, imbalanced-learn, librosa, lime, mlxtend, nilearn, pyLDAvis, pynndescent, rgf-python, scikit-learn-intelex, scikit-optimize, scikit-plot, sentence-transformers, shap, sklearn-compat, sklearn-pandas, TPOT, tsfresh, umap-learn, woodwork, yellowbrick
Name: imbalanced-learn
Version: 0.13.0
Summary: Toolbox for imbalanced dataset in machine learning
Home-page: https://imbalanced-learn.org/
Author: 
Author-email: "G. Lemaitre" <g.lemaitre58@gmail.com>, "C. Aridas" <ichkoar@gmail.com>
License: 
Location: /usr/local/lib/python3.11/dist-packages
Requires: joblib, numpy, scikit-learn, scipy, sklearn-compat, threadpoolctl
Req

In [8]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, confusion_matrix, precision_recall_curve, auc
from imblearn.combine import SMOTEENN
from category_encoders import TargetEncoder
import time
import shap
import matplotlib.pyplot as plt
import pyarrow.parquet as pq

# Verify SMOTEENN availability
try:
    from imblearn.combine import SMOTEENN
    print("SMOTEENN imported successfully")
except ImportError:
    print("Error: SMOTEENN not available. Install: pip install imbalanced-learn==0.10.1")
    raise ImportError("SMOTEENN is required.")

# Data loading and preprocessing function
def prepare_data_for_targets(data_path, flag_map, batch_size=10000):
    """Load and preprocess data from Parquet file in batches."""
    processed_chunks = []
    required_cols = ['MOTHER_ID', 'GRAVIDA']
    numeric_cols = [
        'GRAVIDA', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'HEMOGLOBIN_max',
        'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 'WEIGHT_child_min', 'systolic_bp', 'diastolic_bp',
        'BMI', 'weight_gain', 'weight_gain_per_week', 'gravida_parity_ratio', 'TOTAL_ANC_VISITS', 'NO_OF_WEEKS_max',
        'PHQ_SCORE_max', 'GAD_SCORE_max', 'anemia_severe_systolic_bp', 'hypertension_hemoglobin'
    ]
    flag_cols = [
        'age_adolescent', 'age_elderly', 'age_very_young', 'previous_loss', 'recurrent_loss', 'inadequate_anc',
        'irregular_anc', 'anemia_mild', 'anemia_moderate', 'anemia_severe', 'ever_severe_anemia', 'hypertension',
        'underweight', 'obese', 'normal_weight', 'inadequate_weight_gain', 'IS_CHILD_DEATH', 'IS_DEFECTIVE_BIRTH'
    ]
    categorical_cols = ['FACILITY_TYPE', 'BLOOD_GRP', 'SYS_DISEASE']
    
    parquet_file = pq.ParquetFile(data_path)
    num_rows = parquet_file.metadata.num_rows
    num_batches = (num_rows + batch_size - 1) // batch_size
    
    for batch_idx in range(num_batches):
        batch = parquet_file.read_row_group(batch_idx).to_pandas() if parquet_file.num_row_groups > batch_idx else pd.DataFrame()
        if batch.empty:
            continue
        
        for col in required_cols:
            if col not in batch.columns:
                raise ValueError(f"Required column {col} missing in data")
        
        # Convert numeric columns
        for col in numeric_cols:
            if col in batch.columns:
                batch[col] = pd.to_numeric(batch[col], errors='coerce')
        
        # Map flag columns
        for col in flag_cols:
            if col in batch.columns:
                if batch[col].dtype == 'object':
                    batch[col] = batch[col].map(flag_map)
                batch[col] = batch[col].fillna(0)  # Impute NaNs in flags with 0
        
        # Impute numeric NaNs with median
        for col in numeric_cols:
            if col in batch.columns and batch[col].isna().any():
                batch[col] = batch[col].fillna(batch[col].median())
        
        # Impute categorical NaNs with mode
        for col in categorical_cols:
            if col in batch.columns:
                mode_val = batch[col].mode().iloc[0] if not batch[col].mode().empty else 'Unknown'
                batch[col] = batch[col].fillna(mode_val)
                batch[col] = batch[col].astype(str)  # Ensure string type
        
        processed_chunks.append(batch)
    
    df = pd.concat(processed_chunks, ignore_index=True)
    
    # Limit SYS_DISEASE to top 10 categories
    if 'SYS_DISEASE' in df.columns:
        top_categories = df['SYS_DISEASE'].value_counts().index[:10]
        df['SYS_DISEASE'] = df['SYS_DISEASE'].apply(lambda x: x if x in top_categories else 'Other')
    
    # Validate and create interaction features
    for col in ['anemia_severe', 'systolic_bp', 'hypertension', 'HEMOGLOBIN_mean']:
        if col in df.columns and df[col].isna().any():
            print(f"Warning: Imputing NaNs in {col} before interaction features")
            df[col] = df[col].fillna(df[col].median() if col in numeric_cols else 0)
    
    if 'anemia_severe' in df.columns and 'systolic_bp' in df.columns:
        df['anemia_severe_systolic_bp'] = df['anemia_severe'] * df['systolic_bp']
    if 'hypertension' in df.columns and 'HEMOGLOBIN_mean' in df.columns:
        df['hypertension_hemoglobin'] = df['hypertension'] * df['HEMOGLOBIN_mean']
    
    # Ensure no NaNs in interaction features
    for col in ['anemia_severe_systolic_bp', 'hypertension_hemoglobin']:
        if col in df.columns and df[col].isna().any():
            df[col] = df[col].fillna(df[col].median())
    
    return df

# Stratified sampling function
def create_stratified_sample(df, target_column, sample_size=1000000, min_positive=1000):
    """Create a stratified sample ensuring all maternal mortality cases are included."""
    if target_column not in df.columns:
        raise ValueError(f"Target column '{target_column}' not found in DataFrame. Available columns: {df.columns.tolist()}")
    
    maternal_deaths = df[df[target_column] == 1]
    print(f"Found {len(maternal_deaths)} maternal mortality cases")
    
    critical_cases = maternal_deaths.drop_duplicates()
    
    if len(critical_cases) == 0:
        raise ValueError(f"No positive cases ({target_column}=1) found.")
    if len(critical_cases) < min_positive:
        print(f"Warning: Only {len(critical_cases)} positive cases found. Oversampling to {min_positive}.")
        oversampled_positives = critical_cases.sample(n=min_positive, replace=True, random_state=42)
        critical_cases = pd.concat([critical_cases, oversampled_positives]).drop_duplicates()
    
    remaining_size = sample_size - len(critical_cases)
    if remaining_size > 0:
        other_cases = df[~df.index.isin(critical_cases.index)]
        if len(other_cases) < remaining_size:
            print(f"Warning: Only {len(other_cases)} non-critical cases available. Adjusting sample size.")
            remaining_size = len(other_cases)
        sampled_others = other_cases.sample(n=remaining_size, random_state=42)
        final_sample = pd.concat([critical_cases, sampled_others])
    else:
        final_sample = critical_cases.sample(n=sample_size, random_state=42)
    
    return final_sample

# Main script
data_path = '/kaggle/input/fullfinal/telangana_data_with_features_and_targets (1).parquet'
flag_map = {
    'Y': 1, 'YES': 1, 'Yes': 1, 'y': 1, 'yes': 1,
    'N': 0, 'NO': 0, 'No': 0, 'n': 0, 'no': 0,
    None: np.nan, 'None': np.nan, '': np.nan, 'nan': np.nan, 'NULL': np.nan, 'missing': np.nan
}

# Load and preprocess data
df = prepare_data_for_targets(data_path, flag_map)

# Diagnostic checks
print("Columns in df:", df.columns.tolist())
print("maternal_mortality_risk distribution:")
print(df['maternal_mortality_risk'].value_counts(dropna=False))

# Select target
target_column = 'maternal_mortality_risk'
if df[target_column].eq(1).sum() == 0:
    raise ValueError(f"No positive cases found for {target_column}. Cannot proceed.")
print(f"Using target: {target_column}")

# Create sample
sample_df = create_stratified_sample(df, target_column)
print(f"Class distribution in sample_df for {target_column}:")
print(sample_df[target_column].value_counts())

# Select features
numeric_features = [
    'GRAVIDA', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'HEMOGLOBIN_max',
    'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 'WEIGHT_child_min', 'systolic_bp', 'diastolic_bp',
    'BMI', 'weight_gain', 'weight_gain_per_week', 'gravida_parity_ratio', 'TOTAL_ANC_VISITS', 'NO_OF_WEEKS_max',
    'anemia_severe_systolic_bp', 'hypertension_hemoglobin'
]
flag_features = [
    'age_adolescent', 'age_elderly', 'age_very_young', 'previous_loss', 'recurrent_loss', 'inadequate_anc',
    'irregular_anc', 'anemia_mild', 'anemia_moderate', 'anemia_severe', 'ever_severe_anemia', 'hypertension',
    'underweight', 'obese', 'normal_weight', 'inadequate_weight_gain'
]
categorical_features = ['FACILITY_TYPE', 'BLOOD_GRP', 'SYS_DISEASE']
features = numeric_features + flag_features + categorical_features

if not features:
    raise ValueError(f"No valid features available for {target_column}. Available columns: {df.columns.tolist()}")
print("Features used:", features)

# Prepare data
X = sample_df[features]
y = sample_df[target_column]

# Check for NaNs before encoding
print("Checking for NaNs in X before encoding:")
print(X.isna().sum())
for col in X.columns:
    if X[col].isna().any():
        if col in numeric_features:
            X[col] = X[col].fillna(X[col].median())
        elif col in categorical_features:
            mode_val = X[col].mode().iloc[0] if not X[col].mode().empty else 'Unknown'
            X[col] = X[col].fillna(mode_val)
        elif col in flag_features:
            X[col] = X[col].fillna(0)

# Encode categorical features
encoder = TargetEncoder(cols=categorical_features, handle_missing='value', handle_unknown='value')
X = encoder.fit_transform(X, y)

# Check for NaNs after encoding
print("Checking for NaNs in X after encoding:")
print(X.isna().sum())
X = X.fillna(X.median(numeric_only=True))  # Impute any remaining NaNs

# Initial train-test split
X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Initialize k-fold cross-validation
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Metrics storage
metrics = {
    'thresh_0_1': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'pr_auc': []},
    'thresh_0_2': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'pr_auc': []},
    'thresh_0_3': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'pr_auc': []},
    'thresh_0_4': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'pr_auc': []},
    'auc': []
}
fold_models = []
fold_times = []
fold_avg_f1 = []

# Calculate scale_pos_weight
neg_count = (y_train_full == 0).sum()
pos_count = (y_train_full == 1).sum()
scale_pos_weight = neg_count / pos_count * 1.5
print(f"Scale pos weight: {scale_pos_weight:.2f}")

# LightGBM parameters
params = {
    'objective': 'binary',
    'metric': 'auc',
    'n_estimators': 500,
    'max_depth': 7,
    'learning_rate': 0.05,
    'scale_pos_weight': scale_pos_weight,
    'min_child_weight': 5,
    'random_state': 42,
    'n_jobs': -1,
    'verbosity': -1
}

# K-fold cross-validation
for fold, (train_idx, val_idx) in enumerate(skf.split(X_train_full, y_train_full)):
    print(f"Training fold {fold + 1}/{n_splits}")
    
    X_train, X_val = X_train_full.iloc[train_idx], X_train_full.iloc[val_idx]
    y_train, y_val = y_train_full.iloc[train_idx], y_train_full.iloc[val_idx]
    
    print(f"Training class counts:\n{y_train.value_counts()}")
    print(f"Validation class counts:\n{y_val.value_counts()}")
    
    if len(y_train.unique()) < 2 or len(y_val.unique()) < 2:
        print(f"Warning: Fold {fold + 1} has only one class. Skipping.")
        continue
    
    # Check for NaNs before SMOTEENN
    print(f"Checking for NaNs in X_train before SMOTEENN for fold {fold + 1}:")
    print(X_train.isna().sum())
    if X_train.isna().any().any():
        for col in X_train.columns:
            if X_train[col].isna().any():
                if col in numeric_features:
                    X_train[col] = X_train[col].fillna(X_train[col].median())
                    X_val[col] = X_val[col].fillna(X_train[col].median())
                elif col in categorical_features:
                    mode_val = X_train[col].mode().iloc[0] if not X_train[col].mode().empty else 'Unknown'
                    X_train[col] = X_train[col].fillna(mode_val)
                    X_val[col] = X_val[col].fillna(mode_val)
                elif col in flag_features:
                    X_train[col] = X_train[col].fillna(0)
                    X_val[col] = X_val[col].fillna(0)
    
    # Verify no NaNs remain
    if X_train.isna().any().any():
        raise ValueError(f"NaNs remain in X_train after imputation in fold {fold + 1}")
    
    # Apply SMOTEENN
    print(f"Applying SMOTEENN for fold {fold + 1}")
    smoteenn = SMOTEENN(random_state=42, sampling_strategy=0.1)
    X_train, y_train = smoteenn.fit_resample(X_train, y_train)
    print(f"Post-SMOTEENN training class counts:\n{y_train.value_counts()}")
    
    start_time = time.time()
    model = lgb.LGBMClassifier(**params)
    model.fit(X_train, y_train, eval_set=[(X_val, y_val)], eval_metric='auc', callbacks=[lgb.early_stopping(50, verbose=False)])
    fold_time = time.time() - start_time
    fold_models.append(model)
    fold_times.append(fold_time)
    
    y_pred_proba = model.predict_proba(X_val)[:, 1]
    precisions, recalls, pr_thresholds = precision_recall_curve(y_val, y_pred_proba)
    pr_auc = auc(recalls, precisions)
    
    for thresh, thresh_name in [(0.1, 'thresh_0_1'), (0.2, 'thresh_0_2'), (0.3, 'thresh_0_3'), (0.4, 'thresh_0_4')]:
        y_pred = (y_pred_proba > thresh).astype(int)
        metrics[thresh_name]['f1'].append(f1_score(y_val, y_pred))
        metrics[thresh_name]['accuracy'].append(accuracy_score(y_val, y_pred))
        metrics[thresh_name]['precision'].append(precision_score(y_val, y_pred, zero_division=0))
        metrics[thresh_name]['recall'].append(recall_score(y_val, y_pred, zero_division=0))
        metrics[thresh_name]['pr_auc'].append(pr_auc)
    
    auc_score = roc_auc_score(y_val, y_pred_proba) if len(np.unique(y_val)) > 1 else 0
    metrics['auc'].append(auc_score)
    
    avg_f1 = np.mean([metrics[f'thresh_0_{i}']['f1'][-1] for i in [1, 2, 3, 4]])
    fold_avg_f1.append(avg_f1)
    
    print(f"Fold {fold + 1}")
    print(f"  AUC: {auc_score:.4f}, PR-AUC: {pr_auc:.4f}, Time: {fold_time:.2f} seconds")
    for thresh, thresh_name in [(0.1, 'thresh_0_1'), (0.2, 'thresh_0_2'), (0.3, 'thresh_0_3'), (0.4, 'thresh_0_4')]:
        print(f"  Threshold {thresh} - F1: {metrics[thresh_name]['f1'][-1]:.4f}, Accuracy: {metrics[thresh_name]['accuracy'][-1]:.4f}, "
              f"Precision: {metrics[thresh_name]['precision'][-1]:.4f}, Recall: {metrics[thresh_name]['recall'][-1]:.4f}")
        print(f"  Confusion Matrix (Threshold {thresh}):\n{confusion_matrix(y_val, (y_pred_proba > thresh).astype(int))}")

# Cross-validation results
print(f"\nCross-Validation Mean Metrics:")
print(f"  AUC: {np.mean(metrics['auc']):.4f} ± {np.std(metrics['auc']):.4f}")
for thresh_name in ['thresh_0_1', 'thresh_0_2', 'thresh_0_3', 'thresh_0_4']:
    print(f"\n{thresh_name.replace('_', ' ').title()}:")
    print(f"  F1 Score: {np.mean(metrics[thresh_name]['f1']):.4f} ± {np.std(metrics[thresh_name]['f1']):.4f}")
    print(f"  Accuracy: {np.mean(metrics[thresh_name]['accuracy']):.4f} ± {np.std(metrics[thresh_name]['accuracy']):.4f}")
    print(f"  Precision: {np.mean(metrics[thresh_name]['precision']):.4f} ± {np.std(metrics[thresh_name]['precision']):.4f}")
    print(f"  Recall: {np.mean(metrics[thresh_name]['recall']):.4f} ± {np.std(metrics[thresh_name]['recall']):.4f}")
    print(f"  PR-AUC: {np.mean(metrics[thresh_name]['pr_auc']):.4f} ± {np.std(metrics[thresh_name]['pr_auc']):.4f}")

# Evaluate best model on test set
if fold_avg_f1:
    best_fold_idx = np.argmax(fold_avg_f1)
    best_model = fold_models[best_fold_idx]
    print(f"\nBest Model from Fold {best_fold_idx + 1} with Average F1 Score: {fold_avg_f1[best_fold_idx]:.4f}")
    
    y_test_pred_proba = best_model.predict_proba(X_test)[:, 1]
    precisions, recalls, _ = precision_recall_curve(y_test, y_test_pred_proba)
    test_pr_auc = auc(recalls, precisions)
    
    test_metrics = {}
    for thresh in [0.1, 0.2, 0.3, 0.4]:
        y_test_pred = (y_test_pred_proba > thresh).astype(int)
        test_metrics[thresh] = {
            'f1': f1_score(y_test, y_test_pred),
            'accuracy': accuracy_score(y_test, y_test_pred),
            'precision': precision_score(y_test, y_test_pred, zero_division=0),
            'recall': recall_score(y_test, y_test_pred, zero_division=0),
            'cm': confusion_matrix(y_test, y_test_pred)
        }
    
    test_auc = roc_auc_score(y_test, y_test_pred_proba) if len(np.unique(y_test)) > 1 else 0
    print(f"\nTest Set Metrics (Best Model from Fold {best_fold_idx + 1}):")
    print(f"  AUC: {test_auc:.4f}, PR-AUC: {test_pr_auc:.4f}")
    for thresh in [0.1, 0.2, 0.3, 0.4]:
        print(f"\nThreshold {thresh}:\n  F1: {test_metrics[thresh]['f1']:.4f}, Accuracy: {test_metrics[thresh]['accuracy']:.4f}, "
              f"Precision: {test_metrics[thresh]['precision']:.4f}, Recall: {test_metrics[thresh]['recall']:.4f}\n  "
              f"Confusion Matrix:\n{test_metrics[thresh]['cm']}")
    
    test_f1_scores = {thresh: test_metrics[thresh]['f1'] for thresh in [0.1, 0.2, 0.3, 0.4]}
    best_threshold = max(test_f1_scores, key=test_f1_scores.get)
    print(f"\nBest Threshold on Test Set: {best_threshold} with F1 Score: {test_f1_scores[best_threshold]:.4f}")
    
    # SHAP analysis
    print("\nPerforming SHAP analysis...")
    X_test_sample = X_test.sample(n=min(1000, len(X_test)), random_state=42)
    explainer = shap.TreeExplainer(best_model)
    shap_values = explainer.shap_values(X_test_sample)[1]  # Use positive class SHAP values
    
    plt.figure()
    shap.summary_plot(shap_values, X_test_sample, show=False)
    plt.savefig("shap_summary_plot_maternal.png")
    plt.close()
    print("SHAP summary plot saved as 'shap_summary_plot_maternal.png'")
    
    plt.figure()
    shap.summary_plot(shap_values, X_test_sample, plot_type="bar", show=False)
    plt.savefig("shap_importance_bar_maternal.png")
    plt.close()
    print("SHAP feature importance bar plot saved as 'shap_importance_bar_maternal.png'")
    
    shap_importance = np.abs(shap_values).mean(axis=0)
    importance_df = pd.DataFrame({
        'Feature': X_test_sample.columns,
        'SHAP_Importance': shap_importance
    }).sort_values(by='SHAP_Importance', ascending=False)
    print("\nSHAP Feature Importance:")
    print(importance_df)
else:
    print("\nNo valid models trained due to single-class folds.")

SMOTEENN imported successfully
maternal_mortality_risk distribution:
maternal_mortality_risk
0    4028194
1       1377
Name: count, dtype: int64
Using target: maternal_mortality_risk
Found 1377 maternal mortality cases
Class distribution in sample_df for maternal_mortality_risk:
maternal_mortality_risk
0    998623
1      1377
Name: count, dtype: int64
Features used: ['GRAVIDA', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'HEMOGLOBIN_max', 'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 'WEIGHT_child_min', 'systolic_bp', 'diastolic_bp', 'BMI', 'weight_gain', 'weight_gain_per_week', 'gravida_parity_ratio', 'TOTAL_ANC_VISITS', 'NO_OF_WEEKS_max', 'anemia_severe_systolic_bp', 'hypertension_hemoglobin', 'age_adolescent', 'age_elderly', 'age_very_young', 'previous_loss', 'recurrent_loss', 'inadequate_anc', 'irregular_anc', 'anemia_mild', 'anemia_moderate', 'anemia_severe', 'ever_severe_anemia', 'hypertension', 'underweight', 'obese', 'normal_weight', 'inadequa