In [2]:
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_score, recall_score
import time
import warnings
import pyarrow.parquet as pq
warnings.filterwarnings('ignore')

data_path = '/kaggle/input/fullfinal/telangana_data_with_features_and_targets (1).parquet'
flag_map = {
    'Y': 1, 'YES': 1, 'Yes': 1, 'y': 1, 'yes': 1,
    'N': 0, 'NO': 0, 'No': 0, 'n': 0, 'no': 0,
    None: np.nan, 'None': np.nan, '': np.nan, 'nan': np.nan
}

def prepare_data_for_targets(data_path, flag_map, batch_size=10000):
    """Load and preprocess data from Parquet file in batches."""
    # Initialize an empty list to store processed chunks
    processed_chunks = []
    
    # Define required and numeric columns
    required_cols = ['MOTHER_ID', 'GRAVIDA']
    numeric_cols = ['AGE', 'AGE_preg', 'AGE_final', 'GRAVIDA', 'PARITY', 'ABORTIONS', 'TOTAL_ANC_VISITS',
                    'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'WEIGHT_max', 'HEIGHT',
                    'PHQ_SCORE_max', 'GAD_SCORE_max', 'WEIGHT_last', 'WEIGHT_first',
                    'NO_OF_WEEKS_max', 'WEIGHT_min', 'WEIGHT_mean']
    
    # Open Parquet file using pyarrow
    parquet_file = pq.ParquetFile(data_path)
    num_rows = parquet_file.metadata.num_rows
    num_batches = (num_rows + batch_size - 1) // batch_size  # Ceiling division
    
    # Iterate through batches
    for batch_idx in range(num_batches):
        # Read a batch of rows
        start_row = batch_idx * batch_size
        end_row = min(start_row + batch_size, num_rows)
        batch = parquet_file.read_row_group(batch_idx).to_pandas() if parquet_file.num_row_groups > batch_idx else pd.DataFrame()
        
        if batch.empty:
            continue
            
        # Check for required columns
        for col in required_cols:
            if col not in batch.columns:
                raise ValueError(f"Required column {col} missing in data")

        # Debug: Check non-numeric values
        for col in numeric_cols:
            if col in batch.columns and batch[col].dtype == 'object':
                non_numeric = batch[col][~batch[col].apply(lambda x: str(x).replace('.', '').isdigit() if pd.notna(x) else False)]
                if not non_numeric.empty:
                    print(f"Non-numeric values in {col}: {non_numeric.unique()[:10]}")

        # Clean GRAVIDA specifically
        if 'GRAVIDA' in batch.columns:
            batch['GRAVIDA'] = batch['GRAVIDA'].replace('nan', np.nan)
            batch['GRAVIDA'] = pd.to_numeric(batch['GRAVIDA'], errors='coerce')
            # Replace NaN with median GRAVIDA (or 1 as default) within batch
            if batch['GRAVIDA'].notna().any():
                median_gravida = batch['GRAVIDA'].median()
                if pd.isna(median_gravida):
                    median_gravida = 1.0
                batch['GRAVIDA'] = batch['GRAVIDA'].fillna(0)
                print(f"Filled {batch['GRAVIDA'].isna().sum()} GRAVIDA NaN values with 0 in batch")

        # Convert other numeric columns
        for col in numeric_cols:
            if col in batch.columns and col != 'GRAVIDA':
                batch[col] = pd.to_numeric(batch[col], errors='coerce')

        # Map flag columns
        if 'IS_CHILD_DEATH' in batch.columns and batch['IS_CHILD_DEATH'].dtype == 'object':
            batch['IS_CHILD_DEATH'] = batch['IS_CHILD_DEATH'].map(flag_map)
        if 'IS_DEFECTIVE_BIRTH' in batch.columns and batch['IS_DEFECTIVE_BIRTH'].dtype == 'object':
            batch['IS_DEFECTIVE_BIRTH'] = batch['IS_DEFECTIVE_BIRTH'].map(flag_map)

        # Fill NaN for numeric columns
        batch_numeric_cols = batch.select_dtypes(include=['float64', 'float32', 'int64', 'int32', 'int8']).columns
        batch[batch_numeric_cols] = batch[batch_numeric_cols].fillna(0)

        # Append processed batch
        processed_chunks.append(batch)

    # Concatenate all chunks
    df = pd.concat(processed_chunks, ignore_index=True)
    return df

# Execute the function
df = prepare_data_for_targets(data_path, flag_map)

Non-numeric values in GRAVIDA: ['nan']
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch


In [1]:
!pip show scikit-learn
!pip show imbalanced-learn
!pip install scikit-learn==1.2.2 imbalanced-learn==0.10.1 --no-deps -q
from imblearn.over_sampling import SMOTE
print("SMOTE imported successfully")

Name: scikit-learn
Version: 1.2.2
Summary: A set of python modules for machine learning and data mining
Home-page: http://scikit-learn.org
Author: 
Author-email: 
License: new BSD
Location: /usr/local/lib/python3.11/dist-packages
Requires: joblib, numpy, scipy, threadpoolctl
Required-by: bayesian-optimization, Boruta, category_encoders, cesium, eli5, fastai, hdbscan, hep_ml, imbalanced-learn, librosa, lime, mlxtend, nilearn, pyLDAvis, pynndescent, rgf-python, scikit-learn-intelex, scikit-optimize, scikit-plot, sentence-transformers, shap, sklearn-compat, sklearn-pandas, TPOT, tsfresh, umap-learn, woodwork, yellowbrick
Name: imbalanced-learn
Version: 0.10.1
Summary: Toolbox for imbalanced dataset in machine learning.
Home-page: https://github.com/scikit-learn-contrib/imbalanced-learn
Author: 
Author-email: 
License: MIT
Location: /usr/local/lib/python3.11/dist-packages
Requires: joblib, numpy, scikit-learn, scipy, threadpoolctl
Required-by: 
SMOTE imported successfully


In [2]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, confusion_matrix, precision_recall_curve, auc
from imblearn.combine import SMOTEENN
from imblearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from category_encoders import TargetEncoder
import time
import shap
import matplotlib.pyplot as plt
import pyarrow.parquet as pq
# from sklearn.ensemble import HistGradientBoostingClassifier  # Alternative model that handles NaNs

# Verify SMOTEENN availability
try:
    from imblearn.combine import SMOTEENN
    print("SMOTEENN imported successfully")
except ImportError:
    print("Error: SMOTEENN not available. Install: pip install imbalanced-learn==0.10.1")
    raise ImportError("SMOTEENN is required.")

# Data loading and preprocessing function
def prepare_data_for_targets(data_path, flag_map, batch_size=10000):
    """Load and preprocess data from Parquet file in batches."""
    processed_chunks = []
    required_cols = ['MOTHER_ID', 'GRAVIDA']
    numeric_cols = [
        'GRAVIDA', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'HEMOGLOBIN_max',
        'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 'systolic_bp', 'diastolic_bp', 'BMI',
        'weight_gain', 'weight_gain_per_week', 'TOTAL_ANC_VISITS', 'NO_OF_WEEKS_max', 'PHQ_SCORE_max', 'GAD_SCORE_max',
        'gravida_parity_ratio'
    ]
    flag_cols = [
        'age_adolescent', 'age_elderly', 'age_very_young', 'previous_loss', 'recurrent_loss', 'inadequate_anc',
        'irregular_anc', 'anemia_mild', 'anemia_moderate', 'anemia_severe', 'ever_severe_anemia', 'hypertension',
        'underweight', 'obese', 'normal_weight', 'inadequate_weight_gain'
    ]
    categorical_cols = ['FACILITY_TYPE', 'BLOOD_GRP', 'SYS_DISEASE']
    
    parquet_file = pq.ParquetFile(data_path)
    num_rows = parquet_file.metadata.num_rows
    num_batches = (num_rows + batch_size - 1) // batch_size
    
    for batch_idx in range(num_batches):
        batch = parquet_file.read_row_group(batch_idx).to_pandas() if parquet_file.num_row_groups > batch_idx else pd.DataFrame()
        if batch.empty:
            continue
        
        for col in required_cols:
            if col not in batch.columns:
                raise ValueError(f"Required column {col} missing in data")
        
        # Convert numeric columns
        for col in numeric_cols:
            if col in batch.columns:
                batch[col] = pd.to_numeric(batch[col], errors='coerce')
        
        # Map flag columns and handle NaNs
        for col in flag_cols:
            if col in batch.columns:
                if batch[col].dtype == 'object':
                    batch[col] = batch[col].map(flag_map)
                batch[col] = batch[col].fillna(0)  # Impute NaNs with 0 for flags
        
        # Impute numeric NaNs with median
        if numeric_cols:
            batch[numeric_cols] = batch[numeric_cols].fillna(batch[numeric_cols].median())
        
        # Impute categorical NaNs with mode
        for col in categorical_cols:
            if col in batch.columns:
                mode_val = batch[col].mode().iloc[0] if not batch[col].mode().empty else 'Unknown'
                batch[col] = batch[col].fillna(mode_val)
        
        processed_chunks.append(batch)
    
    df = pd.concat(processed_chunks, ignore_index=True)
    
    # Limit SYS_DISEASE to top 10 categories
    if 'SYS_DISEASE' in df.columns:
        top_categories = df['SYS_DISEASE'].value_counts().index[:10]
        df['SYS_DISEASE'] = df['SYS_DISEASE'].apply(lambda x: x if x in top_categories else 'Other')
    
    # Add interaction features and impute NaNs
    # if 'anemia_mild' in df.columns and 'TOTAL_ANC_VISITS' in df.columns:
    #     df['anemia_mild_anc_visits'] = df['anemia_mild'] * df['TOTAL_ANC_VISITS']
    #     df['anemia_mild_anc_visits'] = df['anemia_mild_anc_visits'].fillna(df['anemia_mild_anc_visits'].median())
    # if 'hypertension' in df.columns and 'WEIGHT_anc_mean' in df.columns:
    #     df['hypertension_weight'] = df['hypertension'] * df['WEIGHT_anc_mean']
    #     df['hypertension_weight'] = df['hypertension_weight'].fillna(df['hypertension_weight'].median())
    if 'GRAVIDA' in df.columns and 'PARITY' in df.columns:
        df['gravida_parity_ratio'] = df['GRAVIDA'] / (df['PARITY'] + 1e-5)  # Avoid division by zero
        df['gravida_parity_ratio'] = df['gravida_parity_ratio'].fillna(df['gravida_parity_ratio'].median())
    
    # Final NaN check
    nan_cols = df[numeric_cols + flag_cols + categorical_cols].isna().sum()
    if nan_cols.any():
        print("Warning: NaNs remain in the following columns after preprocessing:")
        print(nan_cols[nan_cols > 0])
    
    return df

# Stratified sampling function
def create_stratified_sample(df, target_column, sample_size=10000, min_positive=1000):
    """Create a stratified sample ensuring both classes are included."""
    if target_column not in df.columns:
        raise ValueError(f"Target column '{target_column}' not found in DataFrame. Available columns: {df.columns.tolist()}")
    
    positive_cases = df[df[target_column] == 1]
    negative_cases = df[df[target_column] == 0]
    
    print(f"Found {len(positive_cases)} positive cases and {len(negative_cases)} negative cases for {target_column}")
    
    if len(positive_cases) == 0:
        raise ValueError(f"No positive cases ({target_column}=1) found.")
    if len(negative_cases) == 0:
        raise ValueError(f"No negative cases ({target_column}=0) found.")
    
    sampled_negatives = negative_cases
    target_negatives = len(negative_cases)
    target_positives = min(len(positive_cases), max(min_positive, target_negatives * 2))  # Aim for 2:1 positive:negative
    remaining_size = min(sample_size - len(sampled_negatives), len(positive_cases))
    
    if remaining_size > 0:
        sampled_positives = positive_cases.sample(n=remaining_size, random_state=42)
        final_sample = pd.concat([sampled_positives, sampled_negatives], ignore_index=True)
    else:
        final_sample = sampled_negatives
    
    print(f"Sampled {len(sampled_positives)} positive cases and {len(sampled_negatives)} negative cases")
    return final_sample

# Main script
data_path = '/kaggle/input/fullfinal/telangana_data_with_features_and_targets (1).parquet'
flag_map = {
    'Y': 1, 'YES': 1, 'Yes': 1, 'y': 1, 'yes': 1,
    'N': 0, 'NO': 0, 'No': 0, 'n': 0, 'no': 0,
    None: np.nan, 'None': np.nan, '': np.nan, 'nan': np.nan
}

# Load and preprocess data
df = prepare_data_for_targets(data_path, flag_map)

# Define columns to exclude to prevent data leakage
exclude_cols = [
    'anc_dropout', 'maternal_mortality_risk', 'stillbirth_risk', 'premature_birth_risk', 'birth_defect_risk',
    'high_risk_pregnancy', 'total_risk_factors', 'clinical_risk_score', 'risk_level', 'predicted_risk',
    'total_missed_visits', 'age_risk_score', 'demographic_risk', 'anemia_risk_score', 'overall_risk_score',
    'MOTHER_ID', 'ANC_ID', 'CHILD_ID', 'unique_id', 'EID', 'UID_NUMBER', 'WEIGHT_child_mean', 'WEIGHT_child_min',
    'DELIVERY_MODE', 'MATERNAL_OUTCOME', 'IS_DELIVERED', 'DELIVERY_OUTCOME', 'DATE_OF_DELIVERY', 'PLACE_OF_DELIVERY',
    'DEL_TIME', 'DATE_OF_DISCHARGE', 'DISCHARGE_TIME', 'JSY_BENEFICIARY', 'IS_MOTHER_ALIVE', 'IS_CHILD_DEATH',
    'CHILD_DEATH_DATE', 'CHILD_DEATH_REASON', 'DEFECT_HEALTH_CENTER', 'IS_DEFECTIVE_BIRTH', 'BIRTH_DEFECT_TYPE',
    'BIRTH_DEFECT_SUBTYPE', 'DEFECT_SUBTYPE_OTHER', 'DEFECT_TYPE_OTHER', 'NOTIFICATION_SENT',
    'FBIR_COMPLETED_BY_ANM', 'NEWBORN_SCREENING', 'DEATH_REASON_OTHER', 'SNCU_ADMITTED', 'SNCU_REFERRAL_HOSPITAL',
    'TERTIARY_REFERRAL_HOSPITAL', 'OTHER_REFERRAL_HOSPITAL', 'DATE_OF_DEATH', 'REASON_FOR_DEATH', 'PLACE_OF_DEATH',
    'INDICATION_FOR_C_SECTION', 'DELIVERY_INSTITUTION', 'DELIVERY_DONE_BY', 'CONDUCT_BY', 'MISOPROSTAL_TABLET',
    'DEL_COMPLICATIONS', 'OTHER_DEL_COMPLICATIONS', 'NOTIFICATION_SENT_del', 'FBIR_COMPLETED_BY_ANM_del',
    'REGISTRATION_DT', 'REGTYPE', 'CURRENT_USR', 'OTHER_STATE_PLACE', 'OTHER_STATE_PLACE_FILEPATH',
    'OTHER_GOVT_PLACE_FILEPATH', 'ANC2_TAG_FAC_ID', 'ANC3_TAG_FAC_ID', 'ANC_INSTITUTE', 'FACILITY_NAME',
    'DOCTOR_ANM', 'DISTRICT_anc', 'DISTRICT_child', 'IS_BF_IN_HOUR', 'FEEDING_TYPE', 'DATE_OF_FIRST_FEEDING',
    'TIME_OF_FIRST_FEEDING', 'BABY_ON_MEDICATION', 'MEDICATION_REMARKS', 'DATE_OF_BLOODSAMPLE_COLLECTION',
    'TIME_OF_BLOODSAMPLE_COLLECTION', 'HOURS_OF_SAMPLE_COLLECTION', 'TRANSFUSION_DONE', 'VDRL_DATE',
    'VDRL_STATUS', 'HIV_DATE', 'HIV_STATUS', 'HBSAG_DATE', 'HBSAG_STATUS', 'HEP_DATE', 'HEP_STATUS',
    'VDRL_RESULT', 'HIV_RESULT', 'HBSAG_RESULT', 'HEP_RESULT', 'MISSANC1FLG', 'MISSANC2FLG', 'MISSANC3FLG',
    'MISSANC4FLG', 'HIGH_RISKS', 'DISEASES', 'CHILD_NAME', 'age_category', 'multigravida', 'grand_multipara',
    'no_anc', 'missed_first_anc', 'consecutive_missed', 'severe_hypertension', 'low_birth_weight',
    'very_low_birth_weight', 'avg_birth_weight_low', 'depression', 'severe_depression', 'anxiety',
    'severe_anxiety', 'hemoglobin_trend', 'TT_DATE', 'MAL_PRESENT', 'IS_ADMITTED_SNCU', 'IS_PREV_PREG',
    'ANC1FLG', 'ANC2FLG', 'ANC3FLG', 'ANC4FLG', 'ANC_DATE', 'DEATH', 'EXP_DOD', 'DELIVERY_PLACE',
    'EXP_DOD_preg', 'FASTING', 'LMP_DT', 'SCREENED_FOR_MENTAL_HEALTH', 'AGE', 'AGE_final', 'AGE_preg',
    'GENDER', 'TIME_OF_BIRTH', 'TWIN_PREGNANCY_max', 'RNK', 'PHQ_SCORE_max', 'GAD_SCORE_max',
    'mental_health_risk', 'NO_OF_WEEKS_max', 'DELIVERY_INSTITUTION', 'DELIVERY_DONE_BY', 'CONDUCT_BY',
    'OTHER_NAME', 'NOTIFICATION_SENT', 'FBIR_COMPLETED_BY_ANM', 'OTHER_STATE_PLACE', 'OTHER_STATE_PLACE_FILEPATH',
    'OTHER_GOVT_PLACE_FILEPATH', 'bp_risk'
]

# Diagnostic checks
print("Columns in df:", df.columns.tolist())
print("anc_dropout distribution:")
print(df['anc_dropout'].value_counts(dropna=False))

# Select target
target_column = 'anc_dropout'
if df[target_column].eq(1).sum() == 0:
    raise ValueError(f"No positive cases found for {target_column}. Cannot proceed.")
print(f"Using target: {target_column}")

# Create sample
sample_df = create_stratified_sample(df, target_column, sample_size=10000)
print(f"Class distribution in sample_df for {target_column}:")
print(sample_df[target_column].value_counts())

# Select features
numeric_features = [
    'GRAVIDA', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'HEMOGLOBIN_max',
    'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 'systolic_bp', 'diastolic_bp', 'BMI',
    'weight_gain', 'weight_gain_per_week', 'gravida_parity_ratio', 'TOTAL_ANC_VISITS', 'NO_OF_WEEKS_max'
    # 'hypertension_weight'
]
flag_features = [
    'age_adolescent', 'age_elderly', 'age_very_young', 'previous_loss', 'recurrent_loss', 'inadequate_anc',
    'irregular_anc', 'anemia_mild', 'anemia_moderate', 'anemia_severe', 'ever_severe_anemia', 'hypertension',
    'underweight', 'obese', 'normal_weight', 'inadequate_weight_gain'
]
categorical_features = ['FACILITY_TYPE', 'BLOOD_GRP', 'SYS_DISEASE']
features = numeric_features + flag_features + categorical_features

if not features:
    raise ValueError(f"No valid features available for {target_column}. Available columns: {sample_df.columns.tolist()}")
print("Features used:", features)

# Prepare data
X = sample_df[features]
y = sample_df[target_column]

# Check for NaNs in features
nan_summary = X.isna().sum()
if nan_summary.any():
    print("NaNs found in features before encoding:")
    print(nan_summary[nan_summary > 0])

# Encode categorical features
encoder = TargetEncoder(cols=categorical_features)
X = encoder.fit_transform(X, y)

# Check for NaNs after encoding
nan_summary = X.isna().sum()
if nan_summary.any():
    print("NaNs found in features after encoding:")
    print(nan_summary[nan_summary > 0])

# Initial train-test split
X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Initialize k-fold cross-validation
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Metrics storage
metrics = {
    'thresh_0_1': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'pr_auc': []},
    'thresh_0_2': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'pr_auc': []},
    'thresh_0_3': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'pr_auc': []},
    'thresh_0_4': {'f1': [], 'accuracy': [], 'precision': [], 'recall': [], 'pr_auc': []},
    'auc': []
}
fold_models = []
fold_times = []
fold_avg_f1 = []

# Calculate scale_pos_weight
neg_count = (y_train_full == 0).sum()
pos_count = (y_train_full == 1).sum()
scale_pos_weight = neg_count / pos_count * 1.5 if pos_count > 0 else 1.0
print(f"Scale pos weight: {scale_pos_weight:.2f}")

# LightGBM parameters
params = {
    'objective': 'binary',
    'metric': 'auc',
    'n_estimators': 500,
    'max_depth': 7,
    'learning_rate': 0.05,
    'scale_pos_weight': scale_pos_weight,
    'min_child_weight': 5,
    'random_state': 42,
    'n_jobs': -1,
    'verbosity': -1
}

# K-fold cross-validation
for fold, (train_idx, val_idx) in enumerate(skf.split(X_train_full, y_train_full)):
    print(f"Training fold {fold + 1}/{n_splits}")
    
    X_train, X_val = X_train_full.iloc[train_idx], X_train_full.iloc[val_idx]
    y_train, y_val = y_train_full.iloc[train_idx], y_train_full.iloc[val_idx]
    
    print(f"Training class counts:\n{y_train.value_counts()}")
    print(f"Validation class counts:\n{y_val.value_counts()}")
    
    if len(y_train.unique()) < 2 or len(y_val.unique()) < 2:
        print(f"Warning: Fold {fold + 1} has only one class. Skipping.")
        continue
    
    # Apply SMOTEENN with imputation pipeline
    print(f"Applying SMOTEENN for fold {fold + 1}")
    imputer = SimpleImputer(strategy='median')
    smoteenn = SMOTEENN(random_state=42, sampling_strategy=0.5)
    pipeline = Pipeline([
        ('imputer', imputer),
        ('smoteenn', smoteenn)
    ])
    
    try:
        X_train, y_train = pipeline.fit_resample(X_train, y_train)
        print(f"Post-SMOTEENN training class counts:\n{y_train.value_counts()}")
    except ValueError as e:
        print(f"SMOTEENN pipeline failed for fold {fold + 1}: {e}. Proceeding without resampling.")
    
    start_time = time.time()
    model = lgb.LGBMClassifier(**params)
    model.fit(X_train, y_train, eval_set=[(X_val, y_val)], eval_metric='auc', callbacks=[lgb.early_stopping(50, verbose=False)])
    fold_time = time.time() - start_time
    fold_models.append(model)
    fold_times.append(fold_time)
    
    y_pred_proba = model.predict_proba(X_val)[:, 1]
    precisions, recalls, pr_thresholds = precision_recall_curve(y_val, y_pred_proba)
    pr_auc = auc(recalls, precisions)
    
    for thresh, thresh_name in [(0.1, 'thresh_0_1'), (0.2, 'thresh_0_2'), (0.3, 'thresh_0_3'), (0.4, 'thresh_0_4')]:
        y_pred = (y_pred_proba > thresh).astype(int)
        metrics[thresh_name]['f1'].append(f1_score(y_val, y_pred))
        metrics[thresh_name]['accuracy'].append(accuracy_score(y_val, y_pred))
        metrics[thresh_name]['precision'].append(precision_score(y_val, y_pred, zero_division=0))
        metrics[thresh_name]['recall'].append(recall_score(y_val, y_pred, zero_division=0))
        metrics[thresh_name]['pr_auc'].append(pr_auc)
    
    auc_score = roc_auc_score(y_val, y_pred_proba) if len(np.unique(y_val)) > 1 else 0
    metrics['auc'].append(auc_score)
    
    avg_f1 = np.mean([metrics[f'thresh_0_{i}']['f1'][-1] for i in [1, 2, 3, 4]])
    fold_avg_f1.append(avg_f1)
    
    print(f"Fold {fold + 1}")
    print(f"  AUC: {auc_score:.4f}, PR-AUC: {pr_auc:.4f}, Time: {fold_time:.2f} seconds")
    for thresh, thresh_name in [(0.1, 'thresh_0_1'), (0.2, 'thresh_0_2'), (0.3, 'thresh_0_3'), (0.4, 'thresh_0_4')]:
        print(f"  Threshold {thresh} - F1: {metrics[thresh_name]['f1'][-1]:.4f}, Accuracy: {metrics[thresh_name]['accuracy'][-1]:.4f}, "
              f"Precision: {metrics[thresh_name]['precision'][-1]:.4f}, Recall: {metrics[thresh_name]['recall'][-1]:.4f}")
        print(f"  Confusion Matrix (Threshold {thresh}):\n{confusion_matrix(y_val, (y_pred_proba > thresh).astype(int))}")

# Cross-validation results
print(f"\nCross-Validation Mean Metrics:")
print(f"  AUC: {np.mean(metrics['auc']):.4f} ± {np.std(metrics['auc']):.4f}")
for thresh_name in ['thresh_0_1', 'thresh_0_2', 'thresh_0_3', 'thresh_0_4']:
    print(f"\n{thresh_name.replace('_', ' ').title()}:")
    print(f"  F1 Score: {np.mean(metrics[thresh_name]['f1']):.4f} ± {np.std(metrics[thresh_name]['f1']):.4f}")
    print(f"  Accuracy: {np.mean(metrics[thresh_name]['accuracy']):.4f} ± {np.std(metrics[thresh_name]['accuracy']):.4f}")
    print(f"  Precision: {np.mean(metrics[thresh_name]['precision']):.4f} ± {np.std(metrics[thresh_name]['precision']):.4f}")
    print(f"  Recall: {np.mean(metrics[thresh_name]['recall']):.4f} ± {np.std(metrics[thresh_name]['recall']):.4f}")
    print(f"  PR-AUC: {np.mean(metrics[thresh_name]['pr_auc']):.4f} ± {np.std(metrics[thresh_name]['pr_auc']):.4f}")

# Evaluate best model on test set
if fold_avg_f1:
    best_fold_idx = np.argmax(fold_avg_f1)
    best_model = fold_models[best_fold_idx]
    print(f"\nBest Model from Fold {best_fold_idx + 1} with Average F1 Score: {fold_avg_f1[best_fold_idx]:.4f}")
    
    y_test_pred_proba = best_model.predict_proba(X_test)[:, 1]
    precisions, recalls, _ = precision_recall_curve(y_test, y_test_pred_proba)
    test_pr_auc = auc(recalls, precisions)
    
    test_metrics = {}
    for thresh in [0.1, 0.2, 0.3, 0.4]:
        y_test_pred = (y_test_pred_proba > thresh).astype(int)
        test_metrics[thresh] = {
            'f1': f1_score(y_test, y_test_pred),
            'accuracy': accuracy_score(y_test, y_test_pred),
            'precision': precision_score(y_test, y_test_pred, zero_division=0),
            'recall': recall_score(y_test, y_test_pred, zero_division=0),
            'cm': confusion_matrix(y_test, y_test_pred)
        }
    
    test_auc = roc_auc_score(y_test, y_test_pred_proba) if len(np.unique(y_test)) > 1 else 0
    print(f"\nTest Set Metrics (Best Model from Fold {best_fold_idx + 1}):")
    print(f"  AUC: {test_auc:.4f}, PR-AUC: {test_pr_auc:.4f}")
    for thresh in [0.1, 0.2, 0.3, 0.4]:
        print(f"\nThreshold {thresh}:\n  F1: {test_metrics[thresh]['f1']:.4f}, Accuracy: {test_metrics[thresh]['accuracy']:.4f}, "
              f"Precision: {test_metrics[thresh]['precision']:.4f}, Recall: {test_metrics[thresh]['recall']:.4f}\n  "
              f"Confusion Matrix:\n{test_metrics[thresh]['cm']}")
    
    test_f1_scores = {thresh: test_metrics[thresh]['f1'] for thresh in [0.1, 0.2, 0.3, 0.4]}
    best_threshold = max(test_f1_scores, key=test_f1_scores.get)
    print(f"\nBest Threshold on Test Set: {best_threshold} with F1 Score: {test_f1_scores[best_threshold]:.4f}")
    
    # SHAP analysis
    print("\nPerforming SHAP analysis...")
    X_test_sample = X_test.sample(n=min(1000, len(X_test)), random_state=42)
    explainer = shap.TreeExplainer(best_model)
    shap_values = explainer.shap_values(X_test_sample)[1]
    
    plt.figure()
    shap.summary_plot(shap_values, X_test_sample, show=False)
    plt.savefig("shap_summary_plot_anc_dropout.png")
    plt.close()
    print("SHAP summary plot saved as 'shap_summary_plot_anc_dropout.png'")
    
    plt.figure()
    shap.summary_plot(shap_values, X_test_sample, plot_type="bar", show=False)
    plt.savefig("shap_importance_bar_anc_dropout.png")
    plt.close()
    print("SHAP feature importance bar plot saved as 'shap_importance_bar_anc_dropout.png'")
    
    shap_importance = np.abs(shap_values).mean(axis=0)
    importance_df = pd.DataFrame({
        'Feature': X_test_sample.columns,
        'SHAP_Importance': shap_importance
    }).sort_values(by='SHAP_Importance', ascending=False)
    print("\nSHAP Feature Importance:")
    print(importance_df)
else:
    print("\nNo valid models trained due to single-class folds.")

# Alternative model (uncomment to use)
"""
# Use HistGradientBoostingClassifier which handles NaNs natively
params = {
    'max_iter': 500,
    'max_depth': 7,
    'learning_rate': 0.05,
    'random_state': 42
}
model = HistGradientBoostingClassifier(**params)
"""

SMOTEENN imported successfully
anc_dropout distribution:
anc_dropout
1    4029441
0        130
Name: count, dtype: int64
Using target: anc_dropout
Found 4029441 positive cases and 130 negative cases for anc_dropout
Sampled 9870 positive cases and 130 negative cases
Class distribution in sample_df for anc_dropout:
anc_dropout
1    9870
0     130
Name: count, dtype: int64
Features used: ['GRAVIDA', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'HEMOGLOBIN_max', 'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 'systolic_bp', 'diastolic_bp', 'BMI', 'weight_gain', 'weight_gain_per_week', 'gravida_parity_ratio', 'TOTAL_ANC_VISITS', 'NO_OF_WEEKS_max', 'age_adolescent', 'age_elderly', 'age_very_young', 'previous_loss', 'recurrent_loss', 'inadequate_anc', 'irregular_anc', 'anemia_mild', 'anemia_moderate', 'anemia_severe', 'ever_severe_anemia', 'hypertension', 'underweight', 'obese', 'normal_weight', 'inadequate_weight_gain', 'FACILITY_TYPE', 'BLOOD_GRP', 'SYS_DISEA

"\n# Use HistGradientBoostingClassifier which handles NaNs natively\nparams = {\n    'max_iter': 500,\n    'max_depth': 7,\n    'learning_rate': 0.05,\n    'random_state': 42\n}\nmodel = HistGradientBoostingClassifier(**params)\n"