# SNH-AI Customer Segmentation Analysis

This notebook analyzes the customer segments identified by the KMeans clustering model and explores predictions for transaction counts.

## 1. Setup & Imports

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from supabase import create_client, Client
import joblib # To load saved models/preprocessors
import subprocess # To optionally run the setup and main pipeline

# Add project root to path to import config and logger
project_root = os.path.abspath('.') # Assumes notebook is run from dev/snh-ai
src_path = os.path.join(project_root, 'src')
if project_root not in sys.path:
    sys.path.insert(0, project_root)
if src_path not in sys.path:
     sys.path.insert(0, src_path)

try:
    import config
    from src import snh_logger as snh_logging
    logger = snh_logging.get_logger("AnalysisNotebook")
except ImportError as e:
    print(f"Error importing project modules: {e}. Ensure config.py and src/snh_logger.py exist.")
    # Fallback basic logger if needed
    import logging
    logger = logging.getLogger("AnalysisNotebook_Fallback")
    logger.addHandler(logging.StreamHandler(sys.stdout))
    logger.setLevel(logging.INFO)

# Configure plotting style
sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)

## 1.1. (Optional) Run Setup and Full Data Pipeline

Uncomment and modify the cell below (`run_setup = True`, `run_pipeline = True`) **only** if you need to perform initial setup (directory checks, `.env.example` creation via `setup.py`) AND regenerate all the data in the Supabase tables (`raw_customer_data`, `cleaned_customer_data`, `transformed_customer_data`, `customer_segments`) from scratch by executing the `process.py` script. 

**Warning:** Running `process.py` will clear existing data in those tables and replace it based on the current pipeline configuration. It may take several seconds to run.

In [None]:
# Set flags to control execution
run_setup = False # Set to True to run setup.py first
run_pipeline = False # Set to True to run the full process.py pipeline

def run_script(script_name):
    # Use the python executable from the currently running kernel/env
    python_executable = sys.executable 
    command = [python_executable, script_name]
    try:
        print(f"Running script: {' '.join(command)}...")
        # Use check=True to raise an error if the script fails
        # Capture output to display it in the notebook
        result = subprocess.run(command, capture_output=True, text=True, check=True, cwd=project_root)
        print(f"--- {script_name} stdout: ---")
        print(result.stdout)
        if result.stderr:
             print(f"--- {script_name} stderr: ---")
             print(result.stderr)
        print(f"--- {script_name} execution complete. ---")
        return True
    except FileNotFoundError:
        print(f"ERROR: '{python_executable}' command not found or script '{script_name}' not found. Make sure Python is in your PATH and the script exists.")
        return False
    except subprocess.CalledProcessError as e:
        print(f"ERROR: {script_name} failed with exit code {e.returncode}")
        print(f"--- {script_name} stdout (error): ---")
        print(e.stdout)
        print(f"--- {script_name} stderr (error): ---")
        print(e.stderr)
        return False
    except Exception as e:
        print(f"An unexpected error occurred while running {script_name}: {e}")
        return False

# --- Execution Control --- 
setup_success = True # Assume success if setup is skipped
if run_setup:
    print("*** Executing setup.py ***")
    setup_success = run_script('setup.py')
    if not setup_success:
         print("Halting due to setup.py failure.")
    else:
         print("*** Setup script finished. ***")
else:
    print("Setup execution skipped (run_setup=False).")
    
# Run pipeline only if setup was skipped or succeeded
pipeline_run_needed = run_pipeline and setup_success

if pipeline_run_needed:
     print("\n*** Executing process.py (Full Pipeline) ***")
     pipeline_success = run_script('process.py')
     if not pipeline_success:
         print("Pipeline execution failed.")
     else:
         print("*** Full pipeline execution finished. ***")
elif run_pipeline and not setup_success:
    # Case where setup failed and pipeline was requested
    print("Skipped pipeline execution because setup failed.")
else:
    print("\nPipeline execution skipped (run_pipeline=False).")

# Default message if both are false
if not run_setup and not run_pipeline:
     print("Setup and Pipeline execution skipped (set run_setup=True and/or run_pipeline=True in cell to execute).")

## 2. Configuration & Supabase Connection

In [None]:
# Load credentials and initialize Supabase client
supabase: Client | None = None
try:
    if not config.SUPABASE_URL or not config.SUPABASE_SERVICE_ROLE:
        logger.error("Supabase URL or Service Role not found in config.")
    else:
        supabase = create_client(config.SUPABASE_URL, config.SUPABASE_SERVICE_ROLE)
        logger.info("Supabase client initialized successfully.")
except Exception as e:
    logger.error(f"Failed to initialize Supabase client: {e}")

# Table names
CLEANED_TABLE = 'cleaned_customer_data'
SEGMENTS_TABLE = 'customer_segments'
TRANSFORMED_TABLE = 'transformed_customer_data' # Optional for plotting scaled data
PREDICTIONS_TABLE = 'transaction_predictions' # Optional

## 3. Load Data from Supabase

In [None]:
def fetch_data(supabase_client: Client, table_name: str, select_cols: str = "*") -> pd.DataFrame:
    """Fetches data from a Supabase table into a pandas DataFrame."""
    df = pd.DataFrame() # Initialize empty DataFrame
    if not supabase_client:
        logger.error("Supabase client is not initialized.")
        return df
    try:
        logger.info(f"Fetching data from {table_name}... ({select_cols})")
        response = supabase_client.table(table_name).select(select_cols).execute()
        if response.data:
            df = pd.DataFrame(response.data)
            logger.info(f"Successfully fetched {len(df)} rows from {table_name}.")
        else:
            logger.warning(f"No data returned from {table_name}. Check if pipeline ran correctly.")
    except Exception as e:
        logger.error(f"Error fetching data from {table_name}: {e}")
    return df

# Fetch cleaned data (contains original features)
cleaned_df = fetch_data(supabase, CLEANED_TABLE)
# Ensure correct types from DB fetch (Supabase client sometimes returns things as strings/objects)
if not cleaned_df.empty:
    cleaned_df['age'] = pd.to_numeric(cleaned_df['age'], errors='coerce').astype('Int64')
    cleaned_df['annual_income'] = pd.to_numeric(cleaned_df['annual_income'], errors='coerce')
    cleaned_df['total_transactions'] = pd.to_numeric(cleaned_df['total_transactions'], errors='coerce').astype('Int64')
    cleaned_df['customer_id'] = cleaned_df['customer_id'].astype(str)
    cleaned_df['region'] = cleaned_df['region'].astype(str)

# Fetch segment assignments
segments_df = fetch_data(supabase, SEGMENTS_TABLE, "customer_id, pattern_id")
if not segments_df.empty:
     segments_df['customer_id'] = segments_df['customer_id'].astype(str)
     segments_df['pattern_id'] = pd.to_numeric(segments_df['pattern_id'], errors='coerce').astype('Int64') 

# Optional: Fetch transformed data for plotting scaled features
transformed_df = fetch_data(supabase, TRANSFORMED_TABLE)
if not transformed_df.empty:
     transformed_df['customer_id'] = transformed_df['customer_id'].astype(str)
     # Convert scaled/OHE columns back to numeric if needed
     for col in transformed_df.columns:
         if col != 'customer_id' and col != 'transformed_at':
              transformed_df[col] = pd.to_numeric(transformed_df[col], errors='coerce')

# Optional: Fetch predictions
predictions_df = fetch_data(supabase, PREDICTIONS_TABLE)
if not predictions_df.empty:
    predictions_df['customer_id'] = predictions_df['customer_id'].astype(str)
    predictions_df['predicted_total_transactions'] = pd.to_numeric(predictions_df['predicted_total_transactions'], errors='coerce')

print("--- Cleaned Data Head: ---")
display(cleaned_df.head()) if not cleaned_df.empty else print("Cleaned data is empty.")
print("\n--- Segment Assignments Head: ---")
display(segments_df.head()) if not segments_df.empty else print("Segment assignments data is empty.")
print("\n--- Transformed Data Head (Optional): ---")
display(transformed_df.head()) if not transformed_df.empty else print("Transformed data is empty or not loaded.")
print("\n--- Predictions Data Head (Optional): ---")
display(predictions_df.head()) if not predictions_df.empty else print("Predictions data is empty or not loaded.")

## 4. Merge Data for Analysis

In [None]:
# Merge segment assignments with cleaned data
merged_df = pd.DataFrame() # Initialize as empty
if not cleaned_df.empty and not segments_df.empty:
    merged_df = pd.merge(cleaned_df, segments_df, on='customer_id', how='left')
    # Ensure pattern_id is treated as a categorical variable for plotting/analysis
    if 'pattern_id' in merged_df.columns and not merged_df['pattern_id'].isnull().all():
        merged_df['pattern_id'] = pd.Categorical(merged_df['pattern_id'])
        logger.info(f"Merged data shape: {merged_df.shape}")
        print("--- Merged Data Head: ---")
        display(merged_df.head())
        print("\n--- Merged Data Info: ---")
        merged_df.info()
    else:
        logger.warning("Merge completed, but 'pattern_id' column is missing or all null.")
else:
    logger.error("Could not merge dataframes, cleaned_df or segments_df is empty.")

## 5. Segment Analysis

### 5.1. Segment Sizes

In [None]:
if not merged_df.empty and 'pattern_id' in merged_df.columns and merged_df['pattern_id'].notna().any():
    segment_counts = merged_df['pattern_id'].value_counts().sort_index()
    print("--- Segment Sizes (Value Counts): ---")
    display(segment_counts)

    # Plotting segment sizes
    plt.figure(figsize=(8, 5))
    sns.countplot(data=merged_df.dropna(subset=['pattern_id']), x='pattern_id', palette='viridis', order=segment_counts.index)
    plt.title('Customer Count per Segment (pattern_id)')
    plt.xlabel('Segment (pattern_id)')
    plt.ylabel('Number of Customers')
    plt.show()
else:
    logger.warning("Merged DataFrame is empty or missing 'pattern_id'. Cannot analyze segment sizes.")

### 5.2. Feature Analysis by Segment

In [None]:
# Calculate descriptive statistics for numerical features grouped by segment
numerical_features = ['age', 'annual_income', 'total_transactions']
if not merged_df.empty and 'pattern_id' in merged_df.columns and merged_df['pattern_id'].notna().any():
    # Ensure correct numeric types before aggregation
    valid_numerical_features = []
    for col in numerical_features:
         if col in merged_df.columns:
              merged_df[col] = pd.to_numeric(merged_df[col], errors='coerce')
              valid_numerical_features.append(col)
         else:
              logger.warning(f"Numerical feature '{col}' not found in merged_df for summary stats.")

    # Perform aggregation only if pattern_id is not all NaN and we have valid features
    if merged_df['pattern_id'].notna().any() and valid_numerical_features:
        segment_summary = merged_df.groupby('pattern_id', observed=False)[valid_numerical_features].agg(['mean', 'median', 'std', 'count'])
        print("--- Summary Statistics by Segment: ---")
        # Display with rounded values for readability
        try:
            display(segment_summary.style.format("{:.2f}")) # Use display() in Jupyter
        except NameError:
            print(segment_summary.round(2))
    else:
        logger.warning("'pattern_id' column contains only null values or no valid numerical features found. Cannot group for summary stats.")
else:
    logger.warning("Cannot calculate segment summary statistics - check merged_df and pattern_id.")

In [None]:
# Visualize distributions of numerical features across segments
if not merged_df.empty and 'pattern_id' in merged_df.columns and merged_df['pattern_id'].notna().any():
    for col in numerical_features: # Use the original list for iteration
        if col in merged_df.columns:
            plt.figure(figsize=(10, 6))
            # Ensure plot orders by category index if pattern_id is categorical
            plot_order = sorted(merged_df['pattern_id'].cat.categories) if pd.api.types.is_categorical_dtype(merged_df['pattern_id']) else None
            sns.boxplot(data=merged_df, x='pattern_id', y=col, palette='viridis', order=plot_order)
            plt.title(f'Distribution of {col} by Segment (pattern_id)')
            plt.xlabel('Segment (pattern_id)')
            plt.ylabel(col)
            plt.show()
        else:
             logger.warning(f"Column '{col}' not found for plotting distribution.")
else:
     logger.warning("Cannot visualize numerical feature distributions - check merged_df and pattern_id.")

### 5.3. Region Analysis by Segment

In [None]:
# Analyze region distribution within each segment
if not merged_df.empty and 'pattern_id' in merged_df.columns and 'region' in merged_df.columns and merged_df['pattern_id'].notna().any():
    region_distribution = pd.crosstab(merged_df['pattern_id'], merged_df['region'])
    region_distribution_norm = pd.crosstab(merged_df['pattern_id'], merged_df['region'], normalize='index') * 100
    print("--- Region Distribution (Counts) within each Segment: ---")
    try:
        display(region_distribution)
    except NameError:
        print(region_distribution)
    print("\n--- Region Distribution (%) within each Segment: ---")
    try:
        display(region_distribution_norm.style.format("{:.1f}%"))
    except NameError:
        print(region_distribution_norm.round(1).astype(str) + '%')

    # Plotting the distribution
    region_distribution_norm.plot(kind='bar', stacked=True, figsize=(12, 7), colormap='viridis')
    plt.title('Region Distribution by Customer Segment')
    plt.xlabel('Segment (pattern_id)')
    plt.ylabel('Percentage of Customers (%)')
    plt.legend(title='Region', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=0)
    plt.tight_layout()
    plt.show()
else:
     logger.warning("Cannot analyze region distribution - check merged_df, pattern_id, region.")

## 6. Visualization (Optional Advanced Plots)

In [None]:
# Example: Scatter plot using scaled features (requires transformed_df to be loaded)
if not transformed_df.empty and not segments_df.empty:
    plot_df = pd.merge(transformed_df, segments_df, on='customer_id', how='left')
    if 'pattern_id' in plot_df.columns and not plot_df['pattern_id'].isnull().all():
        plot_df['pattern_id'] = pd.Categorical(plot_df['pattern_id'])
        
        # Check if required scaled columns exist
        if 'age_scaled' in plot_df.columns and 'annual_income_scaled' in plot_df.columns:
            plt.figure(figsize=(10, 7))
            sns.scatterplot(data=plot_df.dropna(subset=['pattern_id']), \
                            x='age_scaled', y='annual_income_scaled', \
                            hue='pattern_id', palette='viridis', s=70, alpha=0.7)
            plt.title('Customer Segments based on Scaled Age and Income')
            plt.xlabel('Age (Standardized)')
            plt.ylabel('Annual Income (Standardized)')
            plt.legend(title='Segment (pattern_id)')
            plt.show()
        else:
            logger.warning("Scaled columns ('age_scaled', 'annual_income_scaled') not found. Skipping scatter plot.")
    else:
         logger.warning("Segment IDs missing or all null after merge, skipping scatter plot.")
else:
     logger.warning("Transformed data or segment data not loaded or empty, skipping scatter plot.")

## 7. Predictive Model Insights (Optional)

In [None]:
# Load the saved RandomForestRegressor model and preprocessor for transaction prediction
model_dir = config.MODEL_OUTPUT_DIR if hasattr(config, 'MODEL_OUTPUT_DIR') else 'models'
model_filename = os.path.join(model_dir, "transactions_predictor_model.joblib")
preprocessor_filename = os.path.join(model_dir, "transactions_predictor_preprocessor.joblib")

try:
    rf_model = joblib.load(model_filename)
    preprocessor = joblib.load(preprocessor_filename)
    logger.info("Successfully loaded transactions predictor model and preprocessor.")

    # Get feature importances
    if hasattr(rf_model, 'feature_importances_'):
        # Get feature names from the preprocessor
        try:
             ohe_feature_names = preprocessor.named_transformers_['cat'].get_feature_names_out(['region'])
             # Extract passthrough feature names correctly
             if hasattr(preprocessor, 'feature_names_in_'):
                 input_features = preprocessor.feature_names_in_
                 original_cat_cols = preprocessor.transformers_[0][2] # Assuming OHE is first
                 passthrough_features = [col for col in input_features if col not in original_cat_cols]
                 processed_feature_names = np.concatenate([ohe_feature_names, passthrough_features])
             else:
                  passthrough_features = ['age'] # Fallback assumption
                  processed_feature_names = np.concatenate([ohe_feature_names, passthrough_features])

        except Exception as e_feat:
             logger.warning(f"Could not get feature names from preprocessor: {e_feat}")
             # Fallback: generate generic feature names based on number of importances
             num_features = len(rf_model.feature_importances_)
             processed_feature_names = [f'feature_{i}' for i in range(num_features)]

        importances = rf_model.feature_importances_

        # Ensure lengths match
        if len(processed_feature_names) == len(importances):
            feature_importance_df = pd.DataFrame({'feature': processed_feature_names, 'importance': importances})
            feature_importance_df = feature_importance_df.sort_values(by='importance', ascending=False)

            print("--- Feature Importances for Predicting Total Transactions: ---")
            try:
                display(feature_importance_df)
            except NameError:
                 print(feature_importance_df)

            # Plot feature importances
            plt.figure(figsize=(10, 6))
            sns.barplot(data=feature_importance_df, x='importance', y='feature', palette='viridis')
            plt.title('Feature Importance for Predicting Total Transactions')
            plt.xlabel('Importance Score')
            plt.ylabel('Feature')
            plt.tight_layout()
            plt.show()
        else:
            logger.warning(f"Mismatch between number of feature names ({len(processed_feature_names)}) and importances ({len(importances)}). Skipping importance plot.")

    else:
        logger.warning("Could not retrieve feature importances from the loaded model.")

except FileNotFoundError:
    logger.error(f"Model or preprocessor file not found. Ensure '{model_filename}' and '{preprocessor_filename}' exist.")
except Exception as e:
    logger.error(f"Error loading or analyzing predictive model: {e}")

## 8. Insights & Conclusion

*   **(Fill In)** Summary of Segments: Briefly describe the key characteristics of each identified segment (pattern_id 0-5) based on the analysis above (e.g., \"Segment 0 represents younger customers with lower income and fewer transactions, primarily from region X\").
*   **(Fill In)** Predictive Insights: Mention key findings from the RandomForestRegressor if analyzed (e.g., \"Age was found to be the most important predictor of total transactions... Region X also showed higher transaction counts...\").
*   **(Fill In)** Business Implications: Suggest potential business actions based on these segments (e.g., targeted marketing campaigns, different service levels, promotions for specific regions/age groups).
*   **(Fill In)** Future Work: Mention potential next steps (e.g., using different clustering algorithms, adding more features, deploying models, A/B testing strategies based on segments).