In [2]:
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_score, recall_score
import time
import warnings
import pyarrow.parquet as pq
warnings.filterwarnings('ignore')

data_path = '/kaggle/input/fullfinal/telangana_data_with_features_and_targets (1).parquet'
flag_map = {
    'Y': 1, 'YES': 1, 'Yes': 1, 'y': 1, 'yes': 1,
    'N': 0, 'NO': 0, 'No': 0, 'n': 0, 'no': 0,
    None: np.nan, 'None': np.nan, '': np.nan, 'nan': np.nan
}

def prepare_data_for_targets(data_path, flag_map, batch_size=10000):
    """Load and preprocess data from Parquet file in batches."""
    # Initialize an empty list to store processed chunks
    processed_chunks = []
    
    # Define required and numeric columns
    required_cols = ['MOTHER_ID', 'GRAVIDA']
    numeric_cols = ['AGE', 'AGE_preg', 'AGE_final', 'GRAVIDA', 'PARITY', 'ABORTIONS', 'TOTAL_ANC_VISITS',
                    'HEMOGLOBIN_mean', 'HEMOGLOBIN_min', 'WEIGHT_max', 'HEIGHT',
                    'PHQ_SCORE_max', 'GAD_SCORE_max', 'WEIGHT_last', 'WEIGHT_first',
                    'NO_OF_WEEKS_max', 'WEIGHT_min', 'WEIGHT_mean']
    
    # Open Parquet file using pyarrow
    parquet_file = pq.ParquetFile(data_path)
    num_rows = parquet_file.metadata.num_rows
    num_batches = (num_rows + batch_size - 1) // batch_size  # Ceiling division
    
    # Iterate through batches
    for batch_idx in range(num_batches):
        # Read a batch of rows
        start_row = batch_idx * batch_size
        end_row = min(start_row + batch_size, num_rows)
        batch = parquet_file.read_row_group(batch_idx).to_pandas() if parquet_file.num_row_groups > batch_idx else pd.DataFrame()
        
        if batch.empty:
            continue
            
        # Check for required columns
        for col in required_cols:
            if col not in batch.columns:
                raise ValueError(f"Required column {col} missing in data")

        # Debug: Check non-numeric values
        for col in numeric_cols:
            if col in batch.columns and batch[col].dtype == 'object':
                non_numeric = batch[col][~batch[col].apply(lambda x: str(x).replace('.', '').isdigit() if pd.notna(x) else False)]
                if not non_numeric.empty:
                    print(f"Non-numeric values in {col}: {non_numeric.unique()[:10]}")

        # Clean GRAVIDA specifically
        if 'GRAVIDA' in batch.columns:
            batch['GRAVIDA'] = batch['GRAVIDA'].replace('nan', np.nan)
            batch['GRAVIDA'] = pd.to_numeric(batch['GRAVIDA'], errors='coerce')
            # Replace NaN with median GRAVIDA (or 1 as default) within batch
            if batch['GRAVIDA'].notna().any():
                median_gravida = batch['GRAVIDA'].median()
                if pd.isna(median_gravida):
                    median_gravida = 1.0
                batch['GRAVIDA'] = batch['GRAVIDA'].fillna(0)
                print(f"Filled {batch['GRAVIDA'].isna().sum()} GRAVIDA NaN values with 0 in batch")

        # Convert other numeric columns
        for col in numeric_cols:
            if col in batch.columns and col != 'GRAVIDA':
                batch[col] = pd.to_numeric(batch[col], errors='coerce')

        # Map flag columns
        if 'IS_CHILD_DEATH' in batch.columns and batch['IS_CHILD_DEATH'].dtype == 'object':
            batch['IS_CHILD_DEATH'] = batch['IS_CHILD_DEATH'].map(flag_map)
        if 'IS_DEFECTIVE_BIRTH' in batch.columns and batch['IS_DEFECTIVE_BIRTH'].dtype == 'object':
            batch['IS_DEFECTIVE_BIRTH'] = batch['IS_DEFECTIVE_BIRTH'].map(flag_map)

        # Fill NaN for numeric columns
        batch_numeric_cols = batch.select_dtypes(include=['float64', 'float32', 'int64', 'int32', 'int8']).columns
        batch[batch_numeric_cols] = batch[batch_numeric_cols].fillna(0)

        # Append processed batch
        processed_chunks.append(batch)

    # Concatenate all chunks
    df = pd.concat(processed_chunks, ignore_index=True)
    return df

# Execute the function
df = prepare_data_for_targets(data_path, flag_map)

Non-numeric values in GRAVIDA: ['nan']
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch
Filled 0 GRAVIDA NaN values with 0 in batch


In [3]:
target_columns = [
            'high_risk_pregnancy'
        ]
def create_stratified_sample(df, target_column, sample_size=2000000):
    """Create a stratified sample ensuring all critical cases are included."""
    child_death_col = 'IS_CHILD_DEATH'
    critical_cases = pd.DataFrame()
    
    if child_death_col in df.columns:
        child_deaths = df[df[child_death_col] == 1]
        critical_cases = pd.concat([critical_cases, child_deaths])
    if 'maternal_mortality_risk' in df.columns:
        maternal_deaths = df[df['maternal_mortality_risk'] == 1]
        critical_cases = pd.concat([critical_cases, maternal_deaths])
    if 'stillbirth_risk' in df.columns:
        stillbirths = df[df['stillbirth_risk'] == 1]
        critical_cases = pd.concat([critical_cases, stillbirths])
    critical_cases = critical_cases.drop_duplicates()
    
    remaining_size = sample_size - len(critical_cases)
    
    if remaining_size > 0:
        other_cases = df[~df.index.isin(critical_cases.index)]
        sampled_others = other_cases.sample(n=remaining_size, random_state=42)
        final_sample = pd.concat([critical_cases, sampled_others])
    else:
        final_sample = critical_cases.sample(n=sample_size, random_state=42)
    
    return final_sample

sample_df=create_stratified_sample(df, 'high_risk_pregnancy')

# Exclude target-related columns to prevent leakage
# exclude_cols = target_columns + [
#     'total_risk_factors', 'clinical_risk_score', 'risk_level', 'predicted_risk',
#     'total_missed_visits', 'age_risk_score', 'demographic_risk', 'anemia_risk_score',
#     'overall_risk_score', 'MOTHER_ID', 'ANC_ID', 'IS_CHILD_DEATH',
#     'IS_DEFECTIVE_BIRTH', 'NO_OF_WEEKS_max', 'DELIVERY_OUTCOME', 'IS_MOTHER_ALIVE',
#     'MISSANC1FLG', 'MISSANC2FLG', 'MISSANC3FLG', 'MISSANC4FLG', 'stillbirth_risk', 'premature_birth_risk',
#             'maternal_mortality_risk', 'birth_defect_risk', 'anc_dropout',
    
#     # Newly added columns
#     'ANC_INSTITUTE', 'FACILITY_TYPE', 'FACILITY_NAME', 'DOCTOR_ANM', 'IFA_TABLET',
#     'TT_GIVEN', 'USG_SCAN', 'DISTRICT_anc', 'TIME_OF_BIRTH', 'IS_BF_IN_HOUR',
#     'SNCU_REFERRAL_HOSPITAL', 'TERTIARY_REFERRAL_HOSPITAL', 'OTHER_REFERRAL_HOSPITAL',
#     'IMMUNE_CYCLE_DONE', 'CHILD_DEATH_DATE', 'CHILD_DEATH_REASON', 'DEFECT_HEALTH_CENTER',
#     'DISTRICT_child', 'BIRTH_DEFECT_TYPE', 'BIRTH_DEFECT_SUBTYPE', 'DEFECT_SUBTYPE_OTHER',
#     'NOTIFICATION_SENT', 'FBIR_COMPLETED_BY_ANM', 'CHILD_NAME', 'DEFECT_TYPE_OTHER',
#     'DEATH_REASON_OTHER', 'NEWBORN_SCREENING', 'EID', 'UID_NUMBER', 'SNCU_ADMITTED',
#     'CONSANGUINITY', 'HIGH_RISKS', 'DISEASES', 'FEEDING_TYPE', 'DATE_OF_FIRST_FEEDING',
#     'TIME_OF_FIRST_FEEDING', 'DATE_OF_BLOODSAMPLE_COLLECTION', 'TIME_OF_BLOODSAMPLE_COLLECTION',
#     'HOURS_OF_SAMPLE_COLLECTION', 'TRANSFUSION_DONE', 'BABY_ON_MEDICATION', 'MEDICATION_REMARKS',
#     'DATE_OF_DELIVERY', 'PLACE_OF_DELIVERY', 'MODE_OF_DELIVERY', 'MATERNAL_OUTCOME',
#     'REASON_FOR_DEATH', 'DATE_OF_DEATH', 'PLACE_OF_DEATH', 'INDICATION_FOR_C_SECTION',
#     'CH', 'CAH', 'GALACTOCEMIA', 'G6PDD', 'BIOTINIDASE', 'DELIVERY_INSTITUTION',
#     'DELIVERY_DONE_BY', 'IS_DELIVERED', 'DATE_OF_DISCHARGE', 'DEL_TIME', 'MISOPROSTAL_TABLET',
#     'CONDUCT_BY', 'OTHER_NAME', 'JSY_BENEFICIARY', 'DISCHARGE_TIME', 'DEL_COMPLICATIONS',
#     'OTHER_DEL_COMPLICATIONS', 'NOTIFICATION_SENT_del', 'FBIR_COMPLETED_BY_ANM_del',
#     'OTHER_STATE_PLACE', 'OTHER_STATE_PLACE_FILEPATH', 'OTHER_GOVT_PLACE_FILEPATH',
#     'REGISTRATION_DT', 'REGTYPE', 'CURRENT_USR', 'ANC2_TAG_FAC_ID', 'ANC3_TAG_FAC_ID',
#     'VDRL_DATE', 'VDRL_STATUS', 'HIV_DATE', 'HIV_STATUS', 'HBSAG_DATE', 'HBSAG_STATUS',
#     'HEP_DATE', 'HEP_STATUS'
# ]
exclude_cols = [
    # Targets & labels
    'maternal_mortality_risk', 'stillbirth_risk', 'premature_birth_risk',
    'birth_defect_risk', 'anc_dropout', 'high_risk_pregnancy',

    # Derived risk scores and flags
    'total_risk_factors', 'clinical_risk_score', 'risk_level', 'predicted_risk',
    'total_missed_visits', 'age_risk_score', 'demographic_risk', 'anemia_risk_score',
    'overall_risk_score',

    # Identifiers
    'MOTHER_ID', 'ANC_ID', 'CHILD_ID', 'unique_id', 'EID', 'UID_NUMBER',

    # Delivery outcome or post-delivery info (leakage)
    'WEIGHT_child_mean', 'DELIVERY_MODE', 'MATERNAL_OUTCOME', 'IS_DELIVERED',
    'DELIVERY_OUTCOME', 'DATE_OF_DELIVERY', 'PLACE_OF_DELIVERY', 'DEL_TIME',
    'DATE_OF_DISCHARGE', 'DISCHARGE_TIME', 'JSY_BENEFICIARY',
    'IS_MOTHER_ALIVE', 'IS_CHILD_DEATH', 'CHILD_DEATH_DATE', 'CHILD_DEATH_REASON',
    'DEFECT_HEALTH_CENTER', 'IS_DEFECTIVE_BIRTH', 'BIRTH_DEFECT_TYPE',
    'BIRTH_DEFECT_SUBTYPE', 'DEFECT_SUBTYPE_OTHER', 'DEFECT_TYPE_OTHER',
    'NOTIFICATION_SENT', 'FBIR_COMPLETED_BY_ANM', 'NEWBORN_SCREENING',
    'DEATH_REASON_OTHER', 'SNCU_ADMITTED', 'SNCU_REFERRAL_HOSPITAL',
    'TERTIARY_REFERRAL_HOSPITAL', 'OTHER_REFERRAL_HOSPITAL',
    'DATE_OF_DEATH', 'REASON_FOR_DEATH', 'PLACE_OF_DEATH',
    'INDICATION_FOR_C_SECTION', 'CH', 'CAH', 'GALACTOCEMIA', 'G6PDD', 'BIOTINIDASE',

    # Delivery-specific process info
    'DELIVERY_INSTITUTION', 'DELIVERY_DONE_BY', 'CONDUCT_BY',
    'MISOPROSTAL_TABLET', 'DEL_COMPLICATIONS', 'OTHER_DEL_COMPLICATIONS',
    'NOTIFICATION_SENT_del', 'FBIR_COMPLETED_BY_ANM_del',

    # Administrative / logging columns
    'REGISTRATION_DT', 'REGTYPE', 'CURRENT_USR', 'OTHER_STATE_PLACE',
    'OTHER_STATE_PLACE_FILEPATH', 'OTHER_GOVT_PLACE_FILEPATH',
    'ANC2_TAG_FAC_ID', 'ANC3_TAG_FAC_ID',

    # Facility and geographic info
    'ANC_INSTITUTE', 'FACILITY_TYPE', 'FACILITY_NAME', 'DOCTOR_ANM',
    'DISTRICT_anc', 'DISTRICT_child',

    # Feeding / newborn care
    'IS_BF_IN_HOUR', 'FEEDING_TYPE', 'DATE_OF_FIRST_FEEDING',
    'TIME_OF_FIRST_FEEDING', 'BABY_ON_MEDICATION', 'MEDICATION_REMARKS',
    'DATE_OF_BLOODSAMPLE_COLLECTION', 'TIME_OF_BLOODSAMPLE_COLLECTION',
    'HOURS_OF_SAMPLE_COLLECTION', 'TRANSFUSION_DONE',

    # Screening/test results
    'VDRL_DATE', 'VDRL_STATUS', 'HIV_DATE', 'HIV_STATUS', 'HBSAG_DATE',
    'HBSAG_STATUS', 'HEP_DATE', 'HEP_STATUS', 'VDRL_RESULT',
    'HIV_RESULT', 'HBSAG_RESULT', 'HEP_RESULT',

    # Missed ANC flag columns
    'MISSANC1FLG', 'MISSANC2FLG', 'MISSANC3FLG', 'MISSANC4FLG',

    # Manually added known leaky or post-hoc columns
    'HIGH_RISKS', 'DISEASES', 'CHILD_NAME',

    # Any column with 'risk' or 'score' in name not already excluded
    *[col for col in df.columns if (
        ('risk' in col.lower() or 'score' in col.lower()) and col not in {
            'age_risk_score', 'anemia_risk_score', 'overall_risk_score',
            'demographic_risk', 'clinical_risk_score', 'total_risk_factors',
            'maternal_mortality_risk', 'stillbirth_risk', 'premature_birth_risk',
            'birth_defect_risk', 'high_risk_pregnancy', 'mental_health_risk'
        }
    )],
]

more_leakage_cols = [
    # aggregate or binary risk flags
    'bp_risk', 'mental_health_risk', 'age_category', 'multigravida', 'grand_multipara',
     'no_anc',
    'missed_first_anc', 'consecutive_missed',
    # anemia + BP severity flags
    'severe_hypertension',
    # weight / BMI flags
    'low_birth_weight', 'very_low_birth_weight', 'avg_birth_weight_low',
    # mental-health flags
    # trend or helper columns
    'hemoglobin_trend', 'unique_id', 'TT_DATE', 'MAL_PRESENT','IS_ADMITTED_SNCU',
 'IS_PREV_PREG','CHILD_ID', 'ANC1FLG',
 'ANC2FLG',
 'ANC3FLG',
 'ANC4FLG',
 'ANC_DATE',
 'DEATH',
 'EXP_DOD',
 'DELIVERY_PLACE',
 'EXP_DOD_preg',
 'FASTING',
 'LMP_DT',
 'SCREENED_FOR_MENTAL_HEALTH',
 'AGE',
 'AGE_final',
 'AGE_preg',
 'WEIGHT_child_min',
'GENDER',
    'TIME_OF_BIRTH',
    'RNK', 'HEMOGLOBIN_min', 'HEMOGLOBIN_max', 'WEIGHT_anc_mean', 'WEIGHT_anc_min', 'WEIGHT_anc_max', 'NO_OF_WEEKS_max', 'TWIN_PREGNANCY_max', 'TOTAL_ANC_VISITS',
    # any other “risk / score” not already in exclude list
    *[c for c in df.columns if ('risk' in c.lower() or 'score' in c.lower()) and c not in exclude_cols]
]

# Finally combine everything
exclude_cols = list(set(exclude_cols + more_leakage_cols))

features = [col for col in df.columns if col not in exclude_cols and df[col].dtype in ['float64', 'float32', 'int64', 'int32', 'int8']]

if not features:
    print(f"Skipping {target_columns}: No valid features available.")



In [5]:
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, f1_score
import lightgbm as lgb
import numpy as np
import time

# Assuming sample_df, features, and target_column are defined (correcting target_columns to target_column)
X = sample_df[features]
y = sample_df[target_columns]  # Changed from target_columns to target_column

# Initial train-test split
X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Initialize k-fold cross-validation
n_splits = 5  # Number of folds
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Lists to store metrics for each fold
auc_scores = []
f1_scores = []
fold_times = []

# Calculate scale_pos_weight for class imbalance (based on training dataset)
neg_count = (y_train_full == 0).sum()
pos_count = (y_train_full == 1).sum()
scale_pos_weight = neg_count[0] / pos_count[0] if pos_count[0] > 0 else 1  # Simplified, assuming y is a Series


In [6]:
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, confusion_matrix
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import time
import shap
import matplotlib.pyplot as plt
import pandas as pd

# Assuming sample_df, features, and target_column are defined
# X = sample_df[features]
# y = sample_df[target_columns]

# # Initial train-test split
# X_train_full, X_test, y_train_full, y_test = train_test_split(
#     X, y, test_size=0.2, random_state=42, stratify=y
# )

# Initialize k-fold cross-validation
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Lists to store metrics and models for each fold
metrics = {
    'thresh_0_1': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_2': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_3': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'thresh_0_4': {'f1': [], 'accuracy': [], 'precision': [], 'recall': []},
    'auc': []  # AUC is threshold-independent
}
fold_models = []  # Store models from each fold
fold_avg_f1 = []  # Store average F1 score across thresholds for each fold

# Random Forest parameters (tuned for stability)
params = {
    'n_estimators': 100,  # Number of trees
    'max_depth': 10,      # Limit depth to prevent overfitting
    'min_samples_split': 50,  # Minimum samples to split a node
    'min_samples_leaf': 25,   # Minimum samples per leaf
    'class_weight': 'balanced',  # Handle class imbalance
    'random_state': 42,
    'n_jobs': -1  # Use all available cores
}

# Perform k-fold cross-validation
for fold, (train_idx, val_idx) in enumerate(skf.split(X_train_full, y_train_full)):
    print(f"Training fold {fold + 1}/{n_splits}")
    
    # Split data
    X_train, X_val = X_train_full.iloc[train_idx], X_train_full.iloc[val_idx]
    y_train, y_val = y_train_full.iloc[train_idx], y_train_full.iloc[val_idx]
    
    # Print class distribution
    print(f"Training class counts:\n{y_train.value_counts()}")
    print(f"Validation class counts:\n{y_val.value_counts()}")
    
    # Train Random Forest model
    start_time = time.time()
    model = RandomForestClassifier(**params)
    model.fit(X_train, y_train)
    fold_time = time.time() - start_time
    fold_models.append(model)  # Store the model
    
    # Predict and evaluate on validation set
    y_pred_proba = model.predict_proba(X_val)[:, 1]  # Probability for positive class
    
    # Fixed thresholds (0.1, 0.2, 0.3, 0.4)
    y_pred_0_1 = (y_pred_proba > 0.1).astype(int)
    y_pred_0_2 = (y_pred_proba > 0.2).astype(int)
    y_pred_0_3 = (y_pred_proba > 0.3).astype(int)
    y_pred_0_4 = (y_pred_proba > 0.4).astype(int)
    
    # Metrics for threshold 0.1
    f1_0_1 = f1_score(y_val, y_pred_0_1)
    accuracy_0_1 = accuracy_score(y_val, y_pred_0_1)
    precision_0_1 = precision_score(y_val, y_pred_0_1, zero_division=0)
    recall_0_1 = recall_score(y_val, y_pred_0_1, zero_division=0)
    cm_0_1 = confusion_matrix(y_val, y_pred_0_1)
    
    # Metrics for threshold 0.2
    f1_0_2 = f1_score(y_val, y_pred_0_2)
    accuracy_0_2 = accuracy_score(y_val, y_pred_0_2)
    precision_0_2 = precision_score(y_val, y_pred_0_2, zero_division=0)
    recall_0_2 = recall_score(y_val, y_pred_0_2, zero_division=0)
    cm_0_2 = confusion_matrix(y_val, y_pred_0_2)
    
    # Metrics for threshold 0.3
    f1_0_3 = f1_score(y_val, y_pred_0_3)
    accuracy_0_3 = accuracy_score(y_val, y_pred_0_3)
    precision_0_3 = precision_score(y_val, y_pred_0_3, zero_division=0)
    recall_0_3 = recall_score(y_val, y_pred_0_3, zero_division=0)
    cm_0_3 = confusion_matrix(y_val, y_pred_0_3)
    
    # Metrics for threshold 0.4
    f1_0_4 = f1_score(y_val, y_pred_0_4)
    accuracy_0_4 = accuracy_score(y_val, y_pred_0_4)
    precision_0_4 = precision_score(y_val, y_pred_0_4, zero_division=0)
    recall_0_4 = recall_score(y_val, y_pred_0_4, zero_division=0)
    cm_0_4 = confusion_matrix(y_val, y_pred_0_4)
    
    # AUC (threshold-independent)
    auc = roc_auc_score(y_val, y_pred_proba)
    
    # Store metrics
    metrics['thresh_0_1']['f1'].append(f1_0_1)
    metrics['thresh_0_1']['accuracy'].append(accuracy_0_1)
    metrics['thresh_0_1']['precision'].append(precision_0_1)
    metrics['thresh_0_1']['recall'].append(recall_0_1)
    
    metrics['thresh_0_2']['f1'].append(f1_0_2)
    metrics['thresh_0_2']['accuracy'].append(accuracy_0_2)
    metrics['thresh_0_2']['precision'].append(precision_0_2)
    metrics['thresh_0_2']['recall'].append(recall_0_2)
    
    metrics['thresh_0_3']['f1'].append(f1_0_3)
    metrics['thresh_0_3']['accuracy'].append(accuracy_0_3)
    metrics['thresh_0_3']['precision'].append(precision_0_3)
    metrics['thresh_0_3']['recall'].append(recall_0_3)
    
    metrics['thresh_0_4']['f1'].append(f1_0_4)
    metrics['thresh_0_4']['accuracy'].append(accuracy_0_4)
    metrics['thresh_0_4']['precision'].append(precision_0_4)
    metrics['thresh_0_4']['recall'].append(recall_0_4)
    
    metrics['auc'].append(auc)
    
    # Calculate average F1 score across thresholds for this fold
    avg_f1 = np.mean([f1_0_1, f1_0_2, f1_0_3, f1_0_4])
    fold_avg_f1.append(avg_f1)
    
    print(f"Fold {fold + 1}")
    print(f"  AUC: {auc:.4f}, Time: {fold_time:.2f} seconds")
    print(f"  Threshold 0.1 - F1: {f1_0_1:.4f}, Accuracy: {accuracy_0_1:.4f}, Precision: {precision_0_1:.4f}, Recall: {recall_0_1:.4f}")
    print("  Confusion Matrix (Threshold 0.1):")
    print(f"    TN: {cm_0_1[0,0]}, FP: {cm_0_1[0,1]}")
    print(f"    FN: {cm_0_1[1,0]}, TP: {cm_0_1[1,1]}")
    print(f"  Threshold 0.2 - F1: {f1_0_2:.4f}, Accuracy: {accuracy_0_2:.4f}, Precision: {precision_0_2:.4f}, Recall: {recall_0_2:.4f}")
    print("  Confusion Matrix (Threshold 0.2):")
    print(f"    TN: {cm_0_2[0,0]}, FP: {cm_0_2[0,1]}")
    print(f"    FN: {cm_0_2[1,0]}, TP: {cm_0_2[1,1]}")
    print(f"  Threshold 0.3 - F1: {f1_0_3:.4f}, Accuracy: {accuracy_0_3:.4f}, Precision: {precision_0_3:.4f}, Recall: {recall_0_3:.4f}")
    print("  Confusion Matrix (Threshold 0.3):")
    print(f"    TN: {cm_0_3[0,0]}, FP: {cm_0_3[0,1]}")
    print(f"    FN: {cm_0_3[1,0]}, TP: {cm_0_3[1,1]}")
    print(f"  Threshold 0.4 - F1: {f1_0_4:.4f}, Accuracy: {accuracy_0_4:.4f}, Precision: {precision_0_4:.4f}, Recall: {recall_0_4:.4f}")
    print("  Confusion Matrix (Threshold 0.4):")
    print(f"    TN: {cm_0_4[0,0]}, FP: {cm_0_4[0,1]}")
    print(f"    FN: {cm_0_4[1,0]}, TP: {cm_0_4[1,1]}")

print("Features used for training:", list(X_train_full.columns))

# Summary statistics
print(f"\nCross-Validation Mean Metrics:")
print(f"  AUC: {np.mean(metrics['auc']):.4f} ± {np.std(metrics['auc']):.4f}")
print(f"\nThreshold 0.1:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_1']['f1']):.4f} ± {np.std(metrics['thresh_0_1']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_1']['accuracy']):.4f} ± {np.std(metrics['thresh_0_1']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_1']['precision']):.4f} ± {np.std(metrics['thresh_0_1']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_1']['recall']):.4f} ± {np.std(metrics['thresh_0_1']['recall']):.4f}")

print(f"\nThreshold 0.2:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_2']['f1']):.4f} ± {np.std(metrics['thresh_0_2']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_2']['accuracy']):.4f} ± {np.std(metrics['thresh_0_2']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_2']['precision']):.4f} ± {np.std(metrics['thresh_0_2']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_2']['recall']):.4f} ± {np.std(metrics['thresh_0_2']['recall']):.4f}")

print(f"\nThreshold 0.3:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_3']['f1']):.4f} ± {np.std(metrics['thresh_0_3']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_3']['accuracy']):.4f} ± {np.std(metrics['thresh_0_3']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_3']['precision']):.4f} ± {np.std(metrics['thresh_0_3']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_3']['recall']):.4f} ± {np.std(metrics['thresh_0_3']['recall']):.4f}")

print(f"\nThreshold 0.4:")
print(f"  F1 Score: {np.mean(metrics['thresh_0_4']['f1']):.4f} ± {np.std(metrics['thresh_0_4']['f1']):.4f}")
print(f"  Accuracy: {np.mean(metrics['thresh_0_4']['accuracy']):.4f} ± {np.std(metrics['thresh_0_4']['accuracy']):.4f}")
print(f"  Precision: {np.mean(metrics['thresh_0_4']['precision']):.4f} ± {np.std(metrics['thresh_0_4']['precision']):.4f}")
print(f"  Recall: {np.mean(metrics['thresh_0_4']['recall']):.4f} ± {np.std(metrics['thresh_0_4']['recall']):.4f}")

# Select the best model based on average F1 score across thresholds
best_fold_idx = np.argmax(fold_avg_f1)
best_model = fold_models[best_fold_idx]
print(f"\nBest Model from Fold {best_fold_idx + 1} with Average F1 Score: {fold_avg_f1[best_fold_idx]:.4f}")

# Evaluate best model on test set
y_test_pred_proba = best_model.predict_proba(X_test)[:, 1]
y_test_pred_0_1 = (y_test_pred_proba > 0.1).astype(int)
y_test_pred_0_2 = (y_test_pred_proba > 0.2).astype(int)
y_test_pred_0_3 = (y_test_pred_proba > 0.3).astype(int)
y_test_pred_0_4 = (y_test_pred_proba > 0.4).astype(int)

# Test set metrics
test_auc = roc_auc_score(y_test, y_test_pred_proba)
test_f1_0_1 = f1_score(y_test, y_test_pred_0_1)
test_accuracy_0_1 = accuracy_score(y_test, y_test_pred_0_1)
test_precision_0_1 = precision_score(y_test, y_test_pred_0_1, zero_division=0)
test_recall_0_1 = recall_score(y_test, y_test_pred_0_1, zero_division=0)
test_cm_0_1 = confusion_matrix(y_test, y_test_pred_0_1)

test_f1_0_2 = f1_score(y_test, y_test_pred_0_2)
test_accuracy_0_2 = accuracy_score(y_test, y_test_pred_0_2)
test_precision_0_2 = precision_score(y_test, y_test_pred_0_2, zero_division=0)
test_recall_0_2 = recall_score(y_test, y_test_pred_0_2, zero_division=0)
test_cm_0_2 = confusion_matrix(y_test, y_test_pred_0_2)

test_f1_0_3 = f1_score(y_test, y_test_pred_0_3)
test_accuracy_0_3 = accuracy_score(y_test, y_test_pred_0_3)
test_precision_0_3 = precision_score(y_test, y_test_pred_0_3, zero_division=0)
test_recall_0_3 = recall_score(y_test, y_test_pred_0_3, zero_division=0)
test_cm_0_3 = confusion_matrix(y_test, y_test_pred_0_3)

test_f1_0_4 = f1_score(y_test, y_test_pred_0_4)
test_accuracy_0_4 = accuracy_score(y_test, y_test_pred_0_4)
test_precision_0_4 = precision_score(y_test, y_test_pred_0_4, zero_division=0)
test_recall_0_4 = recall_score(y_test, y_test_pred_0_4, zero_division=0)
test_cm_0_4 = confusion_matrix(y_test, y_test_pred_0_4)

print(f"\nTest Set Metrics (Best Model from Fold {best_fold_idx + 1}):")
print(f"  AUC: {test_auc:.4f}")
print(f"\nThreshold 0.1:")
print(f"  F1 Score: {test_f1_0_1:.4f}")
print(f"  Accuracy: {test_accuracy_0_1:.4f}")
print(f"  Precision: {test_precision_0_1:.4f}")
print(f"  Recall: {test_recall_0_1:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_1[0,0]}, FP: {test_cm_0_1[0,1]}")
print(f"    FN: {test_cm_0_1[1,0]}, TP: {test_cm_0_1[1,1]}")

print(f"\nThreshold 0.2:")
print(f"  F1 Score: {test_f1_0_2:.4f}")
print(f"  Accuracy: {test_accuracy_0_2:.4f}")
print(f"  Precision: {test_precision_0_2:.4f}")
print(f"  Recall: {test_recall_0_2:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_2[0,0]}, FP: {test_cm_0_2[0,1]}")
print(f"    FN: {test_cm_0_2[1,0]}, TP: {test_cm_0_2[1,1]}")

print(f"\nThreshold 0.3:")
print(f"  F1 Score: {test_f1_0_3:.4f}")
print(f"  Accuracy: {test_accuracy_0_3:.4f}")
print(f"  Precision: {test_precision_0_3:.4f}")
print(f"  Recall: {test_recall_0_3:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_3[0,0]}, FP: {test_cm_0_3[0,1]}")
print(f"    FN: {test_cm_0_3[1,0]}, TP: {test_cm_0_3[1,1]}")

print(f"\nThreshold 0.4:")
print(f"  F1 Score: {test_f1_0_4:.4f}")
print(f"  Accuracy: {test_accuracy_0_4:.4f}")
print(f"  Precision: {test_precision_0_4:.4f}")
print(f"  Recall: {test_recall_0_4:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm_0_4[0,0]}, FP: {test_cm_0_4[0,1]}")
print(f"    FN: {test_cm_0_4[1,0]}, TP: {test_cm_0_4[1,1]}")

# Determine best threshold based on test set F1 score
test_f1_scores = {
    0.1: test_f1_0_1,
    0.2: test_f1_0_2,
    0.3: test_f1_0_3,
    0.4: test_f1_0_4
}
best_threshold = max(test_f1_scores, key=test_f1_scores.get)
best_f1 = test_f1_scores[best_threshold]
print(f"\nBest Threshold on Test Set: {best_threshold} with F1 Score: {best_f1:.4f}")

# SHAP analysis using the best model on test set
print("\nPerforming SHAP analysis on test set with best model...")
explainer = shap.TreeExplainer(best_model)
shap_values = explainer.shap_values(X_test)

# Use SHAP values for the positive class (index 1) for binary classification
# Summary plot (beeswarm)
plt.figure()
shap.summary_plot(shap_values[1], X_test, show=False)
plt.savefig("shap_summary_plot.png")
plt.close()
print("SHAP summary plot saved as 'shap_summary_plot.png'")

# Feature importance based on mean absolute SHAP values
shap_importance = np.abs(shap_values[1]).mean(axis=0)
importance_df = pd.DataFrame({
    'Feature': X_test.columns,
    'SHAP_Importance': shap_importance
}).sort_values(by='SHAP_Importance', ascending=False)
print("\nSHAP Feature Importance:")
print(importance_df)

Training fold 1/5
Training class counts:
high_risk_pregnancy
0                      1225573
1                        54427
Name: count, dtype: int64
Validation class counts:
high_risk_pregnancy
0                      306394
1                       13606
Name: count, dtype: int64
Fold 1
  AUC: 0.9858, Time: 45.21 seconds
  Threshold 0.1 - F1: 0.2639, Accuracy: 0.7699, Precision: 0.1527, Recall: 0.9699
  Confusion Matrix (Threshold 0.1):
    TN: 233171, FP: 73223
    FN: 410, TP: 13196
  Threshold 0.2 - F1: 0.6783, Accuracy: 0.9613, Precision: 0.5248, Recall: 0.9587
  Confusion Matrix (Threshold 0.2):
    TN: 294584, FP: 11810
    FN: 562, TP: 13044
  Threshold 0.3 - F1: 0.9676, Accuracy: 0.9973, Precision: 0.9769, Recall: 0.9585
  Confusion Matrix (Threshold 0.3):
    TN: 306086, FP: 308
    FN: 565, TP: 13041
  Threshold 0.4 - F1: 0.9738, Accuracy: 0.9978, Precision: 0.9901, Recall: 0.9581
  Confusion Matrix (Threshold 0.4):
    TN: 306263, FP: 131
    FN: 570, TP: 13036
Training fold 

In [8]:
import pickle
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, confusion_matrix, roc_auc_score
from google.colab import files
import os

# Save the best model (Fold 2) to a pickle file
model_filename = 'rf_model_fold2.pkl'
print(f"Saving best model (Fold {best_fold_idx + 1}) to {model_filename}...")
with open(model_filename, 'wb') as file:
    pickle.dump(best_model, file)
print(f"Model saved successfully as {model_filename}")

# Verify the saved model by loading it and recomputing test set metrics
print("\nVerifying saved model...")
with open(model_filename, 'rb') as file:
    loaded_model = pickle.load(file)

# Recompute test set predictions (threshold 0.4, best threshold)
y_test_pred_proba = loaded_model.predict_proba(X_test)[:, 1]
y_test_pred = (y_test_pred_proba > 0.4).astype(int)

# Recompute metrics
test_auc = roc_auc_score(y_test, y_test_pred_proba)
test_f1 = f1_score(y_test, y_test_pred)
test_accuracy = accuracy_score(y_test, y_test_pred)
test_precision = precision_score(y_test, y_test_pred, zero_division=0)
test_recall = recall_score(y_test, y_test_pred, zero_division=0)
test_cm = confusion_matrix(y_test, y_test_pred)
 
print(f"\nLoaded Model Test Set Metrics (Threshold 0.4):")
print(f"  AUC: {test_auc:.4f}")
print(f"  F1 Score: {test_f1:.4f}")
print(f"  Accuracy: {test_accuracy:.4f}")
print(f"  Precision: {test_precision:.4f}")
print(f"  Recall: {test_recall:.4f}")
print("  Confusion Matrix:")
print(f"    TN: {test_cm[0,0]}, FP: {test_cm[0,1]}")
print(f"    FN: {test_cm[1,0]}, TP: {test_cm[1,1]}")

# Download the model file (Colab-specific)
if os.path.exists(model_filename):
    print(f"\nInitiating download of {model_filename}...")
    files.download(model_filename)
else:
    print(f"Error: {model_filename} not found.")


Saving best model (Fold 2) to rf_model_fold2.pkl...
Model saved successfully as rf_model_fold2.pkl

Verifying saved model...

Loaded Model Test Set Metrics (Threshold 0.4):
  AUC: 0.9852
  F1 Score: 0.9691
  Accuracy: 0.9974
  Precision: 0.9828
  Recall: 0.9557
  Confusion Matrix:
    TN: 382708, FP: 284
    FN: 754, TP: 16254

Initiating download of rf_model_fold2.pkl...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [11]:
import joblib
import re
import numpy as np
import pandas as pd

class HighRiskPregnancyClinicalSupport:
    """
    Dedicated Clinical Decision Support for High-Risk Pregnancy Prediction
    Based on Random Forest model (Fold 2) performance:
    AUC: 0.9852, F1 Score: 0.9691, Optimal Threshold: 0.4
    Key Features (SHAP): HEMOGLOBIN_mean, anemia_moderate, recurrent_loss, ABORTIONS
    Feature Engineering: Based on Telangana Maternal Health pipeline
    """

    def __init__(self, model_path='/kaggle/working/rf_model_fold2.pkl'):
        self.model_threshold = 0.4  # Optimal threshold from model
        self.model = joblib.load(model_path)  # Load trained Random Forest model
        self.feature_names = [
            'GRAVIDA', 'PARITY', 'ABORTIONS', 'HEIGHT', 'HEMOGLOBIN_mean',
            'age_adolescent', 'age_elderly', 'age_very_young', 'previous_loss',
            'recurrent_loss', 'gravida_parity_ratio', 'inadequate_anc', 'irregular_anc',
            'anemia_mild', 'anemia_moderate', 'anemia_severe', 'ever_severe_anemia',
            'systolic_bp', 'diastolic_bp', 'hypertension', 'BMI', 'underweight',
            'obese', 'normal_weight', 'depression', 'severe_depression', 'anxiety',
            'severe_anxiety', 'weight_gain', 'weight_gain_per_week', 'inadequate_weight_gain'
        ]  # Features used in your model

    def preprocess_patient_data(self, raw_data):
        """Preprocess raw clinical data into model-compatible features per Telangana Maternal Health pipeline"""
        processed_data = {}

        # 1. Age-Based Features
        age = raw_data.get('AGE', 25)
        processed_data['age_adolescent'] = 1 if age < 18 else 0
        processed_data['age_very_young'] = 1 if age < 16 else 0
        processed_data['age_elderly'] = 1 if age > 35 else 0

        # 2. Obstetric History Features
        gravida = raw_data.get('GRAVIDA', 1)
        parity = raw_data.get('PARITY', 0)
        abortions = raw_data.get('ABORTIONS', 0)
        processed_data['GRAVIDA'] = gravida
        processed_data['PARITY'] = parity
        processed_data['ABORTIONS'] = abortions
        processed_data['previous_loss'] = 1 if abortions > 0 else 0
        processed_data['recurrent_loss'] = 1 if abortions >= 2 else 0
        processed_data['gravida_parity_ratio'] = gravida / parity if parity > 0 else gravida

        # 3. ANC Visit Pattern Features
        total_anc_visits = raw_data.get('TOTAL_ANC_VISITS', 4)
        processed_data['inadequate_anc'] = 1 if total_anc_visits < 4 else 0
        processed_data['irregular_anc'] = 1 if raw_data.get('total_missed_visits', 0) >= 2 else 0

        # 4. Anemia Classification Features
        hemoglobin = raw_data.get('HEMOGLOBIN', 12)  # Assume single value or mean
        processed_data['HEMOGLOBIN_mean'] = hemoglobin
        processed_data['anemia_mild'] = 1 if 10 <= hemoglobin < 11 else 0
        processed_data['anemia_moderate'] = 1 if 7 <= hemoglobin < 10 else 0
        processed_data['anemia_severe'] = 1 if hemoglobin < 7 else 0
        processed_data['ever_severe_anemia'] = 1 if raw_data.get('HEMOGLOBIN_min', hemoglobin) < 7 else 0

        # 5. Blood Pressure Features
        bp_last = raw_data.get('BP_last', '120/80')
        bp_match = re.match(r'(\d+)/(\d+)', bp_last)
        systolic_bp = float(bp_match.group(1)) if bp_match else 120
        diastolic_bp = float(bp_match.group(2)) if bp_match else 80
        processed_data['systolic_bp'] = systolic_bp
        processed_data['diastolic_bp'] = diastolic_bp
        processed_data['hypertension'] = 1 if systolic_bp >= 140 or diastolic_bp >= 90 else 0

        # 6. BMI and Nutritional Status
        weight = raw_data.get('WEIGHT_max', 60)
        height = raw_data.get('HEIGHT', 160)
        bmi = weight / (height / 100) ** 2 if height > 0 else 22
        processed_data['HEIGHT'] = height
        processed_data['BMI'] = bmi
        processed_data['underweight'] = 1 if bmi < 18.5 else 0
        processed_data['obese'] = 1 if bmi > 30 else 0
        processed_data['normal_weight'] = 1 if 18.5 <= bmi <= 25 else 0

        # 7. Mental Health Features
        phq_score = raw_data.get('PHQ_SCORE_max', 0)
        gad_score = raw_data.get('GAD_SCORE_max', 0)
        processed_data['depression'] = 1 if phq_score >= 10 else 0
        processed_data['severe_depression'] = 1 if phq_score >= 15 else 0
        processed_data['anxiety'] = 1 if gad_score >= 10 else 0
        processed_data['severe_anxiety'] = 1 if gad_score >= 15 else 0

        # 8. Weight Change Features
        weight_gain = raw_data.get('weight_gain', raw_data.get('WEIGHT_last', 60) - raw_data.get('WEIGHT_first', 60))
        no_of_weeks = raw_data.get('NO_OF_WEEKS_max', 40)
        weight_gain_per_week = weight_gain / no_of_weeks if no_of_weeks > 0 else 0
        processed_data['weight_gain'] = weight_gain
        processed_data['weight_gain_per_week'] = weight_gain_per_week
        processed_data['inadequate_weight_gain'] = 1 if weight_gain_per_week < 0.2 else 0

        return processed_data

    def predict_risk_probability(self, raw_data):
        """Predict high-risk pregnancy probability using the trained model"""
        processed_data = self.preprocess_patient_data(raw_data)
        # Create a DataFrame with features in the correct order
        X = pd.DataFrame([processed_data], columns=self.feature_names)
        # Ensure all features are numeric and handle missing values
        X = X.fillna(0).astype({name: 'int8' if name not in ['HEMOGLOBIN_mean', 'systolic_bp', 'diastolic_bp', 'BMI', 'gravida_parity_ratio', 'weight_gain', 'weight_gain_per_week', 'HEIGHT'] else 'float32' for name in self.feature_names})
        # Get probability for high-risk class
        risk_probability = self.model.predict_proba(X)[:, 1][0]
        return risk_probability

    def generate_high_risk_pregnancy_recommendations(self, raw_data):
        """
        Generate clinical recommendations for high-risk pregnancy
        """
        risk_probability = self.predict_risk_probability(raw_data)
        patient_data = self.preprocess_patient_data(raw_data)
        risk_level = self._categorize_high_risk(risk_probability)

        recommendations = {
            'risk_assessment': {
                'probability': float(risk_probability),
                'risk_level': risk_level,
                'classification': risk_probability >= self.model_threshold
            },
            'immediate_actions': [],
            'monitoring_protocols': {},
            'referral_pathways': {},
            'medication_protocols': [],
            'specialist_consultations': [],
            'emergency_protocols': [],
            'healthcare_team_actions': {
                'asha_worker': [],
                'anm_nurse': [],
                'medical_officer': [],
                'obstetrician': []
            }
        }

        self._add_high_risk_specific_actions(recommendations, patient_data, risk_probability, risk_level)
        return recommendations

    def _categorize_high_risk(self, probability):
        """Categorize high-risk pregnancy probability"""
        if probability >= 0.80:
            return 'Critical High-Risk'
        elif probability >= 0.60:
            return 'Very High-Risk'
        elif probability >= 0.40:
            return 'High-Risk'
        elif probability >= 0.20:
            return 'Moderate-Risk'
        else:
            return 'Low-Risk'

    def _add_high_risk_specific_actions(self, recommendations, patient_data, probability, risk_level):
        """Add high-risk pregnancy specific recommendations"""
        # IMMEDIATE ACTIONS
        if risk_level == 'Critical High-Risk':
            recommendations['immediate_actions'].extend([
                "🚨 CRITICAL HIGH-RISK PREGNANCY: Immediate specialized care required",
                "Activate high-risk pregnancy protocol within 24 hours",
                "Schedule emergency maternal-fetal medicine consultation",
                "Consider immediate hospitalization for comprehensive evaluation",
                "Notify tertiary care center and reserve high-risk pregnancy bed",
                "Complete comprehensive risk factor assessment immediately"
            ])
        elif risk_level == 'Very High-Risk':
            recommendations['immediate_actions'].extend([
                "⚠️ VERY HIGH-RISK PREGNANCY: Enhanced monitoring protocol",
                "Schedule urgent high-risk pregnancy clinic appointment within 48 hours",
                "Initiate intensive antenatal surveillance program",
                "Complete detailed maternal-fetal assessment"
            ])
        elif risk_level == 'High-Risk':
            recommendations['immediate_actions'].extend([
                "📋 HIGH-RISK PREGNANCY: Structured care plan required",
                "Schedule high-risk pregnancy consultation within 1 week",
                "Upgrade to enhanced antenatal care protocol"
            ])

        # MONITORING PROTOCOLS
        if risk_level in ['Critical High-Risk', 'Very High-Risk']:
            recommendations['monitoring_protocols'] = {
                'anc_frequency': 'Weekly visits',
                'fetal_monitoring': [
                    'Biweekly non-stress tests (NST)',
                    'Weekly biophysical profile (BPP)',
                    'Doppler studies every 2 weeks',
                    'Daily fetal movement counts'
                ],
                'maternal_monitoring': [
                    'Daily blood pressure monitoring',
                    'Weekly weight monitoring',
                    'Biweekly laboratory assessments (focus on hemoglobin)',
                    'Monthly cardiac evaluation if indicated'
                ],
                'growth_surveillance': [
                    'Serial growth scans every 3-4 weeks',
                    'Amniotic fluid volume assessment',
                    'Placental function evaluation'
                ]
            }

        # REFERRAL PATHWAYS
        self._add_high_risk_referrals(recommendations, patient_data, risk_level)

        # MEDICATION PROTOCOLS
        self._add_high_risk_medications(recommendations, patient_data, risk_level)

        # SPECIALIST CONSULTATIONS
        if risk_level in ['Critical High-Risk', 'Very High-Risk']:
            recommendations['specialist_consultations'].extend([
                'Maternal-Fetal Medicine specialist',
                'Anesthesiology consultation for delivery planning',
                'Neonatology consultation if fetal concerns',
                'Hematology if severe anemia detected',
                'Endocrinology if diabetes/thyroid issues'
            ])

        # EMERGENCY PROTOCOLS
        if risk_level in ['Critical High-Risk', 'Very High-Risk']:
            recommendations['emergency_protocols'].extend([
                '24/7 high-risk pregnancy helpline access',
                'Direct admission protocol to high-risk unit',
                'Emergency delivery team on standby',
                'NICU bed reservation if needed',
                'Blood bank cross-match and hold'
            ])

        # HEALTHCARE TEAM ACTIONS
        self._add_high_risk_team_actions(recommendations, risk_level)

    def _add_high_risk_referrals(self, recommendations, patient_data, risk_level):
        """Add high-risk pregnancy specific referrals"""
        if risk_level == 'Critical High-Risk':
            recommendations['referral_pathways'] = {
                'primary_referral': 'Tertiary Care Center - Maternal-Fetal Medicine Unit',
                'urgency': 'Immediate (within 24 hours)',
                'backup_facility': 'District Hospital with high-risk pregnancy services',
                'transport': 'Emergency ambulance with obstetric capability'
            }
        elif risk_level == 'Very High-Risk':
            recommendations['referral_pathways'] = {
                'primary_referral': 'District Hospital - High-Risk Pregnancy Clinic',
                'urgency': 'Urgent (within 48-72 hours)',
                'backup_facility': 'Community Health Center with specialist coverage',
                'transport': 'Patient transport service or private vehicle'
            }
        elif risk_level == 'High-Risk':
            recommendations['referral_pathways'] = {
                'primary_referral': 'Community Health Center - Enhanced ANC',
                'urgency': 'Routine (within 1 week)',
                'backup_facility': 'Primary Health Center with medical officer',
                'transport': 'Standard transport arrangements'
            }

    def _add_high_risk_medications(self, recommendations, patient_data, risk_level):
        """Add high-risk pregnancy specific medications"""
        hemoglobin = patient_data['HEMOGLOBIN_mean']
        recurrent_loss = patient_data['recurrent_loss']
        abortions = patient_data['ABORTIONS']
        hypertension = patient_data['hypertension']
        age_adolescent = patient_data['age_adolescent']
        depression = patient_data['depression']
        anxiety = patient_data['anxiety']

        if patient_data['anemia_severe'] == 1 or hemoglobin < 7.0:
            recommendations['medication_protocols'].extend([
                '🩸 SEVERE ANEMIA PROTOCOL:',
                '• Iron sucrose IV 200mg in 100ml NS over 20 minutes',
                '• Consider blood transfusion if Hb <7g/dL with symptoms',
                '• Folic acid 5mg daily',
                '• Weekly hemoglobin monitoring'
            ])
        elif patient_data['anemia_moderate'] == 1 or hemoglobin < 10.0:
            recommendations['medication_protocols'].extend([
                '🩸 MODERATE ANEMIA MANAGEMENT:',
                '• Ferrous sulfate 200mg TDS with Vitamin C',
                '• Folic acid 5mg daily',
                '• Fortnightly hemoglobin monitoring'
            ])

        if recurrent_loss == 1 or abortions >= 2:
            recommendations['medication_protocols'].extend([
                '🔄 RECURRENT LOSS PROTOCOL:',
                '• Low-dose aspirin 75mg daily from 12 weeks',
                '• Consider progesterone supplementation if indicated',
                '• Detailed workup for thrombophilia or autoimmune disorders',
                '• Genetic counseling if recurrent losses'
            ])

        if age_adolescent == 1:
            recommendations['medication_protocols'].extend([
                '👶 ADOLESCENT PREGNANCY PROTOCOL:',
                '• Enhanced folic acid 5mg daily',
                '• Calcium 1500mg daily',
                '• Iron with higher bioavailability',
                '• Vitamin D supplementation 1000 IU daily'
            ])

        if hypertension == 1:
            recommendations['medication_protocols'].extend([
                '🫀 HYPERTENSION MANAGEMENT:',
                '• Methyldopa 250mg TDS (first line)',
                '• Nifedipine SR 30mg BD if methyldopa fails',
                '• Low-dose aspirin 75mg from 12 weeks',
                '• Avoid ACE inhibitors and ARBs'
            ])

        if depression == 1 or anxiety == 1:
            recommendations['medication_protocols'].extend([
                '🧠 MENTAL HEALTH SUPPORT:',
                '• Refer to mental health specialist for counseling',
                '• Consider SSRIs (e.g., sertraline) if clinically indicated',
                '• Weekly mental health check-ins'
            ])

    def _add_high_risk_team_actions(self, recommendations, risk_level):
        """Add healthcare team actions for high-risk pregnancy"""
        if risk_level in ['Critical High-Risk', 'Very High-Risk']:
            recommendations['healthcare_team_actions']['asha_worker'].extend([
                'Daily home visits during critical periods',
                'Monitor for pregnancy danger signs',
                'Ensure medication compliance (focus on anemia, hypertension)',
                'Coordinate urgent transportation',
                'Provide continuous emotional support'
            ])

            recommendations['healthcare_team_actions']['anm_nurse'].extend([
                'Detailed assessment at each visit (hemoglobin, BP, mental health)',
                'Monitor fetal heart rate and movements',
                'Blood pressure and weight monitoring',
                'Coordinate specialist appointments',
                'Manage emergency situations'
            ])

            recommendations['healthcare_team_actions']['medical_officer'].extend([
                'Weekly clinical review (focus on anemia, recurrent loss, mental health)',
                'Manage complications',
                'Coordinate tertiary care referrals',
                'Emergency delivery decisions',
                'Family counseling'
            ])

            recommendations['healthcare_team_actions']['obstetrician'].extend([
                'Comprehensive high-risk assessment',
                'Delivery planning and timing',
                'Manage obstetric emergencies',
                'Coordinate multidisciplinary care (hematology, MFM, mental health)',
                'Postpartum high-risk management'
            ])

def print_high_risk_pregnancy_recommendations(recommendations):
    """Print high-risk pregnancy specific recommendations"""
    print("="*80)
    print("🤰 HIGH-RISK PREGNANCY CLINICAL DECISION SUPPORT")
    print("="*80)

    risk_assessment = recommendations['risk_assessment']
    print(f"\n📊 HIGH-RISK PREGNANCY ASSESSMENT:")
    print(f"   Risk Probability: {risk_assessment['probability']:.1%}")
    print(f"   Risk Level: {risk_assessment['risk_level']}")
    print(f"   High-Risk Classification: {'YES' if risk_assessment['classification'] else 'NO'}")

    if recommendations['immediate_actions']:
        print(f"\n🚨 IMMEDIATE ACTIONS FOR HIGH-RISK PREGNANCY:")
        for i, action in enumerate(recommendations['immediate_actions'], 1):
            print(f"   {i}. {action}")

    if recommendations['monitoring_protocols']:
        print(f"\n📈 HIGH-RISK MONITORING PROTOCOLS:")
        monitoring = recommendations['monitoring_protocols']
        print(f"   • ANC Frequency: {monitoring.get('anc_frequency', 'Standard')}")
        if 'fetal_monitoring' in monitoring:
            print(f"   • Fetal Monitoring:")
            for item in monitoring['fetal_monitoring']:
                print(f"     - {item}")
        if 'maternal_monitoring' in monitoring:
            print(f"   • Maternal Monitoring:")
            for item in monitoring['maternal_monitoring']:
                print(f"     - {item}")

    if recommendations['referral_pathways']:
        print(f"\n🏥 HIGH-RISK REFERRAL PATHWAYS:")
        referral = recommendations['referral_pathways']
        print(f"   • Primary Referral: {referral.get('primary_referral', 'Standard')}")
        print(f"   • Urgency: {referral.get('urgency', 'Routine')}")
        print(f"   • Transport: {referral.get('transport', 'Standard')}")

    if recommendations['medication_protocols']:
        print(f"\n💊 HIGH-RISK MEDICATION PROTOCOLS:")
        for protocol in recommendations['medication_protocols']:
            print(f"   {protocol}")

    if recommendations['specialist_consultations']:
        print(f"\n🔬 SPECIALIST CONSULTATIONS:")
        for specialist in recommendations['specialist_consultations']:
            print(f"   • {specialist}")

    print("\n📋 FOR PATIENTS: WHAT YOU NEED TO DO:")
    if risk_assessment['risk_level'] in ['Critical High-Risk', 'Very High-Risk']:
        print("   • Visit a specialist hospital immediately.")
        print("   • Follow your doctor’s instructions for medicines and checkups.")
        print("   • Call the emergency helpline if you feel unwell.")
    if recommendations['medication_protocols']:
        print("   • Take your prescribed medicines (e.g., iron, blood pressure pills) daily.")
    print("   • Attend all scheduled ANC visits and report any symptoms to your doctor.")

    print("="*80)

In [13]:
# Raw patient data from ANC visit
# raw_data = {
#     'AGE': 24,
#     'HEMOGLOBIN': 6.5,
#     'HEMOGLOBIN_min': 6.0,
#     'ABORTIONS': 2,
#     'BP_last': '145/95',
#     'GRAVIDA': 3,
#     'PARITY': 1,
#     'WEIGHT_max': 60,
#     'HEIGHT': 160,
#     'WEIGHT_first': 55,
#     'WEIGHT_last': 62,
#     'NO_OF_WEEKS_max': 30,
#     'TOTAL_ANC_VISITS': 2,
#     'total_missed_visits': 2,
#     'PHQ_SCORE_max': 12,
#     'GAD_SCORE_max': 10
# }
raw_data = {
    'AGE': 26,
    'HEMOGLOBIN': 12.2,
    'HEMOGLOBIN_min': 11.5,
    'ABORTIONS': 0,
    'BP_last': '120/75',
    'GRAVIDA': 1,
    'PARITY': 0,
    'WEIGHT_max': 62,
    'HEIGHT': 160,
    'WEIGHT_first': 54,
    'WEIGHT_last': 61,
    'NO_OF_WEEKS_max': 30,
    'TOTAL_ANC_VISITS': 5,
    'total_missed_visits': 0,  
    'PHQ_SCORE_max': 2,
    'GAD_SCORE_max': 1
}


# Initialize the class with the model path
clinical_support = HighRiskPregnancyClinicalSupport(model_path='/kaggle/working/rf_model_fold2.pkl')
recommendations = clinical_support.generate_high_risk_pregnancy_recommendations(raw_data)
print_high_risk_pregnancy_recommendations(recommendations)

🤰 HIGH-RISK PREGNANCY CLINICAL DECISION SUPPORT

📊 HIGH-RISK PREGNANCY ASSESSMENT:
   Risk Probability: 2.5%
   Risk Level: Low-Risk
   High-Risk Classification: NO

📋 FOR PATIENTS: WHAT YOU NEED TO DO:
   • Attend all scheduled ANC visits and report any symptoms to your doctor.
