In [1]:
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_score, recall_score
import time
import warnings
import pyarrow.parquet as pq
warnings.filterwarnings('ignore')

data_path = '/kaggle/input/fullfinal/telangana_data_with_features_and_targets (1).parquet'
flag_map = {
    'Y': 1, 'YES': 1, 'Yes': 1, 'y': 1, 'yes': 1,
    'N': 0, 'NO': 0, 'No': 0, 'n': 0, 'no': 0,
    None: np.nan, 'None': np.nan, '': np.nan, 'nan': np.nan
}

def prepare_data_for_targets(data_path, flag_map, batch_size=10000):
    """Load and preprocess data from Parquet file in batches."""
    # Initialize an empty list to store processed chunks
    processed_chunks = []
    
    # Define required and numeric columns
    required_cols = ['MOTHER_ID', 'GRAVIDA']
    numeric_cols = ['AGE', 'AGE_preg', 'AGE_final', 'GRAVIDA', 'PARITY', 'ABORTIONS', 'TOTAL_ANC_VISITS',
                    'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'WEIGHT_max', 'HEIGHT',
                    'PHQ_SCORE_max', 'GAD_SCORE_max', 'WEIGHT_last', 'WEIGHT_first',
                    'NO_OF_WEEKS_max', 'WEIGHT_min', 'WEIGHT_mean']
    
    # Open Parquet file using pyarrow
    parquet_file = pq.ParquetFile(data_path)
    num_rows = parquet_file.metadata.num_rows
    num_batches = (num_rows + batch_size - 1) // batch_size  # Ceiling division
    
    # Iterate through batches
    for batch_idx in range(num_batches):
        # Read a batch of rows
        start_row = batch_idx * batch_size
        end_row = min(start_row + batch_size, num_rows)
        batch = parquet_file.read_row_group(batch_idx).to_pandas() if parquet_file.num_row_groups > batch_idx else pd.DataFrame()
        
        if batch.empty:
            continue
            
        # Check for required columns
        for col in required_cols:
            if col not in batch.columns:
                raise ValueError(f"Required column {col} missing in data")

        # Debug: Check non-numeric values
        for col in numeric_cols:
            if col in batch.columns and batch[col].dtype == 'object':
                non_numeric = batch[col][~batch[col].apply(lambda x: str(x).replace('.', '').isdigit() if pd.notna(x) else False)]
                if not non_numeric.empty:
                    print(f"Non-numeric values in {col}: {non_numeric.unique()[:10]}")

        # Clean GRAVIDA specifically
        if 'GRAVIDA' in batch.columns:
            batch['GRAVIDA'] = batch['GRAVIDA'].replace('nan', np.nan)
            batch['GRAVIDA'] = pd.to_numeric(batch['GRAVIDA'], errors='coerce')
            # Replace NaN with median GRAVIDA (or 1 as default) within batch
            if batch['GRAVIDA'].notna().any():
                median_gravida = batch['GRAVIDA'].median()
                if pd.isna(median_gravida):
                    median_gravida = 1.0
                batch['GRAVIDA'] = batch['GRAVIDA'].fillna(0)
                print(f"Filled {batch['GRAVIDA'].isna().sum()} GRAVIDA NaN values with 0 in batch")

        # Convert other numeric columns
        for col in numeric_cols:
            if col in batch.columns and col != 'GRAVIDA':
                batch[col] = pd.to_numeric(batch[col], errors='coerce')

        # Map flag columns
        if 'IS_CHILD_DEATH' in batch.columns and batch['IS_CHILD_DEATH'].dtype == 'object':
            batch['IS_CHILD_DEATH'] = batch['IS_CHILD_DEATH'].map(flag_map)
        if 'IS_DEFECTIVE_BIRTH' in batch.columns and batch['IS_DEFECTIVE_BIRTH'].dtype == 'object':
            batch['IS_DEFECTIVE_BIRTH'] = batch['IS_DEFECTIVE_BIRTH'].map(flag_map)

        # Fill NaN for numeric columns
        batch_numeric_cols = batch.select_dtypes(include=['float64', 'float32', 'int64', 'int32', 'int8']).columns
        batch[batch_numeric_cols] = batch[batch_numeric_cols].fillna(0)

        # Append processed batch
        processed_chunks.append(batch)

    # Concatenate all chunks
    df = pd.concat(processed_chunks, ignore_index=True)
    return df

# Execute the function
df = prepare_data_for_targets(data_path, flag_map)

Non-numeric values in GRAVIDA: ['nan']
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch


In [2]:
target_columns = [
            'premature_birth_risk'
        ]
def create_stratified_sample(df, target_column, sample_size=2000000):
    """Create a stratified sample ensuring all critical cases are included."""
    child_death_col = 'IS_CHILD_DEATH'
    critical_cases = pd.DataFrame()
    
    if child_death_col in df.columns:
        child_deaths = df[df[child_death_col] == 1]
        critical_cases = pd.concat([critical_cases, child_deaths])
    if 'maternal_mortality_risk' in df.columns:
        maternal_deaths = df[df['maternal_mortality_risk'] == 1]
        critical_cases = pd.concat([critical_cases, maternal_deaths])
    if 'stillbirth_risk' in df.columns:
        stillbirths = df[df['stillbirth_risk'] == 1]
        critical_cases = pd.concat([critical_cases, stillbirths])
    critical_cases = critical_cases.drop_duplicates()
    
    remaining_size = sample_size - len(critical_cases)
    
    if remaining_size > 0:
        other_cases = df[~df.index.isin(critical_cases.index)]
        sampled_others = other_cases.sample(n=remaining_size, random_state=42)
        final_sample = pd.concat([critical_cases, sampled_others])
    else:
        final_sample = critical_cases.sample(n=sample_size, random_state=42)
    
    return final_sample

sample_df=create_stratified_sample(df, 'high_risk_pregnancy')

# Exclude target-related columns to prevent leakage
# exclude_cols = target_columns + [
#     'total_risk_factors', 'clinical_risk_score', 'risk_level', 'predicted_risk',
#     'total_missed_visits', 'age_risk_score', 'demographic_risk', 'anemia_risk_score',
#     'overall_risk_score', 'MOTHER_ID', 'ANC_ID', 'IS_CHILD_DEATH',
#     'IS_DEFECTIVE_BIRTH', 'NO_OF_WEEKS_max', 'DELIVERY_OUTCOME', 'IS_MOTHER_ALIVE',
#     'MISSANC1FLG', 'MISSANC2FLG', 'MISSANC3FLG', 'MISSANC4FLG', 'stillbirth_risk', 'premature_birth_risk',
#             'maternal_mortality_risk', 'birth_defect_risk', 'anc_dropout',
    
#     # Newly added columns
#     'ANC_INSTITUTE', 'FACILITY_TYPE', 'FACILITY_NAME', 'DOCTOR_ANM', 'IFA_TABLET',
#     'TT_GIVEN', 'USG_SCAN', 'DISTRICT_anc', 'TIME_OF_BIRTH', 'IS_BF_IN_HOUR',
#     'SNCU_REFERRAL_HOSPITAL', 'TERTIARY_REFERRAL_HOSPITAL', 'OTHER_REFERRAL_HOSPITAL',
#     'IMMUNE_CYCLE_DONE', 'CHILD_DEATH_DATE', 'CHILD_DEATH_REASON', 'DEFECT_HEALTH_CENTER',
#     'DISTRICT_child', 'BIRTH_DEFECT_TYPE', 'BIRTH_DEFECT_SUBTYPE', 'DEFECT_SUBTYPE_OTHER',
#     'NOTIFICATION_SENT', 'FBIR_COMPLETED_BY_ANM', 'CHILD_NAME', 'DEFECT_TYPE_OTHER',
#     'DEATH_REASON_OTHER', 'NEWBORN_SCREENING', 'EID', 'UID_NUMBER', 'SNCU_ADMITTED',
#     'CONSANGUINITY', 'HIGH_RISKS', 'DISEASES', 'FEEDING_TYPE', 'DATE_OF_FIRST_FEEDING',
#     'TIME_OF_FIRST_FEEDING', 'DATE_OF_BLOODSAMPLE_COLLECTION', 'TIME_OF_BLOODSAMPLE_COLLECTION',
#     'HOURS_OF_SAMPLE_COLLECTION', 'TRANSFUSION_DONE', 'BABY_ON_MEDICATION', 'MEDICATION_REMARKS',
#     'DATE_OF_DELIVERY', 'PLACE_OF_DELIVERY', 'MODE_OF_DELIVERY', 'MATERNAL_OUTCOME',
#     'REASON_FOR_DEATH', 'DATE_OF_DEATH', 'PLACE_OF_DEATH', 'INDICATION_FOR_C_SECTION',
#     'CH', 'CAH', 'GALACTOCEMIA', 'G6PDD', 'BIOTINIDASE', 'DELIVERY_INSTITUTION',
#     'DELIVERY_DONE_BY', 'IS_DELIVERED', 'DATE_OF_DISCHARGE', 'DEL_TIME', 'MISOPROSTAL_TABLET',
#     'CONDUCT_BY', 'OTHER_NAME', 'JSY_BENEFICIARY', 'DISCHARGE_TIME', 'DEL_COMPLICATIONS',
#     'OTHER_DEL_COMPLICATIONS', 'NOTIFICATION_SENT_del', 'FBIR_COMPLETED_BY_ANM_del',
#     'OTHER_STATE_PLACE', 'OTHER_STATE_PLACE_FILEPATH', 'OTHER_GOVT_PLACE_FILEPATH',
#     'REGISTRATION_DT', 'REGTYPE', 'CURRENT_USR', 'ANC2_TAG_FAC_ID', 'ANC3_TAG_FAC_ID',
#     'VDRL_DATE', 'VDRL_STATUS', 'HIV_DATE', 'HIV_STATUS', 'HBSAG_DATE', 'HBSAG_STATUS',
#     'HEP_DATE', 'HEP_STATUS'
# ]

exclude_cols = [
    # Labels & targets
    'maternal_mortality_risk', 'stillbirth_risk', 'premature_birth_risk',
    'birth_defect_risk', 'anc_dropout', 'high_risk_pregnancy',

    # Derived risk scores and labels
    'total_risk_factors', 'clinical_risk_score', 'risk_level', 'predicted_risk',
    'total_missed_visits', 'age_risk_score', 'demographic_risk', 'anemia_risk_score',
    'overall_risk_score',

    # Identifiers
    'MOTHER_ID', 'ANC_ID', 'CHILD_ID', 'unique_id', 'EID', 'UID_NUMBER',

    # Post-delivery or outcome leakage
    'WEIGHT_child_mean', 'DELIVERY_MODE', 'MATERNAL_OUTCOME', 'IS_DELIVERED',
    'DELIVERY_OUTCOME', 'DATE_OF_DELIVERY', 'PLACE_OF_DELIVERY', 'DEL_TIME',
    'DATE_OF_DISCHARGE', 'DISCHARGE_TIME', 'JSY_BENEFICIARY',
    'IS_MOTHER_ALIVE', 'IS_CHILD_DEATH', 'CHILD_DEATH_DATE', 'CHILD_DEATH_REASON',
    'DEFECT_HEALTH_CENTER', 'IS_DEFECTIVE_BIRTH', 'BIRTH_DEFECT_TYPE',
    'BIRTH_DEFECT_SUBTYPE', 'DEFECT_SUBTYPE_OTHER', 'DEFECT_TYPE_OTHER',
    'NOTIFICATION_SENT', 'FBIR_COMPLETED_BY_ANM', 'NEWBORN_SCREENING',
    'DEATH_REASON_OTHER', 'SNCU_ADMITTED', 'SNCU_REFERRAL_HOSPITAL',
    'TERTIARY_REFERRAL_HOSPITAL', 'OTHER_REFERRAL_HOSPITAL',
    'DATE_OF_DEATH', 'REASON_FOR_DEATH', 'PLACE_OF_DEATH',
    'INDICATION_FOR_C_SECTION', 'CH', 'CAH', 'GALACTOCEMIA', 'G6PDD', 'BIOTINIDASE',
    'low_birth_weight', 'very_low_birth_weight', 'avg_birth_weight_low',
    'TIME_OF_BIRTH', 'IS_BF_IN_HOUR',

    # Delivery-specific process info
    'DELIVERY_INSTITUTION', 'DELIVERY_DONE_BY', 'CONDUCT_BY',
    'MISOPROSTAL_TABLET', 'DEL_COMPLICATIONS', 'OTHER_DEL_COMPLICATIONS',
    'NOTIFICATION_SENT_del', 'FBIR_COMPLETED_BY_ANM_del',

    # Administrative / logging columns
    'REGISTRATION_DT', 'REGTYPE', 'CURRENT_USR', 'OTHER_STATE_PLACE',
    'OTHER_STATE_PLACE_FILEPATH', 'OTHER_GOVT_PLACE_FILEPATH',
    'ANC2_TAG_FAC_ID', 'ANC3_TAG_FAC_ID',

    # Facility and geographic info
    'ANC_INSTITUTE', 'FACILITY_TYPE', 'FACILITY_NAME', 'DOCTOR_ANM',
    'DISTRICT_anc', 'DISTRICT_child',

    # Feeding / newborn care
    'FEEDING_TYPE', 'DATE_OF_FIRST_FEEDING', 'TIME_OF_FIRST_FEEDING',
    'BABY_ON_MEDICATION', 'MEDICATION_REMARKS',
    'DATE_OF_BLOODSAMPLE_COLLECTION', 'TIME_OF_BLOODSAMPLE_COLLECTION',
    'HOURS_OF_SAMPLE_COLLECTION', 'TRANSFUSION_DONE',

    # Screening/test results (post-diagnosis leakage)
    'VDRL_DATE', 'VDRL_STATUS', 'HIV_DATE', 'HIV_STATUS', 'HBSAG_DATE',
    'HBSAG_STATUS', 'HEP_DATE', 'HEP_STATUS', 'VDRL_RESULT',
    'HIV_RESULT', 'HBSAG_RESULT', 'HEP_RESULT',

    # Missed ANC flag columns (already captured by better engineered features)
    'MISSANC1FLG', 'MISSANC2FLG', 'MISSANC3FLG', 'MISSANC4FLG',

    # Manually added known leaky or post-hoc columns
    'HIGH_RISKS', 'DISEASES', 'CHILD_NAME','RNK', 'AGE_preg', 'TOTAL_ANC_VISITS', 'hemoglobin_trend', 'mental_health_risk', 
    'TWIN_PREGNANCY_max', 'NO_OF_WEEKS_max', 'WEIGHT_child_min', 'AGE_final', 

    # Any other derived risk/score not already excluded
    *[col for col in df.columns if (
        ('risk' in col.lower() or 'score' in col.lower()) and col not in {
            'mental_health_risk'  # keep this one since it's pre-delivery
        }
    )],
]

# more_leakage_cols = [
#     # aggregate or binary risk flags
#     'bp_risk', 'mental_health_risk', 'age_category', 'multigravida', 'grand_multipara',
#      'no_anc',
#     'missed_first_anc', 'consecutive_missed',
#     # anemia + BP severity flags
#     'severe_hypertension',
#     # weight / BMI flags
#     'low_birth_weight', 'very_low_birth_weight', 'avg_birth_weight_low',
#     # mental-health flags
#     # trend or helper columns
#     'hemoglobin_trend', 'unique_id', 'TT_DATE', 'MAL_PRESENT','IS_ADMITTED_SNCU',
#  'IS_PREV_PREG','CHILD_ID', 'ANC1FLG',
#  'ANC2FLG',
#  'ANC3FLG',
#  'ANC4FLG',
#  'ANC_DATE',
#  'DEATH',
#  'EXP_DOD',
#  'DELIVERY_PLACE',
#  'EXP_DOD_preg',
#  'FASTING',
#  'LMP_DT',
#  'SCREENED_FOR_MENTAL_HEALTH',
#  'AGE',
#  'AGE_final',
#  'AGE_preg',
#  'WEIGHT_child_min',
# 'GENDER',
#     'TIME_OF_BIRTH',
#     'RNK', 'HEMOGLOBIN_min', 'HEMOGLOBIN_max', 'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 'NO_OF_WEEKS_max', 'TWIN_PREGNANCY_max', 'TOTAL_ANC_VISITS',
#     # any other “risk / score” not already in exclude list
#     *[c for c in df.columns if ('risk' in c.lower() or 'score' in c.lower()) and c not in exclude_cols]
# ]

# Finally combine everything
# exclude_cols = list(set(exclude_cols + more_leakage_cols))

features = [col for col in df.columns if col not in exclude_cols and df[col].dtype in ['float64', 'float32', 'int64', 'int32', 'int8']]

if not features:
    print(f"Skipping {target_columns}: No valid features available.")



In [3]:
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, f1_score
import lightgbm as lgb
import numpy as np
import time

# Assuming sample_df, features, and target_column are defined (correcting target_columns to target_column)
X = sample_df[features]
y = sample_df[target_columns]  # Changed from target_columns to target_column

# Initial train-test split
X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Initialize k-fold cross-validation
n_splits = 5  # Number of folds
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Lists to store metrics for each fold
auc_scores = []
f1_scores = []
fold_times = []

# Calculate scale_pos_weight for class imbalance (based on training dataset)
neg_count = (y_train_full == 0).sum()
pos_count = (y_train_full == 1).sum()
scale_pos_weight = neg_count[0] / pos_count[0] if pos_count[0] > 0 else 1  # Simplified, assuming y is a Series


In [4]:
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, confusion_matrix
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import time
import shap
import matplotlib.pyplot as plt
import pandas as pd

# Assuming sample_df, features, and target_column are defined
# X = sample_df[features]
# y = sample_df[target_columns]

# # Initial train-test split
# X_train_full, X_test, y_train_full, y_test = train_test_split(
#     X, y, test_size=0.2, random_state=42, stratify=y
# )

# Initialize k-fold cross-validation
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Lists to store metrics and models for each fold
metrics = {
    'thresh_0_1': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_2': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_3': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_4': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_5': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_6': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_7': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'auc': []  # AUC is threshold-independent
}
fold_models = []  # Store models from each fold
fold_avg_f1 = []  # Store average F1 score across thresholds for each fold

# Random Forest parameters (tuned for stability)
params = {
    'n_estimators': 100,  # Number of trees
    'max_depth': 10,      # Limit depth to prevent overfitting
    'min_samples_split': 50,  # Minimum samples to split a node
    'min_samples_leaf': 25,   # Minimum samples per leaf
    'class_weight': 'balanced',  # Handle class imbalance
    'random_state': 42,
    'n_jobs': -1  # Use all available cores
}

# Perform k-fold cross-validation
for fold, (train_idx, val_idx) in enumerate(skf.split(X_train_full, y_train_full)):
    print(f"Training fold {fold + 1}/{n_splits}")
    
    # Split data
    X_train, X_val = X_train_full.iloc[train_idx], X_train_full.iloc[val_idx]
    y_train, y_val = y_train_full.iloc[train_idx], y_train_full.iloc[val_idx]
    
    # Print class distribution
    print(f"Training class counts:\n{y_train.value_counts()}")
    print(f"Validation class counts:\n{y_val.value_counts()}")
    
    # Train Random Forest model
    start_time = time.time()
    model = RandomForestClassifier(**params)
    model.fit(X_train, y_train)
    fold_time = time.time() - start_time
    fold_models.append(model)  # Store the model
    
    # Predict and evaluate on validation set
    y_pred_proba = model.predict_proba(X_val)[:, 1]  # Probability for positive class
    
    # Fixed thresholds (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7)
    y_pred_0_1 = (y_pred_proba > 0.1).astype(int)
    y_pred_0_2 = (y_pred_proba > 0.2).astype(int)
    y_pred_0_3 = (y_pred_proba > 0.3).astype(int)
    y_pred_0_4 = (y_pred_proba > 0.4).astype(int)
    y_pred_0_5 = (y_pred_proba > 0.5).astype(int)
    y_pred_0_6 = (y_pred_proba > 0.6).astype(int)
    y_pred_0_7 = (y_pred_proba > 0.7).astype(int)
    
    # Metrics for threshold 0.1
    f1_0_1 = f1_score(y_val, y_pred_0_1)
    accuracy_0_1 = accuracy_score(y_val, y_pred_0_1)
    precision_0_1 = precision_score(y_val, y_pred_0_1, zero_division=0)
    recall_0_1 = recall_score(y_val, y_pred_0_1, zero_division=0)
    cm_0_1 = confusion_matrix(y_val, y_pred_0_1)
    
    # Metrics for threshold 0.2
    f1_0_2 = f1_score(y_val, y_pred_0_2)
    accuracy_0_2 = accuracy_score(y_val, y_pred_0_2)
    precision_0_2 = precision_score(y_val, y_pred_0_2, zero_division=0)
    recall_0_2 = recall_score(y_val, y_pred_0_2, zero_division=0)
    cm_0_2 = confusion_matrix(y_val, y_pred_0_2)
    
    # Metrics for threshold 0.3
    f1_0_3 = f1_score(y_val, y_pred_0_3)
    accuracy_0_3 = accuracy_score(y_val, y_pred_0_3)
    precision_0_3 = precision_score(y_val, y_pred_0_3, zero_division=0)
    recall_0_3 = recall_score(y_val, y_pred_0_3, zero_division=0)
    cm_0_3 = confusion_matrix(y_val, y_pred_0_3)
    
    # Metrics for threshold 0.4
    f1_0_4 = f1_score(y_val, y_pred_0_4)
    accuracy_0_4 = accuracy_score(y_val, y_pred_0_4)
    precision_0_4 = precision_score(y_val, y_pred_0_4, zero_division=0)
    recall_0_4 = recall_score(y_val, y_pred_0_4, zero_division=0)
    cm_0_4 = confusion_matrix(y_val, y_pred_0_4)
    
    # Metrics for threshold 0.5
    f1_0_5 = f1_score(y_val, y_pred_0_5)
    accuracy_0_5 = accuracy_score(y_val, y_pred_0_5)
    precision_0_5 = precision_score(y_val, y_pred_0_5, zero_division=0)
    recall_0_5 = recall_score(y_val, y_pred_0_5, zero_division=0)
    cm_0_5 = confusion_matrix(y_val, y_pred_0_5)
    
    # Metrics for threshold 0.6
    f1_0_6 = f1_score(y_val, y_pred_0_6)
    accuracy_0_6 = accuracy_score(y_val, y_pred_0_6)
    precision_0_6 = precision_score(y_val, y_pred_0_6, zero_division=0)
    recall_0_6 = recall_score(y_val, y_pred_0_6, zero_division=0)
    cm_0_6 = confusion_matrix(y_val, y_pred_0_6)
    
    # Metrics for threshold 0.7
    f1_0_7 = f1_score(y_val, y_pred_0_7)
    accuracy_0_7 = accuracy_score(y_val, y_pred_0_7)
    precision_0_7 = precision_score(y_val, y_pred_0_7, zero_division=0)
    recall_0_7 = recall_score(y_val, y_pred_0_7, zero_division=0)
    cm_0_7 = confusion_matrix(y_val, y_pred_0_7)
    
    # AUC (threshold-independent)
    auc = roc_auc_score(y_val, y_pred_proba)
    
    # Store metrics
    metrics['thresh_0_1']['f1'].append(f1_0_1)
    metrics['thresh_0_1']['accuracy'].append(accuracy_0_1)
    metrics['thresh_0_1']['precision'].append(precision_0_1)
    metrics['thresh_0_1']['recall'].append(recall_0_1)
    
    metrics['thresh_0_2']['f1'].append(f1_0_2)
    metrics['thresh_0_2']['accuracy'].append(accuracy_0_2)
    metrics['thresh_0_2']['precision'].append(precision_0_2)
    metrics['thresh_0_2']['recall'].append(recall_0_2)
    
    metrics['thresh_0_3']['f1'].append(f1_0_3)
    metrics['thresh_0_3']['accuracy'].append(accuracy_0_3)
    metrics['thresh_0_3']['precision'].append(precision_0_3)
    metrics['thresh_0_3']['recall'].append(recall_0_3)
    
    metrics['thresh_0_4']['f1'].append(f1_0_4)
    metrics['thresh_0_4']['accuracy'].append(accuracy_0_4)
    metrics['thresh_0_4']['precision'].append(precision_0_4)
    metrics['thresh_0_4']['recall'].append(recall_0_4)
    
    metrics['thresh_0_5']['f1'].append(f1_0_5)
    metrics['thresh_0_5']['accuracy'].append(accuracy_0_5)
    metrics['thresh_0_5']['precision'].append(precision_0_5)
    metrics['thresh_0_5']['recall'].append(recall_0_5)
    
    metrics['thresh_0_6']['f1'].append(f1_0_6)
    metrics['thresh_0_6']['accuracy'].append(accuracy_0_6)
    metrics['thresh_0_6']['precision'].append(precision_0_6)
    metrics['thresh_0_6']['recall'].append(recall_0_6)
    
    metrics['thresh_0_7']['f1'].append(f1_0_7)
    metrics['thresh_0_7']['accuracy'].append(accuracy_0_7)
    metrics['thresh_0_7']['precision'].append(precision_0_7)
    metrics['thresh_0_7']['recall'].append(recall_0_7)
    
    metrics['auc'].append(auc)
    
    # Calculate average F1 score across thresholds for this fold
    avg_f1 = np.mean([f1_0_1, f1_0_2, f1_0_3, f1_0_4, f1_0_5, f1_0_6, f1_0_7])
    fold_avg_f1.append(avg_f1)
    
    print(f"Fold {fold + 1}")
    print(f"  AUC: {auc:.4f}, Time: {fold_time:.2f} seconds")
    print(f"  Threshold 0.1 - F1: {f1_0_1:.4f}, Accuracy: {accuracy_0_1:.4f}, Precision: {precision_0_1:.4f}, Recall: {recall_0_1:.4f}")
    print("  Confusion Matrix (Threshold 0.1):")
    print(f"    TN: {cm_0_1[0,0]}, FP: {cm_0_1[0,1]}")
    print(f"    FN: {cm_0_1[1,0]}, TP: {cm_0_1[1,1]}")
    print(f"  Threshold 0.2 - F1: {f1_0_2:.4f}, Accuracy: {accuracy_0_2:.4f}, Precision: {precision_0_2:.4f}, Recall: {recall_0_2:.4f}")
    print("  Confusion Matrix (Threshold 0.2):")
    print(f"    TN: {cm_0_2[0,0]}, FP: {cm_0_2[0,1]}")
    print(f"    FN: {cm_0_2[1,0]}, TP: {cm_0_2[1,1]}")
    print(f"  Threshold 0.3 - F1: {f1_0_3:.4f}, Accuracy: {accuracy_0_3:.4f}, Precision: {precision_0_3:.4f}, Recall: {recall_0_3:.4f}")
    print("  Confusion Matrix (Threshold 0.3):")
    print(f"    TN: {cm_0_3[0,0]}, FP: {cm_0_3[0,1]}")
    print(f"    FN: {cm_0_3[1,0]}, TP: {cm_0_3[1,1]}")
    print(f"  Threshold 0.4 - F1: {f1_0_4:.4f}, Accuracy: {accuracy_0_4:.4f}, Precision: {precision_0_4:.4f}, Recall: {recall_0_4:.4f}")
    print("  Confusion Matrix (Threshold 0.4):")
    print(f"    TN: {cm_0_4[0,0]}, FP: {cm_0_4[0,1]}")
    print(f"    FN: {cm_0_4[1,0]}, TP: {cm_0_4[1,1]}")
    print(f"  Threshold 0.5 - F1: {f1_0_5:.4f}, Accuracy: {accuracy_0_5:.4f}, Precision: {precision_0_5:.4f}, Recall: {recall_0_5:.4f}")
    print("  Confusion Matrix (Threshold 0.5):")
    print(f"    TN: {cm_0_5[0,0]}, FP: {cm_0_5[0,1]}")
    print(f"    FN: {cm_0_5[1,0]}, TP: {cm_0_5[1,1]}")
    print(f"  Threshold 0.6 - F1: {f1_0_6:.4f}, Accuracy: {accuracy_0_6:.4f}, Precision: {precision_0_6:.4f}, Recall: {recall_0_6:.4f}")
    print("  Confusion Matrix (Threshold 0.6):")
    print(f"    TN: {cm_0_6[0,0]}, FP: {cm_0_6[0,1]}")
    print(f"    FN: {cm_0_6[1,0]}, TP: {cm_0_6[1,1]}")
    print(f"  Threshold 0.7 - F1: {f1_0_7:.4f}, Accuracy: {accuracy_0_7:.4f}, Precision: {precision_0_7:.4f}, Recall: {recall_0_7:.4f}")
    print("  Confusion Matrix (Threshold 0.7):")
    print(f"    TN: {cm_0_7[0,0]}, FP: {cm_0_7[0,1]}")
    print(f"    FN: {cm_0_7[1,0]}, TP: {cm_0_7[1,1]}")

print("Features used for training:", list(X_train_full.columns))

# Summary statistics
print(f"\nCross-Validation Mean Metrics:")
print(f"  AUC: {np.mean(metrics['auc']):.4f} ± {np.std(metrics['auc']):.4f}")
print(f"\nThreshold 0.1:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_1']['f1']):.4f} ± {np.std(metrics['thresh_0_1']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_1']['accuracy']):.4f} ± {np.std(metrics['thresh_0_1']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_1']['precision']):.4f} ± {np.std(metrics['thresh_0_1']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_1']['recall']):.4f} ± {np.std(metrics['thresh_0_1']['recall']):.4f}")

print(f"\nThreshold 0.2:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_2']['f1']):.4f} ± {np.std(metrics['thresh_0_2']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_2']['accuracy']):.4f} ± {np.std(metrics['thresh_0_2']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_2']['precision']):.4f} ± {np.std(metrics['thresh_0_2']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_2']['recall']):.4f} ± {np.std(metrics['thresh_0_2']['recall']):.4f}")

print(f"\nThreshold 0.3:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_3']['f1']):.4f} ± {np.std(metrics['thresh_0_3']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_3']['accuracy']):.4f} ± {np.std(metrics['thresh_0_3']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_3']['precision']):.4f} ± {np.std(metrics['thresh_0_3']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_3']['recall']):.4f} ± {np.std(metrics['thresh_0_3']['recall']):.4f}")

print(f"\nThreshold 0.4:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_4']['f1']):.4f} ± {np.std(metrics['thresh_0_4']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_4']['accuracy']):.4f} ± {np.std(metrics['thresh_0_4']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_4']['precision']):.4f} ± {np.std(metrics['thresh_0_4']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_4']['recall']):.4f} ± {np.std(metrics['thresh_0_4']['recall']):.4f}")

print(f"\nThreshold 0.5:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_5']['f1']):.4f} ± {np.std(metrics['thresh_0_5']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_5']['accuracy']):.4f} ± {np.std(metrics['thresh_0_5']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_5']['precision']):.4f} ± {np.std(metrics['thresh_0_5']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_5']['recall']):.4f} ± {np.std(metrics['thresh_0_5']['recall']):.4f}")

print(f"\nThreshold 0.6:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_6']['f1']):.4f} ± {np.std(metrics['thresh_0_6']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_6']['accuracy']):.4f} ± {np.std(metrics['thresh_0_6']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_6']['precision']):.4f} ± {np.std(metrics['thresh_0_6']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_6']['recall']):.4f} ± {np.std(metrics['thresh_0_6']['recall']):.4f}")

print(f"\nThreshold 0.7:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_7']['f1']):.4f} ± {np.std(metrics['thresh_0_7']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_7']['accuracy']):.4f} ± {np.std(metrics['thresh_0_7']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_7']['precision']):.4f} ± {np.std(metrics['thresh_0_7']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_7']['recall']):.4f} ± {np.std(metrics['thresh_0_7']['recall']):.4f}")

Training fold 1/5
Training class counts:
premature_birth_risk
0.0                     760444
1.0                     519556
Name: count, dtype: int64
Validation class counts:
premature_birth_risk
0.0                     190111
1.0                     129889
Name: count, dtype: int64
Fold 1
  AUC: 0.9593, Time: 81.74 seconds
  Threshold 0.1 - F1: 0.8012, Accuracy: 0.7986, Precision: 0.6683, Recall: 1.0000
  Confusion Matrix (Threshold 0.1):
    TN: 125656, FP: 64455
    FN: 0, TP: 129889
  Threshold 0.2 - F1: 0.9024, Accuracy: 0.9122, Precision: 0.8222, Recall: 1.0000
  Confusion Matrix (Threshold 0.2):
    TN: 162017, FP: 28094
    FN: 2, TP: 129887
  Threshold 0.3 - F1: 0.9234, Accuracy: 0.9326, Precision: 0.8576, Recall: 1.0000
  Confusion Matrix (Threshold 0.3):
    TN: 168551, FP: 21560
    FN: 2, TP: 129887
  Threshold 0.4 - F1: 0.9237, Accuracy: 0.9330, Precision: 0.8583, Recall: 1.0000
  Confusion Matrix (Threshold 0.4):
    TN: 168670, FP: 21441
    FN: 5, TP: 129884
  Threshol

In [5]:
# Select the best model based on average F1 score across thresholds
best_fold_idx = np.argmax(fold_avg_f1)
best_model = fold_models[best_fold_idx]
print(f"\nBest Model from Fold {best_fold_idx + 1} with Average F1 Score: {fold_avg_f1[best_fold_idx]:.4f}")

# Evaluate best model on test set
y_test_pred_proba = best_model.predict_proba(X_test)[:, 1]
y_test_pred_0_1 = (y_test_pred_proba > 0.1).astype(int)
y_test_pred_0_2 = (y_test_pred_proba > 0.2).astype(int)
y_test_pred_0_3 = (y_test_pred_proba > 0.3).astype(int)
y_test_pred_0_4 = (y_test_pred_proba > 0.4).astype(int)
y_test_pred_0_5 = (y_test_pred_proba > 0.5).astype(int)
y_test_pred_0_6 = (y_test_pred_proba > 0.6).astype(int)
y_test_pred_0_7 = (y_test_pred_proba > 0.7).astype(int)

# Test set metrics
test_auc = roc_auc_score(y_test, y_test_pred_proba)
test_f1_0_1 = f1_score(y_test, y_test_pred_0_1)
test_accuracy_0_1 = accuracy_score(y_test, y_test_pred_0_1)
test_precision_0_1 = precision_score(y_test, y_test_pred_0_1, zero_division=0)
test_recall_0_1 = recall_score(y_test, y_test_pred_0_1, zero_division=0)
test_cm_0_1 = confusion_matrix(y_test, y_test_pred_0_1)

test_f1_0_2 = f1_score(y_test, y_test_pred_0_2)
test_accuracy_0_2 = accuracy_score(y_test, y_test_pred_0_2)
test_precision_0_2 = precision_score(y_test, y_test_pred_0_2, zero_division=0)
test_recall_0_2 = recall_score(y_test, y_test_pred_0_2, zero_division=0)
test_cm_0_2 = confusion_matrix(y_test, y_test_pred_0_2)

test_f1_0_3 = f1_score(y_test, y_test_pred_0_3)
test_accuracy_0_3 = accuracy_score(y_test, y_test_pred_0_3)
test_precision_0_3 = precision_score(y_test, y_test_pred_0_3, zero_division=0)
test_recall_0_3 = recall_score(y_test, y_test_pred_0_3, zero_division=0)
test_cm_0_3 = confusion_matrix(y_test, y_test_pred_0_3)

test_f1_0_4 = f1_score(y_test, y_test_pred_0_4)
test_accuracy_0_4 = accuracy_score(y_test, y_test_pred_0_4)
test_precision_0_4 = precision_score(y_test, y_test_pred_0_4, zero_division=0)
test_recall_0_4 = recall_score(y_test, y_test_pred_0_4, zero_division=0)
test_cm_0_4 = confusion_matrix(y_test, y_test_pred_0_4)

test_f1_0_5 = f1_score(y_test, y_test_pred_0_5)
test_accuracy_0_5 = accuracy_score(y_test, y_test_pred_0_5)
test_precision_0_5 = precision_score(y_test, y_test_pred_0_5, zero_division=0)
test_recall_0_5 = recall_score(y_test, y_test_pred_0_5, zero_division=0)
test_cm_0_5 = confusion_matrix(y_test, y_test_pred_0_5)

test_f1_0_6 = f1_score(y_test, y_test_pred_0_6)
test_accuracy_0_6 = accuracy_score(y_test, y_test_pred_0_6)
test_precision_0_6 = precision_score(y_test, y_test_pred_0_6, zero_division=0)
test_recall_0_6 = recall_score(y_test, y_test_pred_0_6, zero_division=0)
test_cm_0_6 = confusion_matrix(y_test, y_test_pred_0_6)

test_f1_0_7 = f1_score(y_test, y_test_pred_0_7)
test_accuracy_0_7 = accuracy_score(y_test, y_test_pred_0_7)
test_precision_0_7 = precision_score(y_test, y_test_pred_0_7, zero_division=0)
test_recall_0_7 = recall_score(y_test, y_test_pred_0_7, zero_division=0)
test_cm_0_7 = confusion_matrix(y_test, y_test_pred_0_7)

print(f"\nTest Set Metrics (Best Model from Fold {best_fold_idx + 1}):")
print(f"  AUC: {test_auc:.4f}")
print(f"\nThreshold 0.1:")
print(f"  F1 Score: {test_f1_0_1:.4f}")
print(f"  Accuracy: {test_accuracy_0_1:.4f}")
print(f"  Precision: {test_precision_0_1:.4f}")
print(f"  Recall: {test_recall_0_1:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_1[0,0]}, FP: {test_cm_0_1[0,1]}")
print(f"    FN: {test_cm_0_1[1,0]}, TP: {test_cm_0_1[1,1]}")

print(f"\nThreshold 0.2:")
print(f"  F1 Score: {test_f1_0_2:.4f}")
print(f"  Accuracy: {test_accuracy_0_2:.4f}")
print(f"  Precision: {test_precision_0_2:.4f}")
print(f"  Recall: {test_recall_0_2:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_2[0,0]}, FP: {test_cm_0_2[0,1]}")
print(f"    FN: {test_cm_0_2[1,0]}, TP: {test_cm_0_2[1,1]}")

print(f"\nThreshold 0.3:")
print(f"  F1 Score: {test_f1_0_3:.4f}")
print(f"  Accuracy: {test_accuracy_0_3:.4f}")
print(f"  Precision: {test_precision_0_3:.4f}")
print(f"  Recall: {test_recall_0_3:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_3[0,0]}, FP: {test_cm_0_3[0,1]}")
print(f"    FN: {test_cm_0_3[1,0]}, TP: {test_cm_0_3[1,1]}")

print(f"\nThreshold 0.4:")
print(f"  F1 Score: {test_f1_0_4:.4f}")
print(f"  Accuracy: {test_accuracy_0_4:.4f}")
print(f"  Precision: {test_precision_0_4:.4f}")
print(f"  Recall: {test_recall_0_4:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_4[0,0]}, FP: {test_cm_0_4[0,1]}")
print(f"    FN: {test_cm_0_4[1,0]}, TP: {test_cm_0_4[1,1]}")

print(f"\nThreshold 0.5:")
print(f"  F1 Score: {test_f1_0_5:.4f}")
print(f"  Accuracy: {test_accuracy_0_5:.4f}")
print(f"  Precision: {test_precision_0_5:.4f}")
print(f"  Recall: {test_recall_0_5:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_5[0,0]}, FP: {test_cm_0_5[0,1]}")
print(f"    FN: {test_cm_0_5[1,0]}, TP: {test_cm_0_5[1,1]}")

print(f"\nThreshold 0.6:")
print(f"  F1 Score: {test_f1_0_6:.4f}")
print(f"  Accuracy: {test_accuracy_0_6:.4f}")
print(f"  Precision: {test_precision_0_6:.4f}")
print(f"  Recall: {test_recall_0_6:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_6[0,0]}, FP: {test_cm_0_6[0,1]}")
print(f"    FN: {test_cm_0_6[1,0]}, TP: {test_cm_0_6[1,1]}")

print(f"\nThreshold 0.7:")
print(f"  F1 Score: {test_f1_0_7:.4f}")
print(f"  Accuracy: {test_accuracy_0_7:.4f}")
print(f"  Precision: {test_precision_0_7:.4f}")
print(f"  Recall: {test_recall_0_7:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_7[0,0]}, FP: {test_cm_0_7[0,1]}")
print(f"    FN: {test_cm_0_7[1,0]}, TP: {test_cm_0_7[1,1]}")

# Determine best threshold based on test set F1 score
test_f1_scores = {
    0.1: test_f1_0_1,
    0.2: test_f1_0_2,
    0.3: test_f1_0_3,
    0.4: test_f1_0_4,
    0.5: test_f1_0_5,
    0.6: test_f1_0_6,
    0.7: test_f1_0_7
}
best_threshold = max(test_f1_scores, key=test_f1_scores.get)
best_f1 = test_f1_scores[best_threshold]
print(f"\nBest Threshold on Test Set: {best_threshold} with F1 Score: {best_f1:.4f}")

# SHAP analysis using the best model on test set
print("\nPerforming SHAP analysis on test set with best model...")
explainer = shap.TreeExplainer(best_model)
shap_values = explainer.shap_values(X_test)

# Use SHAP values for the positive class (index 1) for binary classification
# Summary plot (beeswarm)
plt.figure()
shap.summary_plot(shap_values[1], X_test, show=False)
plt.savefig("shap_summary_plot.png")
plt.close()
print("SHAP summary plot saved as 'shap_summary_plot.png'")

# Feature importance based on mean absolute SHAP values
shap_importance = np.abs(shap_values[1]).mean(axis=0)
importance_df = pd.DataFrame({
    'Feature': X_test.columns,
    'SHAP_Importance': shap_importance
}).sort_values(by='SHAP_Importance', ascending=False)
print("\nSHAP Feature Importance:")
print(importance_df)


Best Model from Fold 3 with Average F1 Score: 0.9092

Test Set Metrics (Best Model from Fold 3):
  AUC: 0.9600

Threshold 0.1:
  F1 Score: 0.8264
  Accuracy: 0.8295
  Precision: 0.7042
  Recall: 1.0000
  Confusion Matrix:
    TN: 169437, FP: 68202
    FN: 2, TP: 162359

Threshold 0.2:
  F1 Score: 0.9233
  Accuracy: 0.9326
  Precision: 0.8576
  Recall: 1.0000
  Confusion Matrix:
    TN: 210682, FP: 26957
    FN: 4, TP: 162357

Threshold 0.3:
  F1 Score: 0.9235
  Accuracy: 0.9327
  Precision: 0.8578
  Recall: 1.0000
  Confusion Matrix:
    TN: 210729, FP: 26910
    FN: 4, TP: 162357

Threshold 0.4:
  F1 Score: 0.9239
  Accuracy: 0.9331
  Precision: 0.8586
  Recall: 0.9999
  Confusion Matrix:
    TN: 210895, FP: 26744
    FN: 9, TP: 162352

Threshold 0.5:
  F1 Score: 0.9242
  Accuracy: 0.9334
  Precision: 0.8592
  Recall: 0.9997
  Confusion Matrix:
    TN: 211041, FP: 26598
    FN: 42, TP: 162319

Threshold 0.6:
  F1 Score: 0.9247
  Accuracy: 0.9341
  Precision: 0.8621
  Recall: 0.9972
 

In [6]:
import joblib
 
# Save the best model to a file
model_save_path = f"best_model_fold{best_fold_idx + 1}.pkl"
joblib.dump(best_model, model_save_path)

print(f"✅ Best model from fold {best_fold_idx + 1} saved to: {model_save_path}")


✅ Best model from fold 3 saved to: best_model_fold3.pkl


In [17]:
import pandas as pd
import numpy as np
import pickle
import warnings
import re

class HighRiskPregnancyClinicalSupport:
    def __init__(self, model_path):
        self.model = joblib.load(model_path)
        self.model_threshold = 0.4  # Optimal threshold from previous implementation
        self.feature_names = [
            'GRAVIDA', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean',
            'age_adolescent', 'age_elderly', 'age_very_young', 'previous_loss',
            'recurrent_loss', 'gravida_parity_ratio', 'inadequate_anc', 'irregular_anc',
            'anemia_mild', 'anemia_moderate', 'anemia_severe', 'ever_severe_anemia',
            'systolic_bp', 'diastolic_bp', 'hypertension', 'BMI', 'underweight',
            'obese', 'normal_weight', 'depression', 'severe_depression', 'anxiety',
            'severe_anxiety', 'weight_gain', 'weight_gain_per_week', 'inadequate_weight_gain'
        ]

    def preprocess_patient(self, raw_data):
        try:
            required_keys = ['AGE', 'HEMOGLOBIN', 'HEMOGLOBIN_min', 'ABORTIONS', 'BP_last', 'GRAVIDA', 'PARITY',
                             'HEIGHT', 'WEIGHT_first', 'WEIGHT_last', 'NO_OF_WEEKS_max', 'TOTAL_ANC_VISITS',
                             'total_missed_visits', 'PHQ_SCORE_max', 'GAD_SCORE_max', 'MISSANC1FLG', 'MISSANC2FLG',
                             'MISSANC3FLG', 'MISSANC4FLG']
            for key in required_keys:
                if key not in raw_data:
                    raise KeyError(f"Missing required field: {key}")

            # Calculate total missed visits if not provided accurately
            total_missed = sum([raw_data['MISSANC1FLG'], raw_data['MISSANC2FLG'], raw_data['MISSANC3FLG'], raw_data['MISSANC4FLG']])
            if raw_data['total_missed_visits'] != total_missed:
                raw_data['total_missed_visits'] = total_missed

            processed = {
                'age_adolescent': 1 if raw_data['AGE'] < 18 else 0,
                'age_very_young': 1 if raw_data['AGE'] < 16 else 0,
                'age_elderly': 1 if raw_data['AGE'] > 35 else 0,
                'GRAVIDA': raw_data['GRAVIDA'],
                'PARITY': raw_data['PARITY'],
                'ABORTIONS': raw_data['ABORTIONS'],
                'previous_loss': 1 if raw_data['ABORTIONS'] > 0 else 0,
                'recurrent_loss': 1 if raw_data['ABORTIONS'] >= 2 else 0,
                'gravida_parity_ratio': raw_data['GRAVIDA'] / raw_data['PARITY'] if raw_data['PARITY'] > 0 else raw_data['GRAVIDA'],
                'inadequate_anc': 1 if raw_data['TOTAL_ANC_VISITS'] < 4 else 0,
                'irregular_anc': 1 if raw_data['total_missed_visits'] >= 2 else 0,
                'HEMOGLOBIN_mean': raw_data['HEMOGLOBIN'],
                'anemia_mild': 1 if 10 <= raw_data['HEMOGLOBIN'] < 11 else 0,
                'anemia_moderate': 1 if 7 <= raw_data['HEMOGLOBIN'] < 10 else 0,
                'anemia_severe': 1 if raw_data['HEMOGLOBIN'] < 7 else 0,
                'ever_severe_anemia': 1 if raw_data['HEMOGLOBIN_min'] < 7 else 0,
                'HEIGHT': raw_data['HEIGHT'],
                'BMI': raw_data['WEIGHT_last'] / ((raw_data['HEIGHT'] / 100) ** 2) if raw_data['HEIGHT'] > 0 else 0,
                'underweight': 1 if (raw_data['WEIGHT_last'] / ((raw_data['HEIGHT'] / 100) ** 2) < 18.5) and raw_data['HEIGHT'] > 0 else 0,
                'obese': 1 if (raw_data['WEIGHT_last'] / ((raw_data['HEIGHT'] / 100) ** 2) > 30) and raw_data['HEIGHT'] > 0 else 0,
                'normal_weight': 1 if (18.5 <= raw_data['WEIGHT_last'] / ((raw_data['HEIGHT'] / 100) ** 2) <= 25) and raw_data['HEIGHT'] > 0 else 0,
                'depression': 1 if raw_data['PHQ_SCORE_max'] >= 10 else 0,
                'severe_depression': 1 if raw_data['PHQ_SCORE_max'] >= 15 else 0,
                'anxiety': 1 if raw_data['GAD_SCORE_max'] >= 10 else 0,
                'severe_anxiety': 1 if raw_data['GAD_SCORE_max'] >= 15 else 0,
                'weight_gain': raw_data['WEIGHT_last'] - raw_data['WEIGHT_first'],
                'weight_gain_per_week': (raw_data['WEIGHT_last'] - raw_data['WEIGHT_first']) / (raw_data['NO_OF_WEEKS_max'] or 1),
                'inadequate_weight_gain': 1 if ((raw_data['WEIGHT_last'] - raw_data['WEIGHT_first']) / (raw_data['NO_OF_WEEKS_max'] or 1)) < 0.2 else 0
            }

            # Blood Pressure processing with error handling
            bp_match = re.match(r'(\d+)/(\d+)', raw_data['BP_last'])
            if not bp_match:
                raise ValueError("Invalid BP_last format. Expected 'systolic/diastolic' (e.g., '120/80')")
            processed['systolic_bp'] = float(bp_match.group(1))
            processed['diastolic_bp'] = float(bp_match.group(2))
            processed['hypertension'] = 1 if processed['systolic_bp'] >= 140 or processed['diastolic_bp'] >= 90 else 0

            return pd.DataFrame([processed], columns=self.feature_names)
        except Exception as e:
            raise ValueError(f"Error in preprocessing: {str(e)}")

    def generate_high_risk_pregnancy_recommendations(self, raw_data):
        try:
            processed_data = self.preprocess_patient(raw_data)
            probability = self.model.predict_proba(processed_data)[:, 1][0]
            prediction = 1 if probability >= self.model_threshold else 0
            risk_level = self._categorize_high_risk(probability)

            recommendations = {
                'risk_assessment': {
                    'probability': float(probability),
                    'risk_level': risk_level,
                    'classification': bool(prediction)
                },
                'immediate_actions': [],
                'medication_protocols': [],
                'monitoring_protocols': {}
            }

            # Immediate Actions
            if risk_level == 'Critical High-Risk':
                recommendations['immediate_actions'].extend([
                    "🚨 CRITICAL HIGH-RISK PREGNANCY: Immediate specialized care required",
                    "Schedule emergency maternal-fetal medicine consultation within 24 hours",
                    "Consider hospitalization for comprehensive evaluation"
                ])
            elif risk_level == 'Very High-Risk':
                recommendations['immediate_actions'].extend([
                    "⚠️ VERY HIGH-RISK PREGNANCY: Enhanced monitoring required",
                    "Schedule high-risk pregnancy clinic appointment within 48 hours"
                ])
            elif risk_level == 'High-Risk':
                recommendations['immediate_actions'].extend([
                    "📋 HIGH-RISK PREGNANCY: Structured care plan required",
                    "Schedule high-risk pregnancy consultation within 1 week"
                ])
            else:
                recommendations['immediate_actions'].append("No immediate high-risk actions needed. Continue regular prenatal care.")

            # Medication Protocols
            if raw_data['HEMOGLOBIN'] < 7:
                recommendations['medication_protocols'].extend([
                    "🩸 SEVERE ANEMIA PROTOCOL:",
                    "• Iron sucrose IV 200mg in 100ml NS over 20 minutes",
                    "• Consider blood transfusion if symptomatic"
                ])
            elif raw_data['HEMOGLOBIN'] < 10:
                recommendations['medication_protocols'].extend([
                    "🩸 MODERATE ANEMIA MANAGEMENT:",
                    "• Ferrous sulfate 200mg TDS with Vitamin C",
                    "• Folic acid 5mg daily"
                ])
            if raw_data['ABORTIONS'] >= 2:
                recommendations['medication_protocols'].extend([
                    "🔄 RECURRENT LOSS PROTOCOL:",
                    "• Low-dose aspirin 75mg daily from 12 weeks",
                    "• Consider progesterone supplementation"
                ])
            if processed_data['hypertension'][0] == 1:
                recommendations['medication_protocols'].extend([
                    "🫀 HYPERTENSION MANAGEMENT:",
                    "• Methyldopa 250mg TDS (first line)",
                    "• Monitor blood pressure daily"
                ])
            if raw_data['PHQ_SCORE_max'] >= 10:
                recommendations['medication_protocols'].append("🧠 MENTAL HEALTH SUPPORT: Refer to mental health specialist for counseling")

            # Monitoring Protocols
            if risk_level in ['Critical High-Risk', 'Very High-Risk']:
                recommendations['monitoring_protocols'] = {
                    'anc_frequency': 'Weekly visits',
                    'fetal_monitoring': ['Biweekly non-stress tests (NST)', 'Weekly biophysical profile (BPP)'],
                    'maternal_monitoring': ['Daily blood pressure monitoring', 'Weekly hemoglobin checks']
                }
            else:
                recommendations['monitoring_protocols'] = {
                    'anc_frequency': 'Standard ANC schedule',
                    'fetal_monitoring': ['Standard fetal monitoring'],
                    'maternal_monitoring': ['Standard maternal monitoring']
                }

            return recommendations
        except Exception as e:
            raise ValueError(f"Error generating recommendations: {str(e)}")

    def _categorize_high_risk(self, probability):
        if probability >= 0.80:
            return 'Critical High-Risk'
        elif probability >= 0.60:
            return 'Very High-Risk'
        elif probability >= 0.40:
            return 'High-Risk'
        elif probability >= 0.20:
            return 'Moderate-Risk'
        else:
            return 'Low-Risk'

class PretermBirthClinicalSupport:
    """
    Dedicated Clinical Decision Support for Preterm Birth Prediction
    Based on model performance: AUC 0.9600, F1 Score 0.9247, Best Threshold 0.6
    Uses raw features from Telangana Maternal Health dataset
    """
    
    def __init__(self, model_path='/kaggle/working/best_model_fold3.pkl'):
        self.model_threshold = 0.6  # Updated to best threshold from new metrics
        self.model = None
        self.feature_columns = [
            'GRAVIDA', 'AGE', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean',
            'HEMOGLOBIN_min', 'HEMOGLOBIN_max', 'WEIGHT_anc_mean', 'WEIGHT_anc_min',
            'WEIGHT_anc_max', 'age_adolescent', 'age_elderly', 'age_very_young',
            'multigravida', 'grand_multipara', 'previous_loss', 'recurrent_loss',
            'gravida_parity_ratio', 'inadequate_anc', 'no_anc', 'irregular_anc',
            'missed_first_anc', 'consecutive_missed', 'anemia_mild', 'anemia_moderate',
            'anemia_severe', 'ever_severe_anemia', 'systolic_bp', 'diastolic_bp',
            'hypertension', 'severe_hypertension', 'BMI', 'underweight', 'obese',
            'normal_weight', 'depression', 'severe_depression', 'anxiety',
            'severe_anxiety', 'weight_gain', 'weight_gain_per_week',
            'inadequate_weight_gain'
        ]
        try:
            self.model = self._load_model(model_path)
        except Exception as e:
            warnings.warn(f"Failed to load model: {str(e)}. Using rule-based fallback.")
    
    def _load_model(self, model_path):
        """Load the pretrained model from file"""
        try:
            with open(model_path, 'rb') as file:
                model = pickle.load(file)
            return model
        except Exception as e:
            raise Exception(f"Error loading model from {model_path}: {str(e)}")
    
    def _engineer_features(self, patient_data):
        """Engineer features from raw patient data based on Telangana Maternal Health feature engineering"""
        try:
            required_keys = ['AGE', 'WEIGHT_last', 'WEIGHT_first', 'HEIGHT', 'HEMOGLOBIN', 'NO_OF_WEEKS_max',
                             'MISSANC1FLG', 'MISSANC2FLG', 'MISSANC3FLG', 'MISSANC4FLG', 'GRAVIDA', 'PARITY',
                             'ABORTIONS', 'TOTAL_ANC_VISITS', 'BP_last', 'PHQ_SCORE_max', 'GAD_SCORE_max']
            for key in required_keys:
                if key not in patient_data:
                    raise KeyError(f"Missing required field: {key}")

            processed = {
                'GRAVIDA': patient_data['GRAVIDA'],
                'AGE': patient_data['AGE'],
                'PARITY': patient_data['PARITY'],
                'ABORTIONS': patient_data['ABORTIONS'],
                'HEIGHT': patient_data['HEIGHT'],
                'HEMOGLOBIN_mean': patient_data['HEMOGLOBIN'],
                'HEMOGLOBIN_min': patient_data['HEMOGLOBIN'],
                'HEMOGLOBIN_max': patient_data['HEMOGLOBIN'],
                'WEIGHT_anc_mean': (patient_data['WEIGHT_last'] + patient_data['WEIGHT_first']) / 2,
                'WEIGHT_anc_min': patient_data['WEIGHT_first'],
                'WEIGHT_anc_max': patient_data['WEIGHT_last'],
                'age_adolescent': 1 if patient_data['AGE'] < 18 else 0,
                'age_very_young': 1 if patient_data['AGE'] < 16 else 0,
                'age_elderly': 1 if patient_data['AGE'] > 35 else 0,
                'multigravida': 1 if patient_data['GRAVIDA'] > 1 else 0,
                'grand_multipara': 1 if patient_data['PARITY'] > 5 else 0,
                'previous_loss': 1 if patient_data['ABORTIONS'] > 0 else 0,
                'recurrent_loss': 1 if patient_data['ABORTIONS'] >= 2 else 0,
                'gravida_parity_ratio': patient_data['GRAVIDA'] / (patient_data['PARITY'] + 1) if (patient_data['PARITY'] + 1) > 0 else 1,
                'inadequate_anc': 1 if patient_data['TOTAL_ANC_VISITS'] < 4 else 0,
                'no_anc': 1 if patient_data['TOTAL_ANC_VISITS'] == 0 else 0,
                'irregular_anc': 1 if (patient_data['MISSANC1FLG'] + patient_data['MISSANC2FLG'] + patient_data['MISSANC3FLG'] + patient_data['MISSANC4FLG']) >= 2 else 0,
                'missed_first_anc': patient_data['MISSANC1FLG'],
                'consecutive_missed': 1 if (patient_data['MISSANC1FLG'] + patient_data['MISSANC2FLG'] >= 2 or 
                                            patient_data['MISSANC2FLG'] + patient_data['MISSANC3FLG'] >= 2 or 
                                            patient_data['MISSANC3FLG'] + patient_data['MISSANC4FLG'] >= 2) else 0,
                'anemia_mild': 1 if 10 <= patient_data['HEMOGLOBIN'] < 11 else 0,
                'anemia_moderate': 1 if 7 <= patient_data['HEMOGLOBIN'] < 10 else 0,
                'anemia_severe': 1 if patient_data['HEMOGLOBIN'] < 7 else 0,
                'ever_severe_anemia': 1 if patient_data['HEMOGLOBIN'] < 7 else 0,
                'BMI': patient_data['WEIGHT_last'] / ((patient_data['HEIGHT'] / 100) ** 2) if patient_data['HEIGHT'] > 0 else 0,
                'underweight': 1 if (patient_data['WEIGHT_last'] / ((patient_data['HEIGHT'] / 100) ** 2) < 18.5) and patient_data['HEIGHT'] > 0 else 0,
                'obese': 1 if (patient_data['WEIGHT_last'] / ((patient_data['HEIGHT'] / 100) ** 2) > 30) and patient_data['HEIGHT'] > 0 else 0,
                'normal_weight': 1 if (18.5 <= patient_data['WEIGHT_last'] / ((patient_data['HEIGHT'] / 100) ** 2) <= 25) and patient_data['HEIGHT'] > 0 else 0,
                'depression': 1 if patient_data['PHQ_SCORE_max'] >= 10 else 0,
                'severe_depression': 1 if patient_data['PHQ_SCORE_max'] >= 15 else 0,
                'anxiety': 1 if patient_data['GAD_SCORE_max'] >= 10 else 0,
                'severe_anxiety': 1 if patient_data['GAD_SCORE_max'] >= 15 else 0,
                'weight_gain': patient_data['WEIGHT_last'] - patient_data['WEIGHT_first'],
                'weight_gain_per_week': (patient_data['WEIGHT_last'] - patient_data['WEIGHT_first']) / (patient_data['NO_OF_WEEKS_max'] or 1),
                'inadequate_weight_gain': 1 if ((patient_data['WEIGHT_last'] - patient_data['WEIGHT_first']) / (patient_data['NO_OF_WEEKS_max'] or 1)) < 0.2 else 0
            }

            # Blood Pressure processing
            bp_match = re.match(r'(\d+)/(\d+)', patient_data['BP_last'])
            if not bp_match:
                raise ValueError("Invalid BP_last format. Expected 'systolic/diastolic' (e.g., '120/80')")
            processed['systolic_bp'] = float(bp_match.group(1))
            processed['diastolic_bp'] = float(bp_match.group(2))
            processed['hypertension'] = 1 if processed['systolic_bp'] >= 140 or processed['diastolic_bp'] >= 90 else 0
            processed['severe_hypertension'] = 1 if processed['systolic_bp'] >= 160 or processed['diastolic_bp'] >= 110 else 0

            return pd.DataFrame([processed], columns=self.feature_columns)
        except Exception as e:
            raise ValueError(f"Error in preprocessing patient data: {str(e)}")

    def _rule_based_risk(self, patient_data):
        """Fallback rule-based risk assessment if model loading fails"""
        weight_last = patient_data.get('WEIGHT_last', 60)
        weight_first = patient_data.get('WEIGHT_first', 55)
        no_of_weeks_max = patient_data.get('NO_OF_WEEKS_max', 28)
        hemoglobin = patient_data.get('HEMOGLOBIN', 12)
        height = patient_data.get('HEIGHT', 160)
        age = patient_data.get('AGE', 25)
        missanc1flg = patient_data.get('MISSANC1FLG', 0)
        
        weight_gain = weight_last - weight_first
        weight_gain_per_week = weight_gain / no_of_weeks_max if no_of_weeks_max > 0 else 0
        bmi = weight_last / ((height / 100) ** 2) if height > 0 else 22
        
        risk_score = 0
        if weight_gain_per_week < 0.2:
            risk_score += 0.3  # High importance from SHAP
        if bmi < 18.5 or bmi > 30:
            risk_score += 0.2
        if height < 150:
            risk_score += 0.15
        if age < 20 or age > 35:
            risk_score += 0.1
        if hemoglobin < 11:
            risk_score += 0.15
        if missanc1flg == 1:
            risk_score += 0.1
        
        return min(risk_score, 1.0)  # Cap at 1.0
    
    def predict_risk(self, patient_data):
        """Predict preterm birth risk probability"""
        if self.model:
            X = self._engineer_features(patient_data)
            try:
                return self.model.predict_proba(X)[:, 1][0]
            except Exception as e:
                warnings.warn(f"Model prediction failed: {str(e)}. Using rule-based fallback.")
        return self._rule_based_risk(patient_data)
    
    def generate_preterm_birth_recommendations(self, patient_data):
        """Generate clinical recommendations for preterm birth prevention"""
        risk_probability = self.predict_risk(patient_data)
        risk_level = self._categorize_preterm_risk(risk_probability)
        gestational_age = patient_data.get('NO_OF_WEEKS_max', 28)
        
        recommendations = {
            'risk_assessment': {
                'probability': float(risk_probability),
                'risk_level': risk_level,
                'classification': risk_probability >= self.model_threshold
            },
            'immediate_interventions': [],
            'preterm_prevention_protocols': {},
            'fetal_maturation_interventions': [],
            'delivery_planning': {},
            'nicu_preparation': {},
            'medication_protocols': [],
            'monitoring_schedule': {},
            'healthcare_team_actions': {
                'asha_worker': [],
                'anm_nurse': [],
                'medical_officer': [],
                'obstetrician': [],
                'neonatologist': []
            }
        }
        
        self._add_preterm_specific_actions(recommendations, patient_data, risk_probability, 
                                         risk_level, gestational_age)
        return recommendations
    
    def _categorize_preterm_risk(self, probability):
        """Categorize preterm birth risk probability"""
        if probability >= 0.75:
            return 'Critical Preterm Risk'
        elif probability >= 0.55:
            return 'Very High Preterm Risk'
        elif probability >= 0.35:
            return 'High Preterm Risk'
        elif probability >= 0.15:
            return 'Moderate Preterm Risk'
        else:
            return 'Low Preterm Risk'
    
    def _add_preterm_specific_actions(self, recommendations, patient_data, probability, 
                                    risk_level, gestational_age):
        """Add preterm birth specific recommendations"""
        weight_last = patient_data.get('WEIGHT_last', 60)
        weight_first = patient_data.get('WEIGHT_first', 55)
        height = patient_data.get('HEIGHT', 160)
        hemoglobin = patient_data.get('HEMOGLOBIN', 12)
        missanc1flg = patient_data.get('MISSANC1FLG', 0)
        no_of_weeks_max = patient_data.get('NO_OF_WEEKS_max', 28)
        age = patient_data.get('AGE', 25)
        
        weight_gain = weight_last - weight_first
        weight_gain_per_week = weight_gain / no_of_weeks_max if no_of_weeks_max > 0 else 0
        inadequate_weight_gain = 1 if weight_gain_per_week < 0.2 else 0
        bmi = weight_last / ((height / 100) ** 2) if height > 0 else 22
        
        if risk_level == 'Critical Preterm Risk':
            recommendations['immediate_interventions'].extend([
                "🚨 CRITICAL PRETERM BIRTH RISK: Immediate intervention required",
                "Assess for signs of preterm labor immediately",
                "Hospitalization for intensive monitoring",
                "Implement comprehensive preterm prevention protocol",
                "Notify NICU team for potential preterm delivery",
                "Prepare for emergency delivery if labor progresses"
            ])
        elif risk_level == 'Very High Preterm Risk':
            recommendations['immediate_interventions'].extend([
                "⚠️ VERY HIGH PRETERM RISK: Urgent prevention measures",
                "Initiate preterm prevention protocol within 24 hours",
                "Urgent cervical length assessment via ultrasound",
                "Consider prophylactic interventions"
            ])
        elif risk_level == 'High Preterm Risk':
            recommendations['immediate_interventions'].extend([
                "📋 HIGH PRETERM RISK: Enhanced surveillance required",
                "Implement preterm monitoring protocol",
                "Schedule frequent follow-up visits"
            ])
        
        self._add_preterm_prevention_protocols(recommendations, patient_data, risk_level, gestational_age,
                                             inadequate_weight_gain, bmi, height, age, hemoglobin, missanc1flg)
        
        if risk_level in ['Critical Preterm Risk', 'Very High Preterm Risk'] and 24 <= gestational_age <= 34:
            recommendations['fetal_maturation_interventions'].extend([
                '💉 ANTENATAL CORTICOSTEROIDS:',
                '• Betamethasone 12mg IM x 2 doses, 24 hours apart',
                '• OR Dexamethasone 6mg IM q12h x 4 doses',
                '• Optimal benefit 24 hours to 7 days post-administration',
                '• Repeat course if high risk persists >14 days'
            ])
            
            if gestational_age < 32:
                recommendations['fetal_maturation_interventions'].extend([
                    '🧠 NEUROPROTECTION:',
                    '• Magnesium sulfate 4g IV bolus, then 1g/hr',
                    '• Continue until delivery or 24 hours',
                    '• For fetal neuroprotection against cerebral palsy'
                ])
        
        self._add_delivery_planning(recommendations, risk_level, gestational_age)
        
        if risk_level in ['Critical Preterm Risk', 'Very High Preterm Risk']:
            recommendations['nicu_preparation'] = {
                'nicu_notification': 'Immediate notification to NICU team',
                'bed_reservation': 'Reserve NICU bed for potential preterm delivery',
                'equipment_preparation': 'Prepare respiratory support and thermal regulation equipment',
                'staff_notification': 'Alert neonatal resuscitation team',
                'family_counseling': 'Counsel family on preterm birth outcomes and NICU expectations'
            }
        
        self._add_preterm_medications(recommendations, patient_data, risk_level, gestational_age)
        self._add_preterm_monitoring(recommendations, risk_level)
        self._add_preterm_team_actions(recommendations, risk_level)
    
    def _add_preterm_prevention_protocols(self, recommendations, patient_data, risk_level, gestational_age,
                                        inadequate_weight_gain, bmi, height, age, hemoglobin, missanc1flg):
        """Add preterm birth prevention protocols"""
        protocols = {}
        
        if inadequate_weight_gain == 1:
            protocols['nutritional_intervention'] = [
                'Urgent nutritional assessment and counseling',
                'High-calorie, high-protein diet plan (2500-3000 kcal/day)',
                'Target weight gain: 0.4-0.5kg/week in 2nd/3rd trimester',
                'Bi-weekly weight monitoring',
                'Prescribe prenatal vitamins and nutritional supplements'
            ]
        
        if bmi < 18.5 or bmi > 30:
            protocols['bmi_management'] = [
                'Detailed BMI assessment and monitoring',
                'For BMI <18.5: Nutritional supplementation to achieve healthy weight gain',
                'For BMI >30: Dietary counseling to manage weight gain',
                'Collaborate with dietitian for personalized plan'
            ]
        
        if height < 150:
            protocols['height_risk_management'] = [
                'Assess pelvic capacity due to short stature',
                'Monitor for cephalopelvic disproportion risk',
                'Consider early ultrasound for fetal size estimation'
            ]
        
        if age < 20 or age > 35:
            protocols['age_risk_management'] = [
                'Enhanced monitoring for adolescent (<20) or advanced maternal age (>35)',
                'Counsel on age-related preterm risks',
                'Weekly assessments for age-related complications'
            ]
        
        if hemoglobin < 11.0:
            protocols['anemia_correction'] = [
                'Initiate aggressive iron therapy for hemoglobin <11g/dL',
                'Target hemoglobin >11g/dL to reduce preterm risk',
                'Weekly hemoglobin monitoring during treatment',
                'Consider IV iron if oral intolerance or severe anemia'
            ]
        
        if missanc1flg == 1:
            protocols['enhanced_engagement'] = [
                'Intensive follow-up to ensure ANC compliance',
                'Weekly home visits by ASHA worker',
                'Address barriers to ANC attendance (transport, financial, cultural)',
                'Flexible scheduling for ANC appointments'
            ]
        
        if risk_level in ['High Preterm Risk', 'Very High Preterm Risk', 'Critical Preterm Risk']:
            protocols['activity_modification'] = [
                'Restrict heavy lifting and strenuous physical activities',
                'Ensure 8-10 hours sleep and afternoon rest periods',
                'Implement stress reduction techniques (e.g., mindfulness)',
                'Avoid long-distance travel after 28 weeks',
                'Pelvic rest if cervical shortening detected'
            ]
            protocols['infection_prevention'] = [
                'Screen and treat urogenital infections promptly',
                'Educate on hygiene practices to prevent infections',
                'Avoid crowded places during infection seasons',
                'Treat bacterial vaginosis if present'
            ]
        
        recommendations['preterm_prevention_protocols'] = protocols
    
    def _add_delivery_planning(self, recommendations, risk_level, gestational_age):
        """Add delivery planning for preterm risk"""
        if risk_level in ['Critical Preterm Risk', 'Very High Preterm Risk']:
            if gestational_age < 34:
                delivery_plan = {
                    'delivery_location': 'Tertiary center with Level III NICU',
                    'delivery_team': 'Obstetric team + Neonatal resuscitation team',
                    'anesthesia': 'Early anesthesia consultation for delivery planning',
                    'timing': 'Individualized based on maternal-fetal status',
                    'mode': 'Route of delivery based on obstetric indications'
                }
            else:
                delivery_plan = {
                    'delivery_location': 'Hospital with Level II nursery minimum',
                    'delivery_team': 'Standard obstetric team with neonatal support',
                    'anesthesia': 'Standard anesthesia consultation',
                    'timing': 'Aim for term delivery if stable',
                    'mode': 'Standard obstetric management'
                }
        else:
            delivery_plan = {
                'delivery_location': 'Standard delivery facility',
                'delivery_team': 'Standard obstetric care',
                'timing': 'Term delivery',
                'mode': 'Standard management'
            }
        recommendations['delivery_planning'] = delivery_plan
    
    def _add_preterm_medications(self, recommendations, patient_data, risk_level, gestational_age):
        """Add preterm birth specific medications"""
        medications = []
        if risk_level == 'Critical Preterm Risk' and 24 <= gestational_age <= 34:
            medications.extend([
                '🛑 TOCOLYTIC THERAPY (if in preterm labor):',
                '• Nifedipine 10mg sublingual, then 20mg oral q6h',
                '• OR Indomethacin 25mg q6h x 48 hours (if <32 weeks)',
                '• Limit tocolytics to 48 hours for steroid completion',
                '• Monitor maternal vital signs and fetal heart rate'
            ])
        if risk_level in ['High Preterm Risk', 'Very High Preterm Risk', 'Critical Preterm Risk']:
            medications.extend([
                '🤰 PROGESTERONE SUPPLEMENTATION:',
                '• 17α-hydroxyprogesterone caproate 250mg IM weekly',
                '• Start 16-20 weeks, continue until 36 weeks',
                '• Indicated for history of spontaneous preterm birth',
                '• Monitor for injection site reactions'
            ])
        if risk_level in ['Very High Preterm Risk', 'Critical Preterm Risk']:
            medications.extend([
                '🪡 CERVICAL CERCLAGE CONSIDERATION:',
                '• Perform transvaginal ultrasound for cervical length',
                '• Consider cerclage if cervical length <25mm before 24 weeks',
                '• Administer prophylactic antibiotics during procedure',
                '.• Recommend modified activity post-cerclage'
            ])
        recommendations['medication_protocols'] = medications
    
    def _add_preterm_monitoring(self, recommendations, risk_level):
        """Add preterm-specific monitoring schedule"""
        if risk_level == 'Critical Preterm Risk':
            monitoring = {
                'visit_frequency': 'Weekly visits',
                'cervical_assessment': 'Cervical length every 1-2 weeks',
                'fetal_monitoring': 'Non-stress test (NST) twice weekly',
                'contraction_monitoring': 'Daily symptom assessment for contractions',
                'laboratory': 'Weekly CBC, CRP if infection suspected'
            }
        elif risk_level == 'Very High Preterm Risk':
            monitoring = {
                'visit_frequency': 'Bi-weekly visits',
                'cervical_assessment': 'Cervical length every 2-3 weeks',
                'fetal_monitoring': 'Non-stress test (NST) weekly',
                'contraction_monitoring': 'Daily symptom diary',
                'laboratory': 'CBC every 2 weeks'
            }
        else:
            monitoring = {
                'visit_frequency': 'Standard ANC schedule',
                'cervical_assessment': 'Cervical length as indicated',
                'fetal_monitoring': 'Standard fetal monitoring',
                'contraction_monitoring': 'Educate on preterm labor warning signs',
                'laboratory': 'Standard ANC labs'
            }
        recommendations['monitoring_schedule'] = monitoring
    
    def _add_preterm_team_actions(self, recommendations, risk_level):
        """Add healthcare team actions for preterm birth prevention"""
        if risk_level in ['Critical Preterm Risk', 'Very High Preterm Risk']:
            recommendations['healthcare_team_actions']['asha_worker'].extend([
                'Daily home visits during high-risk periods',
                'Monitor for preterm labor symptoms (contractions, discharge)',
                'Ensure compliance with bed rest if prescribed',
                'Coordinate urgent transportation for labor signs',
                'Provide emotional support for preterm risk anxiety'
            ])
            recommendations['healthcare_team_actions']['anm_nurse'].extend([
                'Assess for preterm labor signs at each visit',
                'Monitor cervical changes if trained',
                'Educate on preterm labor warning signs',
                'Coordinate urgent specialist referrals',
                'Administer tocolytic medications if prescribed'
            ])
            recommendations['healthcare_team_actions']['medical_officer'].extend([
                'Weekly preterm labor risk assessments',
                'Perform cervical examinations as needed',
                'Coordinate referrals to tertiary care centers',
                'Manage preterm labor emergencies',
                'Counsel family on preterm risks and outcomes'
            ])
            recommendations['healthcare_team_actions']['obstetrician'].extend([
                'Comprehensive preterm risk evaluation',
                'Assess need for cervical cerclage',
                'Manage tocolytic and steroid therapy',
                'Plan delivery timing and mode',
                'Coordinate with neonatology for high-risk cases'
            ])
            recommendations['healthcare_team_actions']['neonatologist'].extend([
                'Antenatal consultation for preterm risk counseling',
                'Prepare NICU for potential preterm delivery',
                'Attend delivery for preterm births',
                'Plan ongoing care for preterm infants'
            ])

def print_high_risk_pregnancy_recommendations(recommendations):
    output = ["🤰 HIGH-RISK PREGNANCY CLINICAL DECISION SUPPORT"]
    output.append("=" * 80)

    risk_assessment = recommendations['risk_assessment']
    output.append("\n📊 HIGH-RISK PREGNANCY ASSESSMENT:")
    output.append(f"   Risk Probability: {risk_assessment['probability']:.1%}")
    output.append(f"   Risk Level: {risk_assessment['risk_level']}")
    output.append(f"   High-Risk Classification: {'YES' if risk_assessment['classification'] else 'NO'}")

    if recommendations['immediate_actions']:
        output.append("\n🚨 IMMEDIATE ACTIONS FOR HIGH-RISK PREGNANCY:")
        for i, action in enumerate(recommendations['immediate_actions'], 1):
            output.append(f"   {i}. {action}")

    if recommendations['medication_protocols']:
        output.append("\n💊 HIGH-RISK MEDICATION PROTOCOLS:")
        for protocol in recommendations['medication_protocols']:
            output.append(f"   {protocol}")

    if recommendations['monitoring_protocols']:
        output.append("\n📈 HIGH-RISK MONITORING PROTOCOLS:")
        monitoring = recommendations['monitoring_protocols']
        output.append(f"   • ANC Frequency: {monitoring.get('anc_frequency', 'Standard')}")
        if 'fetal_monitoring' in monitoring:
            output.append("   • Fetal Monitoring:")
            for item in monitoring['fetal_monitoring']:
                output.append(f"     - {item}")
        if 'maternal_monitoring' in monitoring:
            output.append("   • Maternal Monitoring:")
            for item in monitoring['maternal_monitoring']:
                output.append(f"     - {item}")

    output.append("\n" + "=" * 80)
    return "\n".join(output)

def print_preterm_birth_recommendations(recommendations):
    """Print preterm birth specific recommendations"""
    output = ["👶 PRETERM BIRTH PREVENTION CLINICAL DECISION SUPPORT"]
    output.append("="*80)

    risk_assessment = recommendations['risk_assessment']
    output.append("\n📊 PRETERM BIRTH RISK ASSESSMENT:")
    output.append(f"   Risk Probability: {risk_assessment['probability']:.1%}")
    output.append(f"   Risk Level: {risk_assessment['risk_level']}")
    output.append(f"   Preterm Risk Classification: {'YES' if risk_assessment['classification'] else 'NO'}")

    if recommendations['immediate_interventions']:
        output.append("\n🚨 IMMEDIATE PRETERM PREVENTION INTERVENTIONS:")
        for i, intervention in enumerate(recommendations['immediate_interventions'], 1):
            output.append(f"   {i}. {intervention}")

    if recommendations['preterm_prevention_protocols']:
        output.append("\n🛡️ PRETERM BIRTH PREVENTION PROTOCOLS:")
        protocols = recommendations['preterm_prevention_protocols']
        for protocol_name, actions in protocols.items():
            output.append(f"   • {protocol_name.replace('_', ' ').title()}:")
            for action in actions:
                output.append(f"     - {action}")

    if recommendations['fetal_maturation_interventions']:
        output.append("\n🧠 FETAL MATURATION INTERVENTIONS:")
        for intervention in recommendations['fetal_maturation_interventions']:
            output.append(f"   {intervention}")

    if recommendations['delivery_planning']:
        output.append("\n🏥 DELIVERY PLANNING:")
        delivery = recommendations['delivery_planning']
        output.append(f"   • Location: {delivery.get('delivery_location', 'Standard')}")
        output.append(f"   • Team: {delivery.get('delivery_team', 'Standard')}")
        output.append(f"   • Timing: {delivery.get('timing', 'Term')}")
        output.append(f"   • Mode: {delivery.get('mode', 'Standard')}")

    if recommendations['nicu_preparation']:
        output.append("\n🍼 NICU PREPARATION:")
        nicu = recommendations['nicu_preparation']
        for key, value in nicu.items():
            output.append(f"   • {key.replace('_', ' ').title()}: {value}")

    if recommendations['medication_protocols']:
        output.append("\n💊 PRETERM PREVENTION MEDICATIONS:")
        for protocol in recommendations['medication_protocols']:
            output.append(f"   {protocol}")

    if recommendations['monitoring_schedule']:
        output.append("\n📅 MONITORING SCHEDULE:")
        monitoring = recommendations['monitoring_schedule']
        output.append(f"   • Visit Frequency: {monitoring.get('visit_frequency', 'Standard')}")
        output.append(f"   • Cervical Assessment: {monitoring.get('cervical_assessment', 'Standard')}")
        output.append(f"   • Fetal Monitoring: {monitoring.get('fetal_monitoring', 'Standard')}")
        output.append(f"   • Contraction Monitoring: {monitoring.get('contraction_monitoring', 'Standard')}")
        output.append(f"   • Laboratory: {monitoring.get('laboratory', 'Standard')}")

    if recommendations['healthcare_team_actions']:
        output.append("\n👥 HEALTHCARE TEAM ACTIONS:")
        team = recommendations['healthcare_team_actions']
        for role, actions in team.items():
            if actions:
                output.append(f"   • {role.replace('_', ' ').title()}:")
                for action in actions:
                    output.append(f"     - {action}")

    output.append("\n" + "=" * 80)
    return "\n".join(output)

In [23]:
if __name__ == "__main__":
    # patient_data = {
    #     'AGE': 22,
    #     'WEIGHT_last': 62,
    #     'WEIGHT_first': 58,
    #     'HEIGHT': 155,
    #     'HEMOGLOBIN': 10.5,
    #     'NO_OF_WEEKS_max': 30,
    #     'MISSANC1FLG': 1,
    #     'MISSANC2FLG': 0,
    #     'MISSANC3FLG': 0,
    #     'MISSANC4FLG': 0,
    #     'GRAVIDA': 2,
    #     'PARITY': 1,
    #     'ABORTIONS': 0,
    #     'TOTAL_ANC_VISITS': 3,
    #     'BP_last': '130/85',
    #     'PHQ_SCORE_max': 8
    # }
    patient_data = {
        'AGE': 17,
        'WEIGHT_last': 50,
        'WEIGHT_first': 56,
        'HEIGHT': 160,
        'HEMOGLOBIN': 12.0,
        'NO_OF_WEEKS_max': 3,
        'MISSANC1FLG': 0,
        'MISSANC2FLG': 0,
        'MISSANC3FLG': 0,
        'MISSANC4FLG': 0,
        'GRAVIDA': 1,
        'PARITY': 0,
        'ABORTIONS': 5,
        'TOTAL_ANC_VISITS': 4,
        'BP_last': '118/76',
        'PHQ_SCORE_max': 2,
        'GAD_SCORE_max': 2  # Added to reflect no anxiety (similar to PHQ_SCORE_max)
    }
    
    clinical_support = PretermBirthClinicalSupport(model_path='/kaggle/working/best_model_fold3.pkl')
    recommendations = clinical_support.generate_preterm_birth_recommendations(patient_data)
    print(print_preterm_birth_recommendations(recommendations))

👶 PRETERM BIRTH PREVENTION CLINICAL DECISION SUPPORT

📊 PRETERM BIRTH RISK ASSESSMENT:
   Risk Probability: 40.0%
   Risk Level: High Preterm Risk
   Preterm Risk Classification: NO

🚨 IMMEDIATE PRETERM PREVENTION INTERVENTIONS:
   1. 📋 HIGH PRETERM RISK: Enhanced surveillance required
   2. Implement preterm monitoring protocol
   3. Schedule frequent follow-up visits

🛡️ PRETERM BIRTH PREVENTION PROTOCOLS:
   • Nutritional Intervention:
     - Urgent nutritional assessment and counseling
     - High-calorie, high-protein diet plan (2500-3000 kcal/day)
     - Target weight gain: 0.4-0.5kg/week in 2nd/3rd trimester
     - Bi-weekly weight monitoring
     - Prescribe prenatal vitamins and nutritional supplements
   • Age Risk Management:
     - Enhanced monitoring for adolescent (<20) or advanced maternal age (>35)
     - Counsel on age-related preterm risks
     - Weekly assessments for age-related complications
   • Activity Modification:
     - Restrict heavy lifting and strenuous ph