In [None]:
# Function to validate the imputation results with improved visualizations
def validate_imputation(original_data, imputed_data, key_columns):
    """
    Validate the imputation results by comparing distributions
    Only creates visualizations for variables that actually had missing values
    """
    print("\nValidating imputation results...")
    
    # Filter to only include columns that had missing values
    columns_with_missing = [col for col in key_columns 
                           if col in original_data.columns 
                           and col in imputed_data.columns
                           and original_data[col].isnull().sum() > 0]
    
    if not columns_with_missing:
        print("  No columns with missing values to validate.")
        return
    
    print(f"  Validating {len(columns_with_missing)} columns with missing values")
    
    # Organize columns by missingness percentage
    missing_pcts = {col: original_data[col].isnull().mean() * 100 for col in columns_with_missing}
    
    # Group columns by missing percentage
    high_missing = [col for col, pct in missing_pcts.items() if pct >= 30]
    mod_missing = [col for col, pct in missing_pcts.items() if 10 <= pct < 30]
    low_missing = [col for col, pct in missing_pcts.items() if pct < 10]
    
    # Create separate figures for each group to avoid overcrowding
    plot_groups = []
    if high_missing:
        plot_groups.append(("High Missingness Variables (≥30%)", high_missing))
    if mod_missing:
        plot_groups.append(("Moderate Missingness Variables (10-30%)", mod_missing))
    if low_missing:
        plot_groups.append(("Low Missingness Variables (<10%)", low_missing))
    
    # Process each group
    for group_title, cols in plot_groups:
        # Setup plotting for this group
        n_cols = len(cols)
        n_rows = (n_cols + 1) // 2  # Ceiling division by 2
        
        # Create figure
        fig, axes = plt.subplots(n_rows, 2, figsize=(16, 6 * n_rows))
        fig.suptitle(group_title, fontsize=16, fontweight='bold', y=0.98)
        
        # Flatten axes for easy indexing
        if n_rows > 1:
            axes = axes.flatten()
        else:
            axes = [axes[0], axes[1]] if isinstance(axes, np.ndarray) else [axes]  # Handle single row case
        
        # Generate comparison plots for each column
        for i, col in enumerate(cols):
            if i < len(axes):
                ax = axes[i]
                
                # Get missing percentage for title
                missing_pct = missing_pcts[col]
                
                # Get original data (non-missing), originally missing values, and all imputed data
                orig_data = original_data[col].dropna()
                missing_mask = original_data[col].isnull()
                imputed_missing = imputed_data.loc[missing_mask, col]
                all_imputed = imputed_data[col]
                
                # Set title with missing percentage
                ax.set_title(f"{col} ({missing_pct:.1f}% imputed)", fontsize=12, fontweight='bold')
                
                # Check if numeric or categorical
                if pd.api.types.is_numeric_dtype(orig_data):
                    # For numeric variables, create a more informative distribution plot
                    # Use kernel density estimation for smoother visualization
                    sns.kdeplot(orig_data, ax=ax, color='#1f77b4', label='Original', linewidth=2.5)
                    sns.kdeplot(all_imputed, ax=ax, color='#ff7f0e', label='After Imputation', linewidth=2)
                    
                    # Add a histogram of just the imputed values to see their distribution
                    if len(imputed_missing) > 0:
                        sns.histplot(imputed_missing, ax=ax, color='#d62728', alpha=0.4, 
                                    label='Imputed Values Only', kde=False)
                    
                    # Add vertical lines for means
                    ax.axvline(orig_data.mean(), color='#1f77b4', linestyle='--', linewidth=1.5)
                    ax.axvline(all_imputed.mean(), color='#ff7f0e', linestyle='--', linewidth=1.5)
                    
                    # Add statistics as an inset text box
                    orig_stats = f"Original (n={len(orig_data)}): mean={orig_data.mean():.2f}, std={orig_data.std():.2f}"
                    imp_stats = f"After Imputation (n={len(all_imputed)}): mean={all_imputed.mean():.2f}, std={all_imputed.std():.2f}"
                    imp_only_stats = f"Imputed Only (n={len(imputed_missing)}): mean={imputed_missing.mean():.2f}, std={imputed_missing.std():.2f}"
                    
                    stats_text = f"{orig_stats}\n{imp_stats}\n{imp_only_stats}"
                    
                    # Calculate and display KS test p-value if enough data
                    if len(orig_data) >= 5 and len(all_imputed) >= 5:
                        from scipy.stats import ks_2samp
                        ks_stat, p_val = ks_2samp(orig_data, all_imputed)
                        stats_text += f"\nKS test p-value: {p_val:.3f}"
                        
                        # Add a qualitative assessment
                        if p_val > 0.05:
                            stats_text += " (distributions not significantly different)"
                        else:
                            stats_text += " (distributions significantly different)"
                    
                    # Add text box
                    ax.text(0.05, 0.95, stats_text, transform=ax.transAxes, va='top', 
                           bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), fontsize=9)
                    
                    # Print mean and std differences for reference
                    mean_diff = abs(orig_data.mean() - all_imputed.mean())
                    std_diff = abs(orig_data.std() - all_imputed.std())
                    
                    # Calculate mean difference as percentage of original std (useful metric)
                    if orig_data.std() > 0:
                        std_pct_diff = (mean_diff / orig_data.std()) * 100
                        print(f"  {col}: mean diff = {mean_diff:.2f} ({std_pct_diff:.1f}% of original std)")
                    else:
                        print(f"  {col}: mean diff = {mean_diff:.2f}")
                    
                else:
                    # For categorical variables, create a grouped bar chart
                    # Create value counts for visualization
                    orig_counts = orig_data.value_counts(normalize=True).sort_index()
                    all_imp_counts = all_imputed.value_counts(normalize=True).sort_index()
                    
                    if len(imputed_missing) > 0:
                        imp_only_counts = imputed_missing.value_counts(normalize=True).sort_index()
                    else:
                        imp_only_counts = pd.Series(dtype='float64')
                    
                    # Combine all unique categories
                    all_cats = sorted(set(list(orig_counts.index) + 
                                       list(all_imp_counts.index) + 
                                       list(imp_only_counts.index)))
                    
                    # Create a DataFrame for plotting with all three distributions
                    plot_data = []
                    
                    # Add original distribution
                    for cat in all_cats:
                        value = orig_counts.get(cat, 0)
                        plot_data.append({
                            'Category': cat,
                            'Proportion': value,
                            'Source': 'Original'
                        })
                    
                    # Add after imputation distribution
                    for cat in all_cats:
                        value = all_imp_counts.get(cat, 0)
                        plot_data.append({
                            'Category': cat,
                            'Proportion': value,
                            'Source': 'After Imputation'
                        })
                    
                    # Add imputed-only distribution if we have imputed values
                    if len(imputed_missing) > 0:
                        for cat in all_cats:
                            value = imp_only_counts.get(cat, 0)
                            plot_data.append({
                                'Category': cat,
                                'Proportion': value,
                                'Source': 'Imputed Only'
                            })
                    
                    # Convert to DataFrame
                    plot_df = pd.DataFrame(plot_data)
                    
                    # Plot bar chart
                    sns.barplot(data=plot_df, x='Category', y='Proportion', hue='Source', 
                               palette=['#1f77b4', '#ff7f0e', '#d62728'], ax=ax)
                    
                    # Customize x-axis labels
                    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
                    
                    # Add counts as text
                    stats_text = f"Original n={len(orig_data)}, After Imputation n={len(all_imputed)}, Imputed Only n={len(imputed_missing)}"
                    
                    # Calculate chi-square test for distribution comparison
                    try:
                        from scipy.stats import chi2_contingency
                        
                        # Create contingency table
                        table = pd.DataFrame({
                            'Original': [orig_counts.get(cat, 0) * len(orig_data) for cat in all_cats],
                            'Imputed': [all_imp_counts.get(cat, 0) * len(all_imputed) for cat in all_cats]
                        })
                        
                        # Chi-square test
                        chi2, p_val, _, _ = chi2_contingency(table)
                        stats_text += f"\nChi² test p-value: {p_val:.3f}"
                        
                        # Add qualitative assessment
                        if p_val > 0.05:
                            stats_text += " (distributions not significantly different)"
                        else:
                            stats_text += " (distributions significantly different)"
                            
                    except Exception as e:
                        # If chi-square test fails, use Hellinger distance instead
                        hell_dist = np.sqrt(np.sum((np.sqrt(orig_counts.reindex(all_cats).fillna(0)) - 
                                               np.sqrt(all_imp_counts.reindex(all_cats).fillna(0))) ** 2)) / np.sqrt(2)
                        stats_text += f"\nHellinger distance: {hell_dist:.3f}"
                        
                        # Add qualitative assessment
                        if hell_dist < 0.2:
                            stats_text += " (distributions very similar)"
                        elif hell_dist < 0.4:
                            stats_text += " (distributions moderately similar)"
                        else:
                            stats_text += " (distributions differ substantially)"
                    
                    # Add text box
                    ax.text(0.05, 0.95, stats_text, transform=ax.transAxes, va='top', 
                           bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), fontsize=9)
                
                # Consistent legend with better placement
                handles, labels = ax.get_legend_handles_labels()
                ax.legend(handles, labels, loc='best', frameon=True, fontsize=9, 
                         title="Data Source")
            
        # Remove any unused subplots
        for i in range(len(cols), len(axes)):
            fig.delaxes(axes[i])
            
        # Improve spacing
        plt.tight_layout()
        plt.subplots_adjust(top=0.92)
        
        # Save figure
        group_name = group_title.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('%', 'pct')
        plt.savefig(f'imputation_validation_{group_name}.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"  Validation plots for {group_title} saved to 'imputation_validation_{group_name}.png'")
    
    # Create a summary table of imputation quality metrics
    print("\nCreating imputation quality summary table...")
    
    # Collect metrics for each variable
    metrics = []
    for col in columns_with_missing:
        orig_data = original_data[col].dropna()
        missing_mask = original_data[col].isnull()
        imputed_missing = imputed_data.loc[missing_mask, col]
        all_imputed = imputed_data[col]
        
        metric = {
            'Variable': col,
            'Missing %': missing_pcts[col],
            'Missing Count': missing_mask.sum(),
            'Total Count': len(original_data)
        }
        
        # Calculate appropriate metrics based on variable type
        if pd.api.types.is_numeric_dtype(orig_data):
            # For numeric variables
            metric['Original Mean'] = orig_data.mean()
            metric['Original Std'] = orig_data.std()
            metric['Imputed Mean'] = all_imputed.mean()
            metric['Imputed Std'] = all_imputed.std()
            metric['Mean Diff'] = abs(metric['Original Mean'] - metric['Imputed Mean'])
            
            if metric['Original Std'] > 0:
                metric['Mean Diff (% of Std)'] = (metric['Mean Diff'] / metric['Original Std']) * 100
            else:
                metric['Mean Diff (% of Std)'] = None
                
            # Calculate KS test if enough data
            if len(orig_data) >= 5 and len(all_imputed) >= 5:
                from scipy.stats import ks_2samp
                ks_stat, p_val = ks_2samp(orig_data, all_imputed)
                metric['KS Test p-value'] = p_val
                
                # Determine quality
                if p_val > 0.05:
                    if metric['Mean Diff (% of Std)'] is not None and metric['Mean Diff (% of Std)'] < 20:
                        metric['Quality'] = 'Good'
                    else:
                        metric['Quality'] = 'Acceptable'
                else:
                    if missing_pcts[col] > 50:
                        metric['Quality'] = 'Acceptable (high missingness)'
                    else:
                        metric['Quality'] = 'Review'
            else:
                metric['KS Test p-value'] = None
                metric['Quality'] = 'Insufficient Data'
                
        else:
            # For categorical variables
            orig_counts = orig_data.value_counts(normalize=True)
            all_imp_counts = all_imputed.value_counts(normalize=True)
            
            # Most frequent category
            metric['Original Mode'] = orig_data.mode().iloc[0] if len(orig_data) > 0 else None
            metric['Imputed Mode'] = all_imputed.mode().iloc[0]
            metric['Mode Changed'] = metric['Original Mode'] != metric['Imputed Mode']
            
            # Calculate chi-square or Hellinger distance
            try:
                from scipy.stats import chi2_contingency
                all_cats = sorted(set(list(orig_counts.index) + list(all_imp_counts.index)))
                
                # Create contingency table
                table = pd.DataFrame({
                    'Original': [orig_counts.get(cat, 0) * len(orig_data) for cat in all_cats],
                    'Imputed': [all_imp_counts.get(cat, 0) * len(all_imputed) for cat in all_cats]
                })
                
                # Chi-square test
                chi2, p_val, _, _ = chi2_contingency(table)
                metric['Chi² Test p-value'] = p_val
                
                # Determine quality
                if p_val > 0.05:
                    metric['Quality'] = 'Good'
                else:
                    if missing_pcts[col] > 50:
                        metric['Quality'] = 'Acceptable (high missingness)'
                    else:
                        metric['Quality'] = 'Review'
                        
            except Exception:
                # Fall back to Hellinger distance
                all_cats = sorted(set(list(orig_counts.index) + list(all_imp_counts.index)))
                hell_dist = np.sqrt(np.sum((np.sqrt(orig_counts.reindex(all_cats).fillna(0)) - 
                                       np.sqrt(all_imp_counts.reindex(all_cats).fillna(0))) ** 2)) / np.sqrt(2)
                metric['Hellinger Distance'] = hell_dist
                
                # Determine quality
                if hell_dist < 0.2:
                    metric['Quality'] = 'Good'
                elif hell_dist < 0.4:
                    metric['Quality'] = 'Acceptable'
                else:
                    metric['Quality'] = 'Review'
        
        metrics.append(metric)
    
    # Convert to DataFrame
    metrics_df = pd.DataFrame(metrics)
    
    # Sort by missing percentage
    metrics_df = metrics_df.sort_values('Missing %', ascending=False)
    
    # Save as CSV
    metrics_df.to_csv('imputation_quality_metrics.csv', index=False)
    print(f"  Imputation quality metrics saved to 'imputation_quality_metrics.csv'")
    
    # Create a visualization of the quality metrics
    plt.figure(figsize=(14, len(metrics_df) * 0.8 + 2))
    ax = plt.subplot(111)
    
    # Hide axes
    ax.axis('off')
    
    # Define colors for quality
    quality_colors = {
        'Good': '#1b9e77',           # Green
        'Acceptable': '#7570b3',      # Purple
        'Acceptable (high missingness)': '#7570b3', # Same purple
        'Review': '#d95f02',          # Orange
        'Insufficient Data': '#999999' # Gray
    }
    
    # Format the DataFrame for display
    display_df = metrics_df[['Variable', 'Missing %', 'Quality']].copy()
    display_df['Missing %'] = display_df['Missing %'].round(1).astype(str) + '%'
    
    # Create the table
    table = ax.table(
        cellText=display_df.values,
        colLabels=display_df.columns,
        loc='center',
        cellLoc='center',
        cellColours=[[1, 1, 1, 0] for _ in range(len(display_df) * 3)]  # Default white
    )
    
    # Manually set color for quality cells
    for i, quality in enumerate(metrics_df['Quality']):
        table[(i+1, 2)].set_facecolor(quality_colors.get(quality, '#ffffff'))
    
    # Style the table
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1.2, 1.5)
    
    # Add a title
    plt.suptitle('Imputation Quality Summary', fontsize=16, fontweight='bold', y=0.98)
    
    # Add a legend for quality colors
    legend_elements = [plt.Rectangle((0, 0), 1, 1, facecolor=color, edgecolor='black', 
                                   label=quality)
                     for quality, color in quality_colors.items() 
                     if quality in metrics_df['Quality'].values]
    
    ax.legend(handles=legend_elements, loc='upper right', title='Imputation Quality')
    
    # Save the figure
    plt.tight_layout()
    plt.savefig('imputation_quality_summary.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  Imputation quality summary visualization saved to 'imputation_quality_summary.png'")

In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.impute import KNNImputer
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.linear_model import BayesianRidge
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
import warnings
warnings.filterwarnings('ignore')

# Function to determine optimal k using hierarchical clustering (CAH)
def determine_optimal_k(data, columns, min_k=3, max_k=15, prefix=""):
    """Use hierarchical clustering to determine optimal k for KNN imputation"""
    print(f"\nDetermining optimal k for {prefix} using CAH...")
    numeric_data = data[columns].select_dtypes(include=[np.number])
    
    # Drop columns with all NaN values
    numeric_data = numeric_data.loc[:, numeric_data.notna().any()]
    
    # If no columns remain, return default k
    if numeric_data.shape[1] == 0:
        print(f"  No valid numeric columns for {prefix}. Using default k=5")
        return 5
    
    # Get complete cases (rows with no NaN)
    clean_data = numeric_data.dropna()
    
    if len(clean_data) < 5:
        print(f"  Not enough complete cases for {prefix}. Using default k=5")
        return 5
        
    try:
        # Standardize data for clustering
        scaler = StandardScaler()
        scaled_data = scaler.fit_transform(clean_data)
        
        # Perform hierarchical clustering
        Z = linkage(scaled_data, method='ward')
        
        # Plot dendrogram
        plt.figure(figsize=(12, 6))
        dendrogram(Z, truncate_mode='level', p=5)
        plt.title(f'Hierarchical Clustering Dendrogram for {prefix}')
        plt.xlabel('Sample index')
        plt.ylabel('Distance')
        plt.tight_layout()
        plt.savefig(f'dendrogram_{prefix.lower().replace(" ", "_")}.png')
        plt.close()
        
        # Calculate the differences between consecutive merges
        last_merge_distances = Z[:, 2]
        distance_diffs = np.diff(last_merge_distances)
        
        if len(distance_diffs) > 0:
            # Find the point of maximum difference (elbow method)
            max_diff_idx = np.argmax(distance_diffs)
            optimal_k = len(clean_data) - max_diff_idx
            
            # Check if the differences are significant enough
            if distance_diffs[max_diff_idx] < 0.1 * np.mean(last_merge_distances):
                print(f"  No clear elbow found for {prefix}. Using alternative method.")
                # Alternative method: find a reasonable number of clusters
                optimal_k = min(len(clean_data) // 5, max_k)
        else:
            optimal_k = min(len(clean_data) // 5, max_k)
        
        # Constrain k to reasonable bounds
        optimal_k = max(min_k, min(optimal_k, max_k, len(clean_data) // 3))
        
        print(f"  Optimal k determined for {prefix}: {optimal_k}")
        return optimal_k
        
    except Exception as e:
        print(f"  Error in hierarchical clustering for {prefix}: {e}")
        print("  Using default k=5")
        return 5

# Function for Predictive Mean Matching (PMM)
def pmm_impute(data, target_cols, predictor_cols=None, n_neighbors=5, random_state=42):
    """Impute values using Predictive Mean Matching"""
    if not target_cols:
        return data
        
    imputed_data = data.copy()
    np.random.seed(random_state)
    
    # If no predictor columns are specified, use all numeric columns except target
    if predictor_cols is None:
        predictor_cols = [col for col in data.columns 
                        if col not in target_cols 
                        and pd.api.types.is_numeric_dtype(data[col])]
    
    # Drop predictor columns with more than 50% missing values
    valid_predictors = [col for col in predictor_cols 
                       if data[col].isnull().mean() < 0.5]
    
    if not valid_predictors:
        print("  Warning: No valid predictor columns for PMM")
        return imputed_data
        
    print(f"  Using {len(valid_predictors)} predictor columns for PMM")
    
    for col in target_cols:
        if col not in data.columns:
            continue
            
        if data[col].isnull().sum() == 0:
            print(f"  No missing values in {col}")
            continue
            
        if data[col].isnull().sum() == len(data):
            print(f"  All values missing in {col}")
            continue
            
        print(f"  Imputing {col} with PMM...")
        missing_idx = data[col].isnull()
        observed_idx = ~missing_idx
        
        observed_values = data.loc[observed_idx, col].values
        
        if len(observed_values) <= n_neighbors:
            print(f"  Too few observed values for {col}. Using random sampling.")
            for idx in data[missing_idx].index:
                imputed_data.loc[idx, col] = np.random.choice(observed_values)
            continue
            
        X_train = data.loc[observed_idx, valid_predictors].copy()
        for p in valid_predictors:
            X_train[p] = X_train[p].fillna(X_train[p].median())
            
        X_test = data.loc[missing_idx, valid_predictors].copy()
        for p in valid_predictors:
            X_test[p] = X_test[p].fillna(data[p].median())
            
        if len(X_train) > 0 and len(X_test) > 0:
            try:
                # Standardize data for better distance calculation
                scaler = StandardScaler()
                X_train_scaled = scaler.fit_transform(X_train)
                X_test_scaled = scaler.transform(X_test)
                
                # Find k nearest neighbors
                nn_model = NearestNeighbors(n_neighbors=min(n_neighbors, len(X_train)))
                nn_model.fit(X_train_scaled)
                distances, indices = nn_model.kneighbors(X_test_scaled)
                
                # Sample value from k nearest neighbors for each missing value
                for i, idx in enumerate(data[missing_idx].index):
                    # Sample with probability inversely proportional to distance
                    if np.any(distances[i]):  # Check if distances array is not all zeros
                        weights = 1 / (distances[i] + 1e-10)  # Add small constant to avoid division by zero
                        weights = weights / np.sum(weights)  # Normalize to sum to 1
                        neighbor_idx = np.random.choice(indices[i], p=weights)
                    else:
                        neighbor_idx = np.random.choice(indices[i])
                        
                    donor_idx = X_train.index[neighbor_idx]
                    imputed_data.loc[idx, col] = data.loc[donor_idx, col]
                    
            except Exception as e:
                print(f"  Error in PMM for {col}: {e}")
                print("  Falling back to random sampling")
                for idx in data[missing_idx].index:
                    imputed_data.loc[idx, col] = np.random.choice(observed_values)
                    
    return imputed_data

# Function for MICE (Multiple Imputation by Chained Equations)
def mice_impute(data, target_cols, categorical_cols=None, n_iter=10, random_state=42):
    """Impute values using Multiple Imputation by Chained Equations (MICE)"""
    if not target_cols:
        return data
        
    imputed_data = data.copy()
    
    # Select columns to include in MICE based on missingness and correlation
    numeric_cols = [col for col in data.columns 
                   if pd.api.types.is_numeric_dtype(data[col])
                   and data[col].isnull().mean() < 0.5]
    
    if len(numeric_cols) < 3:
        print("  Not enough valid numeric columns for MICE. Using PMM instead.")
        return pmm_impute(data, target_cols, n_neighbors=5, random_state=random_state)
    
    # Create a MICE imputer with appropriate estimator
    print(f"  Imputing {len(target_cols)} columns with MICE...")
    try:
        # BayesianRidge is robust for mixed variable types
        mice_imputer = IterativeImputer(
            estimator=BayesianRidge(),
            max_iter=n_iter,
            random_state=random_state,
            skip_complete=True
        )
        
        # Prepare data for MICE: only include numeric columns and handle categorical explicitly
        mice_data = imputed_data[numeric_cols].copy()
        
        # Run MICE imputation
        imputed_values = mice_imputer.fit_transform(mice_data)
        
        # Update the imputed dataframe with MICE results
        mice_result = pd.DataFrame(imputed_values, columns=numeric_cols, index=mice_data.index)
        
        # Replace only the target columns to avoid modifying other columns
        for col in target_cols:
            if col in numeric_cols:
                # Copy only the previously missing values from MICE results
                missing_mask = imputed_data[col].isnull()
                imputed_data.loc[missing_mask, col] = mice_result.loc[missing_mask, col]
                print(f"    Imputed {missing_mask.sum()} values in {col}")
                
        # Handle categorical columns separately if specified
        if categorical_cols:
            cat_cols_to_impute = [col for col in categorical_cols if col in target_cols]
            if cat_cols_to_impute:
                print("  Handling categorical variables separately...")
                for col in cat_cols_to_impute:
                    # Use mode imputation for categorical variables
                    missing_mask = imputed_data[col].isnull()
                    if missing_mask.sum() > 0:
                        mode_val = imputed_data[col].mode().iloc[0]
                        imputed_data.loc[missing_mask, col] = mode_val
                        print(f"    Imputed {missing_mask.sum()} values in {col} with mode")
                    
    except Exception as e:
        print(f"  Error in MICE: {e}")
        print("  Falling back to PMM imputation")
        return pmm_impute(data, target_cols, n_neighbors=5, random_state=random_state)
        
    return imputed_data

# Function for KNN imputation
def knn_impute(data, target_cols, n_neighbors=5, weights='uniform'):
    """Impute values using KNN"""
    if not target_cols:
        return data
        
    imputed_data = data.copy()
    
    # Only select columns where KNN makes sense (numeric)
    numeric_cols = [col for col in data.columns 
                  if pd.api.types.is_numeric_dtype(data[col])]
    
    # Prepare data for imputation
    target_numeric = [col for col in target_cols if col in numeric_cols]
    
    if not target_numeric:
        print("  No valid numeric columns for KNN imputation")
        return imputed_data
        
    print(f"  Imputing {len(target_numeric)} columns with KNN (k={n_neighbors})...")
    
    try:
        # Create KNN imputer
        knn_imputer = KNNImputer(
            n_neighbors=n_neighbors,
            weights=weights,
            metric='nan_euclidean'
        )
        
        # Apply KNN imputation
        imputed_values = knn_imputer.fit_transform(imputed_data[numeric_cols])
        
        # Update dataframe with imputed values
        imputed_numeric = pd.DataFrame(
            imputed_values,
            columns=numeric_cols,
            index=imputed_data.index
        )
        
        # Copy only target columns back to main dataframe
        for col in target_numeric:
            # Only update missing values
            missing_mask = imputed_data[col].isnull()
            imputed_data.loc[missing_mask, col] = imputed_numeric.loc[missing_mask, col]
            print(f"    Imputed {missing_mask.sum()} values in {col}")
            
    except Exception as e:
        print(f"  Error in KNN imputation: {e}")
        print("  Falling back to simple imputation")
        
        # Simple fallback: median for numeric
        for col in target_numeric:
            missing_mask = imputed_data[col].isnull()
            if missing_mask.sum() > 0:
                median_val = imputed_data[col].median()
                imputed_data.loc[missing_mask, col] = median_val
                print(f"    Imputed {missing_mask.sum()} values in {col} with median")
                
    return imputed_data

# Function for mode imputation
def mode_impute(data, target_cols):
    """Impute categorical variables using mode"""
    if not target_cols:
        return data
        
    imputed_data = data.copy()
    print(f"  Imputing {len(target_cols)} columns with mode...")
    
    for col in target_cols:
        if col not in data.columns:
            continue
            
        missing_mask = imputed_data[col].isnull()
        if missing_mask.sum() == 0:
            continue
            
        # Get the mode value
        mode_val = imputed_data[col].mode().iloc[0]
        
        # Apply imputation
        imputed_data.loc[missing_mask, col] = mode_val
        print(f"    Imputed {missing_mask.sum()} values in {col} with mode {mode_val}")
        
    return imputed_data

# Function for Random Forest imputation
def rf_impute(data, target_cols, categorical_cols=None, n_estimators=100, random_state=42):
    """Impute values using Random Forest"""
    if not target_cols:
        return data
        
    imputed_data = data.copy()
    
    # Start with median/mode imputation to get a complete dataset for RF
    # This will be refined by the RF algorithm
    temp_data = imputed_data.copy()
    for col in data.columns:
        if pd.api.types.is_numeric_dtype(data[col]):
            temp_data[col] = temp_data[col].fillna(temp_data[col].median())
        else:
            temp_data[col] = temp_data[col].fillna(temp_data[col].mode().iloc[0])
    
    # Process each target column
    for col in target_cols:
        if col not in data.columns:
            continue
            
        missing_mask = imputed_data[col].isnull()
        if missing_mask.sum() == 0:
            continue
            
        print(f"  Imputing {col} with Random Forest...")
        
        # Select features
        features = [c for c in temp_data.columns 
                   if c != col 
                   and pd.api.types.is_numeric_dtype(temp_data[c])
                   and temp_data[c].isnull().sum() == 0]
        
        if len(features) < 2:
            print(f"    Not enough features for {col}. Using simple imputation.")
            if pd.api.types.is_numeric_dtype(imputed_data[col]):
                imputed_data.loc[missing_mask, col] = imputed_data[col].median()
            else:
                imputed_data.loc[missing_mask, col] = imputed_data[col].mode().iloc[0]
            continue
        
        # Split into train (known values) and test (missing values)
        X_train = temp_data.loc[~missing_mask, features]
        y_train = temp_data.loc[~missing_mask, col]
        X_test = temp_data.loc[missing_mask, features]
        
        try:
            # Train Random Forest model (classifier for categorical, regressor for numeric)
            if categorical_cols and col in categorical_cols:
                model = RandomForestClassifier(
                    n_estimators=n_estimators,
                    random_state=random_state
                )
            else:
                model = RandomForestRegressor(
                    n_estimators=n_estimators,
                    random_state=random_state
                )
                
            # Train the model
            model.fit(X_train, y_train)
            
            # Predict missing values
            y_pred = model.predict(X_test)
            
            # Update the imputed dataframe
            imputed_data.loc[missing_mask, col] = y_pred
            print(f"    Imputed {missing_mask.sum()} values in {col}")
            
        except Exception as e:
            print(f"    Error in RF imputation for {col}: {e}")
            print("    Falling back to simple imputation")
            if pd.api.types.is_numeric_dtype(imputed_data[col]):
                imputed_data.loc[missing_mask, col] = imputed_data[col].median()
            else:
                imputed_data.loc[missing_mask, col] = imputed_data[col].mode().iloc[0]
                
    return imputed_data

# Function for conditional imputation
def conditional_impute(data, conditions):
    """Impute values based on logical conditions"""
    imputed_data = data.copy()
    print("\nApplying conditional imputation...")
    
    for condition in conditions:
        col = condition['column']
        filter_col = condition['filter_column']
        filter_value = condition['filter_value']
        strategy = condition['strategy']
        
        if col not in data.columns or filter_col not in data.columns:
            print(f"  Column {col} or {filter_col} not found in data")
            continue
            
        filter_mask = data[filter_col] == filter_value
        missing_mask = data[col].isnull() & filter_mask
        
        if missing_mask.sum() == 0:
            print(f"  No conditional missing values in {col}")
            continue
            
        print(f"  Conditionally imputing {missing_mask.sum()} values in {col} where {filter_col}={filter_value}")
        
        if strategy == 'zero':
            imputed_data.loc[missing_mask, col] = 0
        elif strategy == 'mode':
            subset_mode = data.loc[filter_mask & ~data[col].isnull(), col].mode()
            if len(subset_mode) > 0:
                imputed_data.loc[missing_mask, col] = subset_mode.iloc[0]
            else:
                imputed_data.loc[missing_mask, col] = data[col].mode().iloc[0]
        elif strategy == 'median':
            subset_median = data.loc[filter_mask & ~data[col].isnull(), col].median()
            if not np.isnan(subset_median):
                imputed_data.loc[missing_mask, col] = subset_median
            else:
                imputed_data.loc[missing_mask, col] = data[col].median()
        elif strategy == 'custom':
            # Handle Age ménopause specifically
            if col == 'Age ménopause':
                # For menopausal women: age - random(3-10)
                subset = data.loc[filter_mask & ~data[col].isnull()]
                if len(subset) > 0:
                    avg_diff = (subset['Age'] - subset['Age ménopause']).mean()
                    avg_diff = max(5, min(15, avg_diff))  # Constrain to reasonable range
                else:
                    avg_diff = 7  # Default assumption
                    
                for idx in data[missing_mask].index:
                    age = data.loc[idx, 'Age']
                    if not np.isnan(age):
                        random_diff = np.random.uniform(avg_diff - 2, avg_diff + 2)
                        imputed_data.loc[idx, col] = max(35, age - random_diff)
                    else:
                        imputed_data.loc[idx, col] = np.random.uniform(45, 55)  # Default range
                        
                print(f"    Used custom age-based approach with average difference of {avg_diff:.1f} years")
        
        # Handle the opposite condition (e.g., non-menopausal women)
        if condition.get('handle_opposite', False):
            opposite_mask = (data[filter_col] != filter_value) & ~data[col].isnull()
            if opposite_mask.sum() > 0:
                print(f"  Setting {opposite_mask.sum()} values to NaN in {col} where {filter_col}!={filter_value}")
                imputed_data.loc[opposite_mask, col] = np.nan
                
    return imputed_data

# Function to fix logical constraints
def fix_logical_constraints(data):
    """Fix logical inconsistencies in the data"""
    print("\nChecking and fixing logical constraints...")
    fixed_data = data.copy()
    
    # Age ménopause <= Age for menopausal women
    if 'Age ménopause' in fixed_data.columns and 'Age' in fixed_data.columns:
        invalid_age = fixed_data[(fixed_data['Age ménopause'].notna()) & 
                                (fixed_data['Age'].notna()) & 
                                (fixed_data['Age ménopause'] > fixed_data['Age'])]
        
        if len(invalid_age) > 0:
            print(f"  Found {len(invalid_age)} cases where Age ménopause > Age")
            for idx, row in invalid_age.iterrows():
                fixed_data.loc[idx, 'Age ménopause'] = row['Age'] - np.random.uniform(1, 5)
                
    # Agricultural experience should be reasonable
    if 'Ancienneté agricole' in fixed_data.columns and 'Age' in fixed_data.columns:
        invalid_exp = fixed_data[(fixed_data['Ancienneté agricole'].notna()) & 
                                (fixed_data['Age'].notna()) & 
                                (fixed_data['Ancienneté agricole'] > (fixed_data['Age'] - 15))]
        
        if len(invalid_exp) > 0:
            print(f"  Found {len(invalid_exp)} cases with unrealistic work experience")
            for idx, row in invalid_exp.iterrows():
                max_exp = max(0, row['Age'] - 15)
                fixed_data.loc[idx, 'Ancienneté agricole'] = np.random.uniform(1, max_exp)
                
    # Number of children should be realistic
    if 'Nb enfants' in fixed_data.columns and 'Age' in fixed_data.columns:
        invalid_children = fixed_data[(fixed_data['Nb enfants'].notna()) & 
                                     (fixed_data['Age'].notna()) & 
                                     (fixed_data['Nb enfants'] > (fixed_data['Age'] - 15) / 2)]
        
        if len(invalid_children) > 0:
            print(f"  Found {len(invalid_children)} cases with unrealistic number of children")
            for idx, row in invalid_children.iterrows():
                max_children = int((row['Age'] - 15) / 2)
                fixed_data.loc[idx, 'Nb enfants'] = np.random.randint(0, max(1, max_children + 1))
                
    # Work hours per day should be realistic
    if 'H travail / jour' in fixed_data.columns:
        too_many_hours = fixed_data[fixed_data['H travail / jour'] > 16]
        if len(too_many_hours) > 0:
            print(f"  Capping {len(too_many_hours)} cases with >16 work hours per day")
            fixed_data.loc[fixed_data['H travail / jour'] > 16, 'H travail / jour'] = 16
            
    # Work days per week should be realistic
    if 'J travail / Sem' in fixed_data.columns:
        too_many_days = fixed_data[fixed_data['J travail / Sem'] > 7]
        if len(too_many_days) > 0:
            print(f"  Capping {len(too_many_days)} cases with >7 work days per week")
            fixed_data.loc[fixed_data['J travail / Sem'] > 7, 'J travail / Sem'] = 7
            
    return fixed_data

# Function to validate the imputation results
def validate_imputation(original_data, imputed_data, key_columns):
    """Validate the imputation results by comparing distributions"""
    print("\nValidating imputation results...")
    
    # Setup plotting
    n_cols = len(key_columns)
    if n_cols == 0:
        return
        
    n_rows = (n_cols + 2) // 3  # Ceiling division by 3
    fig, axes = plt.subplots(n_rows, 3, figsize=(18, n_rows * 5))
    axes = axes.flatten()
    
    # Generate comparison plots for each key column
    for i, col in enumerate(key_columns):
        if col not in original_data.columns or col not in imputed_data.columns:
            continue
            
        ax = axes[i]
        
        # Get original and imputed distributions
        orig_data = original_data[col].dropna()
        imp_data = imputed_data[col]
        
        # Check if numeric or categorical
        if pd.api.types.is_numeric_dtype(orig_data):
            if len(orig_data) > 5:  # Only plot if we have enough data
                sns.histplot(orig_data, color='blue', alpha=0.5, label='Original', ax=ax)
                sns.histplot(imp_data, color='red', alpha=0.5, label='Imputed', ax=ax)
                
                # Add vertical lines for means
                ax.axvline(orig_data.mean(), color='blue', linestyle='--', alpha=0.7)
                ax.axvline(imp_data.mean(), color='red', linestyle='--', alpha=0.7)
                
                # Add statistics
                orig_stats = f"Original (n={len(orig_data)}): mean={orig_data.mean():.2f}, std={orig_data.std():.2f}"
                imp_stats = f"Imputed (n={len(imp_data)}): mean={imp_data.mean():.2f}, std={imp_data.std():.2f}"
                ax.text(0.05, 0.95, orig_stats, transform=ax.transAxes, va='top', fontsize=9)
                ax.text(0.05, 0.90, imp_stats, transform=ax.transAxes, va='top', fontsize=9)
                
                # Calculate and display KS test p-value if enough data
                if len(orig_data) >= 5 and len(imp_data) >= 5:
                    from scipy.stats import ks_2samp
                    ks_stat, p_val = ks_2samp(orig_data, imp_data)
                    ax.text(0.05, 0.85, f"KS test p-value: {p_val:.3f}", transform=ax.transAxes, va='top', fontsize=9)
                    
                # Print mean and std differences
                mean_diff = abs(orig_data.mean() - imp_data.mean())
                std_diff = abs(orig_data.std() - imp_data.std())
                print(f"  {col}: mean diff = {mean_diff:.2f}, std diff = {std_diff:.2f}")
                
        else:
            # For categorical variables, use countplot
            orig_counts = orig_data.value_counts(normalize=True)
            imp_counts = imp_data.value_counts(normalize=True)
            
            # Combine indices to ensure all categories are shown
            all_cats = sorted(set(list(orig_counts.index) + list(imp_counts.index)))
            
            # Create a DataFrame for plotting
            plot_data = pd.DataFrame({'Category': [], 'Proportion': [], 'Source': []})
            
            for cat in all_cats:
                if cat in orig_counts:
                    plot_data = plot_data.append({'Category': cat, 'Proportion': orig_counts[cat], 'Source': 'Original'}, 
                                                ignore_index=True)
                if cat in imp_counts:
                    plot_data = plot_data.append({'Category': cat, 'Proportion': imp_counts[cat], 'Source': 'Imputed'}, 
                                                ignore_index=True)
                    
            # Plot
            sns.barplot(x='Category', y='Proportion', hue='Source', data=plot_data, ax=ax)
            ax.set_title(f'Distribution of {col}')
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
            
            # Calculate Hellinger distance as a measure of distribution similarity
            hell_dist = np.sqrt(np.sum((np.sqrt(orig_counts.reindex(all_cats).fillna(0)) - 
                                       np.sqrt(imp_counts.reindex(all_cats).fillna(0))) ** 2)) / np.sqrt(2)
            ax.text(0.05, 0.95, f"Hellinger distance: {hell_dist:.3f}", transform=ax.transAxes, va='top', fontsize=9)
            print(f"  {col}: Hellinger distance = {hell_dist:.3f}")
            
        ax.set_title(f'Distribution Comparison: {col}')
        ax.legend()
        
    # Remove any unused subplots
    for i in range(len(key_columns), len(axes)):
        fig.delaxes(axes[i])
        
    plt.tight_layout()
    plt.savefig('imputation_validation.png', dpi=300)
    plt.close()
    
    print("  Validation plots saved to 'imputation_validation.png'")

# Main function to handle the entire imputation process
def impute_female_farmers_data():
    """Main function to impute the female farmers dataset"""
    print("\n=== BEGINNING IMPUTATION PROCESS FOR FEMALE FARMERS DATASET ===\n")
    
    # 1. Load the data
    try:
        print("Loading data...")
        data = pd.read_excel('encoded_female_farmers_data_no_text.xlsx')
        print(f"Loaded {data.shape[0]} rows and {data.shape[1]} columns")
    except Exception as e:
        print(f"Error loading data: {e}")
        return
    
    # 2. Define column types
    numerical_cols = ['Age', 'Nb enfants', 'Nb pers à charge', 'H travail / jour', 'Age ménopause', 
                      'Ancienneté agricole', 'J travail / Sem', 'Poids', 'Taille', 'TAS', 'TAD', 'GAD']
    
    binary_cols = ['Neffa', 'Fumées de Tabouna', 'AT en milieu agricole', 'Ménopause', 'Tabagisme']
    
    ordinal_cols = ['Situation maritale', 'Domicile', 'Niveau socio-économique', 'Statut', 'Niveau scolaire',
                   'Masque pour pesticides', 'Bottes', 'Gants', 'Casquette/Mdhalla', 'Manteau imperméable']
    
    # Check which columns actually exist in the dataset
    numerical_cols = [col for col in numerical_cols if col in data.columns]
    binary_cols = [col for col in binary_cols if col in data.columns]
    ordinal_cols = [col for col in ordinal_cols if col in data.columns]
    
    # Identify prefix patterns for one-hot encoded and binary indicator variables
    chemical_cols = [col for col in data.columns if col.startswith('Chemical_')]
    bio_cols = [col for col in data.columns if col.startswith('Bio_')]
    fertilizer_cols = [col for col in data.columns if col.startswith('Fertilizer_')]
    thermal_cols = [col for col in data.columns if col.startswith('Thermal_')]
    transport_cols = [col for col in data.columns if col.startswith('Transport_')]
    
    # Identify husband profession columns (one-hot encoded)
    husband_cols = [col for col in data.columns if col.startswith('Profession du mari_')]
    
    # Group columns by prefix
    binary_indicator_groups = {
        'Chemical': chemical_cols,
        'Bio': bio_cols,
        'Fertilizer': fertilizer_cols,
        'Thermal': thermal_cols,
        'Transport': transport_cols
    }
    
    one_hot_groups = {
        'Profession du mari': husband_cols
    }
    
    # 3. Check missing data
    print("\nAnalyzing missing data...")
    missing_counts = {}
    
    # Check numerical, binary, and ordinal columns
    for col in numerical_cols + binary_cols + ordinal_cols:
        missing_count = data[col].isnull().sum()
        if missing_count > 0:
            missing_counts[col] = {
                'count': missing_count,
                'percentage': missing_count / len(data) * 100,
                'type': 'direct'
            }
    
    # Check binary indicator groups
    for prefix, cols in binary_indicator_groups.items():
        if cols:
            missing_count = data[cols].isnull().all(axis=1).sum()
            if missing_count > 0:
                missing_counts[prefix] = {
                    'count': missing_count,
                    'percentage': missing_count / len(data) * 100,
                    'type': 'binary_indicator'
                }
    
    # Check one-hot encoded groups
    for group, cols in one_hot_groups.items():
        nan_col = f"{group}_nan"
        if nan_col in data.columns:
            missing_count = data[nan_col].sum()
            if missing_count > 0:
                missing_counts[group] = {
                    'count': missing_count,
                    'percentage': missing_count / len(data) * 100,
                    'type': 'one_hot'
                }
    
    # Display missing data summary
    print("\nMissing data summary:")
    for var, info in sorted(missing_counts.items(), key=lambda x: x[1]['percentage'], reverse=True):
        print(f"  {var}: {info['count']} missing values ({info['percentage']:.1f}%)")
    
    # 4. Apply initial fixes and logical constraints
    print("\nApplying initial fixes...")
    fixed_data = fix_logical_constraints(data)
    
    # 5. Define imputation strategy based on missing mechanism analysis
    # Based on the analysis in the previous step
    
    # 5.1 Conditional imputation first (for MNAR variables)
    print("\nStep 1: Applying conditional imputation for MNAR variables...")
    conditional_imputation_rules = [
        {
            'column': 'Age ménopause',
            'filter_column': 'Ménopause',
            'filter_value': 1,
            'strategy': 'custom',
            'handle_opposite': True
        }
    ]
    
    imputed_data = conditional_impute(fixed_data, conditional_imputation_rules)
    
    # 6. Determine optimal k for KNN-based methods using CAH
    print("\nStep 2: Determining optimal k values using hierarchical clustering (CAH)...")
    
    # Different k values for different variable groups
    k_demographic = determine_optimal_k(imputed_data, ['Age', 'Nb enfants', 'Nb pers à charge'], 
                                      min_k=3, max_k=10, prefix="Demographic Variables")
    
    k_health = determine_optimal_k(imputed_data, ['Poids', 'Taille', 'TAS', 'TAD', 'GAD'], 
                                 min_k=3, max_k=10, prefix="Health Variables")
    
    k_work = determine_optimal_k(imputed_data, ['H travail / jour', 'Ancienneté agricole', 'J travail / Sem'],
                               min_k=3, max_k=10, prefix="Work Variables")
    
    # 7. Apply imputation techniques based on missingness mechanisms
    print("\nStep 3: Applying specific imputation techniques based on missingness mechanisms...")
    
    # 7.1 Strong MAR variables (with MICE)
    strong_mar_cols = ['Nb pers à charge']
    strong_mar_cols = [col for col in strong_mar_cols if col in imputed_data.columns and imputed_data[col].isnull().sum() > 0]
    
    if strong_mar_cols:
        print("\nImputing strong MAR variables with MICE...")
        imputed_data = mice_impute(imputed_data, strong_mar_cols, n_iter=10)
    
    # 7.2 Moderate MAR variables (with PMM)
    moderate_mar_cols = ['Niveau socio-économique', 'Profession du mari']
    moderate_mar_cols = [col for col in moderate_mar_cols if col in imputed_data.columns and imputed_data[col].isnull().sum() > 0]
    
    if 'Niveau socio-économique' in moderate_mar_cols:
        print("\nImputing Niveau socio-économique with PMM...")
        imputed_data = pmm_impute(imputed_data, ['Niveau socio-économique'], n_neighbors=k_demographic)
    
    if 'Profession du mari' in moderate_mar_cols:
        print("\nImputing Profession du mari (one-hot encoded columns)...")
        # Use KNN for one-hot encoded husband profession
        if 'Profession du mari_nan' in imputed_data.columns:
            # Identify rows with missing husband profession
            missing_husband = imputed_data['Profession du mari_nan'] == 1
            
            if missing_husband.sum() > 0:
                print(f"  Found {missing_husband.sum()} rows with missing husband profession")
                
                # Use random forest to predict the most likely profession category
                rf_features = ['Age', 'Nb enfants', 'Nb pers à charge', 'Niveau socio-économique', 
                              'Ancienneté agricole', 'Niveau scolaire']
                rf_features = [f for f in rf_features if f in imputed_data.columns]
                
                # Get existing husband professions (non-nan columns)
                husband_categories = [col for col in husband_cols if not col.endswith('_nan')]
                
                if husband_categories and rf_features:
                    try:
                        # Create temporary dataset with one profession per row
                        temp_data = imputed_data.copy()
                        
                        # Fill missing values in features
                        for feature in rf_features:
                            if temp_data[feature].isnull().sum() > 0:
                                if pd.api.types.is_numeric_dtype(temp_data[feature]):
                                    temp_data[feature] = temp_data[feature].fillna(temp_data[feature].median())
                                else:
                                    temp_data[feature] = temp_data[feature].fillna(temp_data[feature].mode().iloc[0])
                        
                        # For rows with known husband profession, identify which profession
                        known_profession = pd.Series(index=temp_data.index, dtype='object')
                        for idx in temp_data[~missing_husband].index:
                            for prof in husband_categories:
                                if temp_data.loc[idx, prof] == 1:
                                    known_profession.loc[idx] = prof
                                    break
                        
                        # Remove rows with no clear profession
                        valid_mask = known_profession.notna()
                        
                        if valid_mask.sum() > 0:
                            X_train = temp_data.loc[valid_mask, rf_features]
                            y_train = known_profession[valid_mask]
                            
                            # Train random forest classifier
                            rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
                            rf_model.fit(X_train, y_train)
                            
                            # Predict profession for missing rows
                            X_predict = temp_data.loc[missing_husband, rf_features]
                            predictions = rf_model.predict(X_predict)
                            
                            # Apply predictions
                            for i, idx in enumerate(temp_data[missing_husband].index):
                                pred_profession = predictions[i]
                                # Reset all profession columns to 0
                                for prof in husband_categories:
                                    imputed_data.loc[idx, prof] = 0
                                # Set predicted profession to 1
                                imputed_data.loc[idx, pred_profession] = 1
                                # Set nan indicator to 0
                                imputed_data.loc[idx, 'Profession du mari_nan'] = 0
                                
                            print(f"  Successfully imputed husband professions using Random Forest")
                        else:
                            print("  No valid training data for husband profession prediction")
                    except Exception as e:
                        print(f"  Error in husband profession imputation: {e}")
                        print("  Using mode imputation instead")
                        
                        # Fallback: use most common profession
                        most_common_prof = None
                        max_count = 0
                        for prof in husband_categories:
                            count = imputed_data[prof].sum()
                            if count > max_count:
                                max_count = count
                                most_common_prof = prof
                        
                        if most_common_prof:
                            for idx in imputed_data[missing_husband].index:
                                # Reset all profession columns to 0
                                for prof in husband_categories:
                                    imputed_data.loc[idx, prof] = 0
                                # Set most common profession to 1
                                imputed_data.loc[idx, most_common_prof] = 1
                                # Set nan indicator to 0
                                imputed_data.loc[idx, 'Profession du mari_nan'] = 0
    
    # 7.3 MCAR variables (with mode/median imputation)
    mcar_cols = []
    for prefix, cols in binary_indicator_groups.items():
        if cols and prefix in missing_counts and missing_counts[prefix]['type'] == 'binary_indicator':
            mcar_cols.append(prefix)
    
    for prefix in mcar_cols:
        cols = binary_indicator_groups.get(prefix, [])
        if cols:
            print(f"\nImputing MCAR binary indicators for {prefix}...")
            # Get rows with all indicators missing
            all_missing = imputed_data[cols].isnull().all(axis=1)
            
            if all_missing.sum() > 0:
                # For each column, calculate the proportion of 1s in non-missing data
                proportions = {}
                for col in cols:
                    valid_data = imputed_data.loc[~all_missing, col].dropna()
                    if len(valid_data) > 0:
                        proportions[col] = valid_data.mean()
                    else:
                        proportions[col] = 0.5  # Default if no data
                
                # Impute missing indicators based on these proportions
                for idx in imputed_data[all_missing].index:
                    for col in cols:
                        # Generate random value based on proportion
                        imputed_data.loc[idx, col] = np.random.choice([0, 1], p=[1-proportions[col], proportions[col]])
    
    # 7.4 Low missingness variables (with KNN)
    low_missing_cols = []
    for col in numerical_cols + binary_cols + ordinal_cols:
        if col in missing_counts and missing_counts[col]['percentage'] < 10 and col != 'Age ménopause':
            low_missing_cols.append(col)
    
    if low_missing_cols:
        print("\nImputing low missingness variables with KNN...")
        
        # Group variables by type for better imputation
        low_missing_numerical = [col for col in low_missing_cols if col in numerical_cols]
        low_missing_binary = [col for col in low_missing_cols if col in binary_cols]
        low_missing_ordinal = [col for col in low_missing_cols if col in ordinal_cols]
        
        # Use appropriate k values for different variable types
        if low_missing_numerical:
            # Further split numerical variables by domain
            demographic_vars = [col for col in low_missing_numerical if col in ['Age', 'Nb enfants']]
            health_vars = [col for col in low_missing_numerical if col in ['Poids', 'Taille', 'TAS', 'TAD', 'GAD']]
            work_vars = [col for col in low_missing_numerical if col in ['H travail / jour', 'Ancienneté agricole', 'J travail / Sem']]
            
            if demographic_vars:
                imputed_data = knn_impute(imputed_data, demographic_vars, n_neighbors=k_demographic)
            if health_vars:
                imputed_data = knn_impute(imputed_data, health_vars, n_neighbors=k_health)
            if work_vars:
                imputed_data = knn_impute(imputed_data, work_vars, n_neighbors=k_work)
        
        if low_missing_binary:
            imputed_data = knn_impute(imputed_data, low_missing_binary, n_neighbors=k_demographic)
            # Round binary values
            for col in low_missing_binary:
                if col in imputed_data.columns:
                    imputed_data[col] = imputed_data[col].round().clip(0, 1)
        
        if low_missing_ordinal:
            imputed_data = knn_impute(imputed_data, low_missing_ordinal, n_neighbors=k_demographic)
            # Round ordinal values to nearest integer
            for col in low_missing_ordinal:
                if col in imputed_data.columns:
                    imputed_data[col] = imputed_data[col].round()
    
    # 8. Post-processing: Fix data types and constraints
    print("\nPost-processing: Fixing data types and constraints...")
    
    # 8.1 Round numerical values that should be integers
    integer_cols = ['Age', 'Nb enfants', 'Nb pers à charge', 'Age ménopause', 'J travail / Sem']
    for col in integer_cols:
        if col in imputed_data.columns:
            imputed_data[col] = imputed_data[col].round()
    
    # 8.2 Ensure binary columns are 0 or 1
    for col in binary_cols:
        if col in imputed_data.columns:
            imputed_data[col] = imputed_data[col].round().clip(0, 1)
    
    # 8.3 Ensure ordinal columns have valid values
    for col in ordinal_cols:
        if col in imputed_data.columns:
            if col in ['Masque pour pesticides', 'Bottes', 'Gants', 'Casquette/Mdhalla', 'Manteau imperméable']:
                # These should be 0, 1, 2, or 3
                imputed_data[col] = imputed_data[col].round().clip(0, 3)
            elif col == 'Situation maritale':
                # 0=célibataire, 1=mariée, 2=divorcée, 3=veuve
                imputed_data[col] = imputed_data[col].round().clip(0, 3)
            elif col == 'Domicile':
                # 0=monastir, 1=sfax, 2=mahdia
                imputed_data[col] = imputed_data[col].round().clip(0, 2)
            elif col == 'Niveau socio-économique':
                # 0=bas, 1=moyen, 2=bon
                imputed_data[col] = imputed_data[col].round().clip(0, 2)
            elif col == 'Statut':
                # 0=permanente, 1=saisonnière
                imputed_data[col] = imputed_data[col].round().clip(0, 1)
            elif col == 'Niveau scolaire':
                # 0=analphabète, 1=primaire, 2=secondaire, 3=supérieur
                imputed_data[col] = imputed_data[col].round().clip(0, 3)
    
    # 8.4 Fix one-hot encoding issues (ensure one value is 1, rest are 0)
    for group, cols in one_hot_groups.items():
        nan_col = f"{group}_nan"
        if nan_col in cols:
            cols.remove(nan_col)
        
        # Ensure each row has exactly one 1
        for idx in imputed_data.index:
            row_values = imputed_data.loc[idx, cols]
            if row_values.sum() != 1:
                # If no value is 1, set the most common value to 1
                if row_values.sum() == 0:
                    most_common = imputed_data[cols].sum().idxmax()
                    imputed_data.loc[idx, cols] = 0
                    imputed_data.loc[idx, most_common] = 1
                # If multiple values are 1, keep only the highest value
                elif row_values.sum() > 1:
                    max_col = row_values.idxmax()
                    imputed_data.loc[idx, cols] = 0
                    imputed_data.loc[idx, max_col] = 1
                    
            # Ensure nan indicator is 0
            if nan_col in imputed_data.columns:
                imputed_data.loc[idx, nan_col] = 0
    
    # 8.5 Final logical constraint check
    imputed_data = fix_logical_constraints(imputed_data)
    
    # In your impute_female_farmers_data function, replace the validation section (around line 650) with:

    # 9. Validate imputation
    print("\nValidating imputation results...")
    # Select only key columns that actually had missing values for validation
    key_columns = []

    # Get columns with missing values for each column type
    for col in numerical_cols + binary_cols + ordinal_cols:
        if col in data.columns and data[col].isnull().sum() > 0:
            key_columns.append(col)
            
    # Add one-hot encoded group variables for validation if they had missing values
    for group, cols in one_hot_groups.items():
        nan_col = f"{group}_nan"
        if nan_col in data.columns and data[nan_col].sum() > 0:
            key_columns.append(group)

    # Call our enhanced validation function
    validate_imputation(data, imputed_data, key_columns)
    
    # 10. Save imputed data
    print("\nSaving imputed data...")
    imputed_data.to_excel('imputed_female_farmers_data.xlsx', index=False)
    print("Imputed data saved to 'imputed_female_farmers_data.xlsx'")
    
    # 11. Summary of imputation
    print("\n=== IMPUTATION SUMMARY ===")
    print(f"Total rows: {len(imputed_data)}")
    print(f"Total columns: {len(imputed_data.columns)}")
    
    # Check for any remaining missing values
    remaining_missing = imputed_data.isnull().sum()
    remaining_missing = remaining_missing[remaining_missing > 0]
    
    if len(remaining_missing) > 0:
        print("\nWarning: Some columns still have missing values:")
        for col, count in remaining_missing.items():
            print(f"  {col}: {count} missing values")
    else:
        print("\nAll columns successfully imputed!")
        
    print("\n=== IMPUTATION PROCESS COMPLETED ===\n")
    
    return imputed_data

# Execute the imputation process if run directly
if __name__ == "__main__":
    impute_female_farmers_data()


=== BEGINNING IMPUTATION PROCESS FOR FEMALE FARMERS DATASET ===

Loading data...
Loaded 80 rows and 74 columns

Analyzing missing data...

Missing data summary:
  Age ménopause: 51 missing values (63.7%)
  Chemical: 41 missing values (51.2%)
  Thermal: 34 missing values (42.5%)
  Bio: 31 missing values (38.8%)
  Nb pers à charge: 30 missing values (37.5%)
  Niveau socio-économique: 27 missing values (33.8%)
  Profession du mari: 23 missing values (28.7%)
  Fertilizer: 16 missing values (20.0%)
  Statut: 8 missing values (10.0%)
  J travail / Sem: 4 missing values (5.0%)
  Age: 3 missing values (3.8%)
  H travail / jour: 3 missing values (3.8%)
  Ancienneté agricole: 3 missing values (3.8%)
  Neffa: 2 missing values (2.5%)
  Tabagisme: 2 missing values (2.5%)
  Niveau scolaire: 2 missing values (2.5%)
  Nb enfants: 1 missing values (1.2%)
  Fumées de Tabouna: 1 missing values (1.2%)
  Situation maritale: 1 missing values (1.2%)
  Transport: 1 missing values (1.2%)

Applying initial fix

In [6]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import os
import warnings

# Set the style for visualizations
plt.style.use('seaborn-v0_8-whitegrid')
warnings.filterwarnings('ignore')

# Define paths
imputed_data_path = "imputed_female_farmers_data.xlsx"
output_dir = "./"  # Current directory

# Load the already imputed dataset
try:
    imputed_df = pd.read_excel(imputed_data_path)
except FileNotFoundError:
    print(f"Error: The file {imputed_data_path} was not found. Please check the file path.")
    exit(1)
except Exception as e:
    print(f"Error loading the dataset: {e}")
    exit(1)

# 1. EXPLORATORY ANALYSIS OF THE STATUT FIELD
# ===========================================

def analyze_statut_field(df):
    """Analyze the Statut field in the dataset"""
    print("=" * 50)
    print("ANALYSIS OF THE STATUT FIELD")
    print("=" * 50)
    
    print(f"Total rows: {df.shape[0]}")
    print(f"Statut field - Missing values: {df['Statut'].isna().sum()}")
    print(f"Statut field - Value counts:\n{df['Statut'].value_counts(dropna=False)}")
    
    # Visualize the distribution
    plt.figure(figsize=(7, 5))
    ax = df['Statut'].value_counts(dropna=False).plot(kind='bar')
    plt.title('Statut Distribution Before Imputation')
    plt.xlabel('Statut Value')
    plt.ylabel('Count')
    for p in ax.patches:
        ax.annotate(str(int(p.get_height())), (p.get_x() + p.get_width()/2., p.get_height()),
                    ha='center', va='center', xytext=(0, 10), textcoords='offset points')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'statut_distribution_before_imputation.png'), dpi=300)
    plt.close()

# 2. IMPUTE THE STATUT FIELD USING RANDOM FOREST
# =============================================

def impute_statut_field(df):
    """Impute the Statut field using Random Forest"""
    print("\n" + "=" * 50)
    print("IMPUTING THE STATUT FIELD WITH RANDOM FOREST")
    print("=" * 50)
    
    # Create a copy of the dataset
    df = df.copy()
    
    # Define predictor variables based on domain knowledge
    predictors = [
        'Age', 'Ancienneté agricole', 'H travail / jour', 'J travail / Sem', 
        'Niveau scolaire', 'Situation maritale', 'Nb enfants', 'Nb pers à charge',
        'Domicile', 'AT en milieu agricole'
    ]
    
    # Ensure all predictors exist in the dataset
    predictors = [col for col in predictors if col in df.columns]
    if not predictors:
        print("No valid predictors found. Cannot impute Statut.")
        return df
    
    # Since the dataset is already imputed, we assume predictors have no missing values
    # Check for missing values in predictors just to be safe
    missing_predictors = df[predictors].isna().sum()
    if missing_predictors.sum() > 0:
        print("Warning: Some predictors have missing values, which should not happen in an already imputed dataset:")
        print(missing_predictors[missing_predictors > 0])
        print("Proceeding with imputation, but results may be affected.")
    
    # Get subset with non-missing Statut for training
    train_df = df.dropna(subset=['Statut'])
    if len(train_df) < 10:
        print("Not enough non-missing Statut values to train a model. Skipping imputation.")
        return df
    
    # Split data for training and evaluation
    X = train_df[predictors]
    y = train_df['Statut']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Train a Random Forest classifier
    print("Training Random Forest model...")
    rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
    rf_model.fit(X_train, y_train)
    
    # Evaluate the model
    y_pred = rf_model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Model accuracy: {accuracy:.4f}")
    print("Classification report:")
    print(classification_report(y_test, y_pred))
    
    # Feature importances
    feature_importances = pd.DataFrame({
        'Feature': predictors,
        'Importance': rf_model.feature_importances_
    }).sort_values(by='Importance', ascending=False)
    print("\nFeature importances:")
    print(feature_importances)
    
    # Visualize feature importances
    plt.figure(figsize=(10, 6))
    sns.barplot(x='Importance', y='Feature', data=feature_importances)
    plt.title('Feature Importance for Statut Imputation')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'statut_feature_importance.png'), dpi=300)
    plt.close()
    
    # Impute missing Statut values
    missing_statut_mask = df['Statut'].isna()
    if missing_statut_mask.sum() > 0:
        X_missing = df.loc[missing_statut_mask, predictors]
        predicted_statut = rf_model.predict(X_missing)
        df.loc[missing_statut_mask, 'Statut'] = predicted_statut
        print(f"\nImputed {missing_statut_mask.sum()} missing values in Statut field")
        print(f"New Statut distribution:\n{df['Statut'].value_counts(dropna=False)}")
    else:
        print("\nNo missing values to impute in Statut field")
    
    return df

# 3. VALIDATE THE IMPUTATION
# ==========================

def validate_statut_imputation(original_df, imputed_df):
    """Validate the imputation of the Statut field"""
    print("\n" + "=" * 50)
    print("VALIDATING STATUT IMPUTATION")
    print("=" * 50)
    
    # Compare distributions
    plt.figure(figsize=(14, 7))
    
    # Original distribution
    plt.subplot(1, 2, 1)
    ax = original_df['Statut'].value_counts(dropna=False).plot(kind='bar', color='blue', alpha=0.7)
    plt.title('Statut Distribution (Original)')
    plt.xlabel('Statut Value')
    plt.ylabel('Count')
    for p in ax.patches:
        ax.annotate(str(int(p.get_height())), (p.get_x() + p.get_width()/2., p.get_height()),
                    ha='center', va='center', xytext=(0, 10), textcoords='offset points')
    
    # Imputed distribution
    plt.subplot(1, 2, 2)
    ax = imputed_df['Statut'].value_counts(dropna=False).plot(kind='bar', color='red', alpha=0.7)
    plt.title('Statut Distribution (Imputed)')
    plt.xlabel('Statut Value')
    plt.ylabel('Count')
    for p in ax.patches:
        ax.annotate(str(int(p.get_height())), (p.get_x() + p.get_width()/2., p.get_height()),
                    ha='center', va='center', xytext=(0, 10), textcoords='offset points')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'statut_distribution_comparison.png'), dpi=300)
    plt.close()
    
    # Chi-squared test to compare distributions
    from scipy.stats import chi2_contingency
    orig_counts = original_df['Statut'].value_counts(dropna=False).sort_index()
    imp_counts = imputed_df['Statut'].value_counts(dropna=False).sort_index()
    all_cats = sorted(set(orig_counts.index).union(imp_counts.index))
    
    contingency_table = pd.DataFrame({
        'Original': [orig_counts.get(cat, 0) for cat in all_cats],
        'Imputed': [imp_counts.get(cat, 0) for cat in all_cats]
    }, index=all_cats)
    
    chi2, p_val, _, _ = chi2_contingency(contingency_table)
    print(f"\nChi-squared test for distribution similarity:")
    print(f"Chi2 statistic: {chi2:.4f}, p-value: {p_val:.4f}")
    if p_val > 0.05:
        print("Distributions are not significantly different (p > 0.05)")
    else:
        print("Distributions are significantly different (p ≤ 0.05)")

# 4. MAIN FUNCTION
# ================

def main():
    """Main function to impute and validate the Statut field"""
    print("FEMALE FARMERS DATA - STATUT FIELD IMPUTATION")
    print("=" * 60)
    
    # Step 1: Analyze the Statut field
    analyze_statut_field(imputed_df)
    
    # Step 2: Impute the Statut field
    original_df = imputed_df.copy()  # Keep a copy for validation
    improved_df = impute_statut_field(imputed_df)
    
    # Step 3: Validate the imputation
    validate_statut_imputation(original_df, improved_df)
    
    # Step 4: Save the updated dataset
    try:
        improved_df.to_excel(imputed_data_path, index=False)
        print(f"\nUpdated dataset saved to: {imputed_data_path}")
    except PermissionError as e:
        print(f"\nError: Could not save to {imputed_data_path}. Permission denied.")
        print("This might be because the file is open in another program (e.g., Excel) or you lack write permissions.")
        print("Trying to save to a temporary file instead...")
        temp_path = os.path.join(output_dir, "imputed_female_farmers_data_temp.xlsx")
        try:
            improved_df.to_excel(temp_path, index=False)
            print(f"Dataset saved to temporary file: {temp_path}")
            print(f"Please close {imputed_data_path} in other programs and manually rename {temp_path} to {imputed_data_path}.")
        except Exception as e2:
            print(f"Error: Could not save to temporary file {temp_path} either: {e2}")
            print("Please check your permissions and ensure the output directory is writable.")
    except Exception as e:
        print(f"\nError saving the dataset: {e}")
        print("Please check the file path and ensure you have write permissions.")

if __name__ == "__main__":
    main()

FEMALE FARMERS DATA - STATUT FIELD IMPUTATION
ANALYSIS OF THE STATUT FIELD
Total rows: 80
Statut field - Missing values: 8
Statut field - Value counts:
Statut
0.0    61
1.0    11
NaN     8
Name: count, dtype: int64

IMPUTING THE STATUT FIELD WITH RANDOM FOREST
Training Random Forest model...
Model accuracy: 0.8667
Classification report:
              precision    recall  f1-score   support

         0.0       0.87      1.00      0.93        13
         1.0       0.00      0.00      0.00         2

    accuracy                           0.87        15
   macro avg       0.43      0.50      0.46        15
weighted avg       0.75      0.87      0.80        15


Feature importances:
                 Feature  Importance
0                    Age    0.204697
1    Ancienneté agricole    0.158840
2       H travail / jour    0.133433
8               Domicile    0.112464
3        J travail / Sem    0.101972
7       Nb pers à charge    0.088951
6             Nb enfants    0.062787
5     Situation 

In [9]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import os
import warnings
import re

# Set the style for visualizations
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 7)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
warnings.filterwarnings('ignore')

# Define paths - adjust as needed
imputed_data_path = "imputed_female_farmers_data.xlsx"
output_dir = "./"  # Current directory

def clean_filename(varname):
    """Create a safe filename by removing special characters"""
    return re.sub(r'[^\w\-]', '_', str(varname))

def analyze_statut_field(df):
    """Analyze the Statut field in the dataset"""
    print("=" * 50)
    print("ANALYSIS OF THE STATUT FIELD")
    print("=" * 50)
    
    # Check standard missing values
    na_count = df['Statut'].isna().sum()
    
    # Check for empty strings
    empty_string_count = (df['Statut'] == '').sum()
    
    # Total rows and valid values
    total_rows = len(df)
    valid_values = total_rows - (na_count + empty_string_count)
    
    print(f"Total rows: {total_rows}")
    print(f"Statut field - NaN values: {na_count}")
    print(f"Statut field - Empty strings: {empty_string_count}")
    print(f"Statut field - Valid values: {valid_values}")
    
    # Get value counts including missing values
    statut_vals = df['Statut'].copy()
    
    # Replace empty strings with NaN for proper value_counts
    if empty_string_count > 0:
        statut_vals = statut_vals.replace('', np.nan)
    
    print(f"Statut field - Value counts:\n{statut_vals.value_counts(dropna=False)}")
    
    # Visualize the distribution
    plt.figure(figsize=(10, 6))
    
    # Get value counts
    counts = statut_vals.value_counts(dropna=False)
    
    # Create readable labels
    labels = [str(val) if not pd.isna(val) else "Missing" for val in counts.index]
    
    # Create bar chart
    ax = plt.bar(labels, counts.values, color='#3498db', alpha=0.7)
    plt.title('Statut Distribution Before Imputation', fontsize=16)
    plt.xlabel('Statut Value', fontsize=14)
    plt.ylabel('Count', fontsize=14)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add value labels
    for i, v in enumerate(counts.values):
        plt.text(i, v + 0.5, str(int(v)), ha='center', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'statut_distribution_before.png'), dpi=300)
    plt.close()
    
    return empty_string_count > 0

def impute_statut_field(df):
    """Impute the Statut field using Random Forest"""
    print("\n" + "=" * 50)
    print("IMPUTING THE STATUT FIELD WITH RANDOM FOREST")
    print("=" * 50)
    
    # Create a copy of the dataset
    df = df.copy()
    
    # First, convert empty strings to NaN for proper imputation
    if 'Statut' in df.columns:
        empty_strings = (df['Statut'] == '').sum()
        if empty_strings > 0:
            print(f"Converting {empty_strings} empty strings to NaN in Statut field")
            df['Statut'] = df['Statut'].replace('', np.nan)
    
    # Define predictor variables based on domain knowledge
    predictors = [
        'Age', 'Ancienneté agricole', 'H travail / jour', 'J travail / Sem', 
        'Niveau scolaire', 'Situation maritale', 'Nb enfants', 'Nb pers à charge',
        'Domicile', 'AT en milieu agricole', 'Chemical_pesticides', 'Thermal_chaleur'
    ]
    
    # Ensure all predictors exist in the dataset
    predictors = [col for col in predictors if col in df.columns]
    if not predictors:
        print("No valid predictors found. Cannot impute Statut.")
        return df
    
    print(f"Using {len(predictors)} predictor variables for imputation")
    
    # Handle missing values in predictors
    for col in predictors:
        missing_count = df[col].isna().sum()
        if missing_count > 0:
            print(f"Filling {missing_count} missing values in predictor '{col}'")
            
            # For numeric columns
            if pd.api.types.is_numeric_dtype(df[col]):
                df[col] = df[col].fillna(df[col].median())
            # For categorical/object columns
            else:
                # Get mode without NA values
                mode_val = df[col].dropna().mode()
                if not mode_val.empty:
                    df[col] = df[col].fillna(mode_val[0])
                else:
                    print(f"  Warning: Cannot find mode for '{col}'. Dropping from predictors.")
                    predictors.remove(col)
    
    # Get subset with non-missing Statut for training
    train_df = df.dropna(subset=['Statut'])
    if len(train_df) < 10:
        print("Not enough non-missing Statut values to train a model.")
        print("Using mode imputation instead.")
        
        # Impute with mode
        mode_val = df['Statut'].dropna().mode()[0]
        df['Statut'] = df['Statut'].fillna(mode_val)
        print(f"Imputed missing Statut values with mode: {mode_val}")
        
        return df
    
    print(f"Training on {len(train_df)} complete records")
    
    # Split data for training and evaluation
    X = train_df[predictors]
    y = train_df['Statut']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Train a Random Forest classifier
    print("Training Random Forest model...")
    rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
    rf_model.fit(X_train, y_train)
    
    # Evaluate the model
    y_pred = rf_model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Model accuracy: {accuracy:.4f}")
    print("Classification report:")
    print(classification_report(y_test, y_pred))
    
    # Show confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Permanente (0)', 'Saisonnière (1)'],
                yticklabels=['Permanente (0)', 'Saisonnière (1)'])
    plt.title('Confusion Matrix', fontsize=16)
    plt.ylabel('True Label', fontsize=14)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'statut_confusion_matrix.png'), dpi=300)
    plt.close()
    
    # Feature importances
    feature_importances = pd.DataFrame({
        'Feature': predictors,
        'Importance': rf_model.feature_importances_
    }).sort_values(by='Importance', ascending=False)
    
    print("\nFeature importances:")
    print(feature_importances)
    
    # Visualize feature importances
    plt.figure(figsize=(10, 6))
    
    # Create horizontal bar chart
    bars = plt.barh(feature_importances['Feature'], feature_importances['Importance'], 
                    color='#2ecc71', alpha=0.7)
    plt.title('Feature Importance for Statut Imputation', fontsize=16)
    plt.xlabel('Importance', fontsize=14)
    plt.grid(axis='x', linestyle='--', alpha=0.7)
    
    # Add value labels
    for i, bar in enumerate(bars):
        width = bar.get_width()
        plt.text(width + 0.01, i, f'{width:.4f}', va='center', fontsize=11)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'statut_feature_importance.png'), dpi=300)
    plt.close()
    
    # Impute missing Statut values
    missing_statut_mask = df['Statut'].isna()
    missing_count = missing_statut_mask.sum()
    
    if missing_count > 0:
        print(f"\nImputing {missing_count} missing values in Statut field")
        
        # Get predictors for missing values
        X_missing = df.loc[missing_statut_mask, predictors]
        
        # Predict values
        predicted_statut = rf_model.predict(X_missing)
        
        # Apply predictions
        df.loc[missing_statut_mask, 'Statut'] = predicted_statut
        
        # Show imputation results
        value_counts = df['Statut'].value_counts(dropna=False)
        print(f"New Statut distribution:\n{value_counts}")
        
        # Show probabilities for imputed records
        proba = rf_model.predict_proba(X_missing)
        class_0_prob = proba[:, 0]  # Probability of class 0 (Permanente)
        class_1_prob = proba[:, 1]  # Probability of class 1 (Saisonnière)
        
        print("\nImputation probability summary:")
        print(f"Class 0 (Permanente) - Mean: {class_0_prob.mean():.4f}, Min: {class_0_prob.min():.4f}, Max: {class_0_prob.max():.4f}")
        print(f"Class 1 (Saisonnière) - Mean: {class_1_prob.mean():.4f}, Min: {class_1_prob.min():.4f}, Max: {class_1_prob.max():.4f}")
        
        # Visualize the imputation probabilities
        plt.figure(figsize=(10, 6))
        
        # Plot histogram of probabilities
        plt.hist(class_0_prob, alpha=0.5, bins=10, label='Permanente (0)')
        plt.hist(class_1_prob, alpha=0.5, bins=10, label='Saisonnière (1)')
        
        plt.title('Imputation Probabilities', fontsize=16)
        plt.xlabel('Probability', fontsize=14)
        plt.ylabel('Frequency', fontsize=14)
        plt.grid(linestyle='--', alpha=0.7)
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'statut_imputation_probabilities.png'), dpi=300)
        plt.close()
    else:
        print("\nNo missing values to impute in Statut field")
    
    return df

def validate_statut_imputation(original_df, imputed_df):
    """Validate the imputation of the Statut field"""
    print("\n" + "=" * 50)
    print("VALIDATING STATUT IMPUTATION")
    print("=" * 50)
    
    # Convert empty strings to NaN for proper analysis
    if 'Statut' in original_df.columns:
        original_df = original_df.copy()
        original_df['Statut'] = original_df['Statut'].replace('', np.nan)
    
    if 'Statut' in imputed_df.columns:
        imputed_df = imputed_df.copy()
        imputed_df['Statut'] = imputed_df['Statut'].replace('', np.nan)
    
    # Compare distributions
    plt.figure(figsize=(15, 7))
    
    # Original distribution
    plt.subplot(1, 2, 1)
    
    # Get value counts
    orig_counts = original_df['Statut'].value_counts(dropna=False)
    
    # Create readable labels
    labels = [str(val) if not pd.isna(val) else "Missing" for val in orig_counts.index]
    
    # Create bar chart
    bars = plt.bar(labels, orig_counts.values, color='blue', alpha=0.7)
    plt.title('Statut Distribution (Original)', fontsize=16)
    plt.xlabel('Statut Value', fontsize=14)
    plt.ylabel('Count', fontsize=14)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add value labels
    for i, v in enumerate(orig_counts.values):
        plt.text(i, v + 0.5, str(int(v)), ha='center', fontsize=12)
    
    # Imputed distribution
    plt.subplot(1, 2, 2)
    
    # Get value counts
    imp_counts = imputed_df['Statut'].value_counts(dropna=False)
    
    # Create readable labels
    labels = [str(val) if not pd.isna(val) else "Missing" for val in imp_counts.index]
    
    # Create bar chart
    bars = plt.bar(labels, imp_counts.values, color='red', alpha=0.7)
    plt.title('Statut Distribution (Imputed)', fontsize=16)
    plt.xlabel('Statut Value', fontsize=14)
    plt.ylabel('Count', fontsize=14)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add value labels
    for i, v in enumerate(imp_counts.values):
        plt.text(i, v + 0.5, str(int(v)), ha='center', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'statut_distribution_comparison.png'), dpi=300)
    plt.close()
    
    # Run statistical tests to compare distributions
    # First, align categories
    all_cats = sorted(set(list(orig_counts.index) + list(imp_counts.index)))
    # Remove NaN from all_cats (can't be compared directly)
    all_cats = [cat for cat in all_cats if not (isinstance(cat, float) and np.isnan(cat))]
    
    # Create aligned frequency arrays
    orig_freq = [orig_counts.get(cat, 0) for cat in all_cats]
    imp_freq = [imp_counts.get(cat, 0) for cat in all_cats]
    
    # Try running chi-squared test
    try:
        from scipy.stats import chi2_contingency
        
        # Create contingency table
        contingency_table = pd.DataFrame({
            'Original': orig_freq,
            'Imputed': imp_freq
        }, index=all_cats)
        
        print("\nContingency table for Chi-square test:")
        print(contingency_table)
        
        # Run chi-squared test
        chi2, p_val, _, _ = chi2_contingency(contingency_table)
        print(f"\nChi-squared test for distribution similarity:")
        print(f"Chi2 statistic: {chi2:.4f}, p-value: {p_val:.4f}")
        if p_val > 0.05:
            print("Distributions are not significantly different (p > 0.05)")
            print("This suggests the imputation preserved the original distribution.")
        else:
            print("Distributions are significantly different (p ≤ 0.05)")
            print("This suggests the imputation altered the distribution pattern.")
    except Exception as e:
        print(f"\nCould not perform Chi-squared test: {e}")

def main():
    """Main function to impute and validate the Statut field"""
    print("FEMALE FARMERS DATA - STATUT FIELD IMPUTATION")
    print("=" * 60)
    
    # Load the imputed dataset
    print(f"Loading dataset from: {imputed_data_path}")
    try:
        imputed_df = pd.read_excel(imputed_data_path)
        print(f"Dataset loaded successfully: {imputed_df.shape[0]} rows, {imputed_df.shape[1]} columns")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return
    
    # Step 1: Analyze the Statut field
    has_empty_strings = analyze_statut_field(imputed_df)
    
    # Step 2: Save original for comparison
    original_df = imputed_df.copy()
    
    # Step 3: Impute the Statut field
    improved_df = impute_statut_field(imputed_df)
    
    # Step 4: Validate the imputation
    validate_statut_imputation(original_df, improved_df)
    
    # Step 5: Save the updated dataset
    try:
        improved_df.to_excel(imputed_data_path, index=False)
        print(f"\nUpdated dataset saved to: {imputed_data_path}")
    except Exception as e:
        print(f"\nError saving dataset: {e}")
        # Save to an alternative filename if there's an error
        alt_path = "improved_female_farmers_data.xlsx"
        improved_df.to_excel(alt_path, index=False)
        print(f"Saved to alternative file instead: {alt_path}")
    
    print("\nProcess completed successfully!")

if __name__ == "__main__":
    main()

FEMALE FARMERS DATA - STATUT FIELD IMPUTATION
Loading dataset from: imputed_female_farmers_data.xlsx
Dataset loaded successfully: 80 rows, 74 columns
ANALYSIS OF THE STATUT FIELD
Total rows: 80
Statut field - NaN values: 0
Statut field - Empty strings: 0
Statut field - Valid values: 80
Statut field - Value counts:
Statut
0    68
1    12
Name: count, dtype: int64

IMPUTING THE STATUT FIELD WITH RANDOM FOREST
Using 12 predictor variables for imputation
Training on 80 complete records
Training Random Forest model...
Model accuracy: 0.8750
Classification report:
              precision    recall  f1-score   support

           0       0.88      1.00      0.93        14
           1       0.00      0.00      0.00         2

    accuracy                           0.88        16
   macro avg       0.44      0.50      0.47        16
weighted avg       0.77      0.88      0.82        16


Feature importances:
                  Feature  Importance
0                     Age    0.140082
8         