In [None]:
# Import required libraries for Snowflake ML
import pandas as pd
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from snowflake.snowpark import Window
import snowflake.snowpark.functions as F
from snowflake.snowpark.types import LongType
from snowflake.ml.modeling.preprocessing import OrdinalEncoder, OneHotEncoder, StandardScaler
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.metrics import accuracy_score, roc_auc_score
from snowflake.ml.registry import Registry

warnings.filterwarnings('ignore')

from snowflake.snowpark.context import get_active_session
session = get_active_session()

print("Libraries imported successfully")

In [None]:
import sys
snowflake_environment = session.sql('select current_user(), current_version()').collect()
from snowflake.snowpark.version import VERSION
from snowflake.ml import version

# Current Environment Details
print('User                        : {}'.format(snowflake_environment[0][0]))
print('Role                        : {}'.format(session.get_current_role()))
print('Database                    : {}'.format(session.get_current_database()))
print('Schema                      : {}'.format(session.get_current_schema()))
print('Warehouse                   : {}'.format(session.get_current_warehouse()))
print('Snowflake version           : {}'.format(snowflake_environment[0][1]))
print('Snowpark for Python version : {}.{}.{}'.format(VERSION[0],VERSION[1],VERSION[2]))
print('Snowflake ML version        : {}.{}.{}'.format(version.VERSION[0],version.VERSION[2],version.VERSION[4]))

In [None]:
select * from HR_EMPLOYEE_ATTRITION;

In [None]:
# Load raw data from HR_EMPLOYEE_ATTRITION table
raw_data_df = session.table("HR_EMPLOYEE_ATTRITION")

# We can also make references to the cells in the notebook
# raw_data_df = raw_data.to_df()

In [None]:
# Basic dataset overview using Snowpark DataFrame methods
print("📋 DATASET OVERVIEW")
print("=" * 50)

# Get row count and column info
row_count = raw_data_df.count()
column_count = len(raw_data_df.columns)
print(f"Shape: {row_count} rows × {column_count} columns")

# Show column names and types
print(f"\n📊 Columns and Types:")
for field in raw_data_df.schema.fields:
    print(f"   {field.name}: {field.datatype}")

print(f"\n📈 First 5 rows:")
st.dataframe(raw_data_df.limit(5))

# Check for missing values and data quality using Snowpark DataFrame
print("🔍 DATA QUALITY CHECKS")
print("=" * 50)

# Check for duplicates by comparing total rows vs distinct rows
total_rows = raw_data_df.count()
distinct_rows = raw_data_df.distinct().count()
duplicates = total_rows - distinct_rows
print(f"\n🔄 Duplicate rows: {duplicates}")

# Basic statistical summary for numerical columns
print(f"\n📊 Statistical Summary (key numerical columns):")
numerical_stats = raw_data_df.select([
    F.avg("AGE").alias("avg_age"),
    F.min("AGE").alias("min_age"), 
    F.max("AGE").alias("max_age"),
    F.avg("MONTHLY_INCOME").alias("avg_income"),
    F.min("MONTHLY_INCOME").alias("min_income"),
    F.max("MONTHLY_INCOME").alias("max_income"),
    F.avg("YEARS_AT_COMPANY").alias("avg_tenure"),
    F.max("YEARS_AT_COMPANY").alias("max_tenure")
])
st.dataframe(numerical_stats)

In [None]:
# Basic Data Discovery - Check missing values and Drop unnecessary columns
st.header("🔍 Basic Data Discovery")
st.markdown("---")

st.write("Data is cleaned and have no missing value.")
st.write("Since the number of columns are overwhelming, it is plausible to retain the most crucial valuesets in relation to our analytical purposes.")

# Check for problematic columns exactly as in Medium post
problematic_cols = ['EMPLOYEE_COUNT', 'EMPLOYEE_NUMBER', 'STANDARD_HOURS', 'OVER18', 'PERFORMANCE_RATING']
cols_to_drop = []

st.subheader("Columns to analyze:")
for col_name in problematic_cols:
    if col_name in raw_data_df.columns:
        unique_vals = raw_data_df.select(col_name).distinct().collect()
        unique_count = len(unique_vals)
        vals = [row[col_name] for row in unique_vals]
        
        if col_name == 'EMPLOYEE_COUNT':
            st.write(f"- **{col_name}**: consists of 1 value only ({vals}) - can be omitted")
            cols_to_drop.append(col_name)
        elif col_name == 'OVER18':
            st.write(f"- **{col_name}**: consists of 1 value only ({vals}) - can be omitted")  
            cols_to_drop.append(col_name)
        elif col_name == 'STANDARD_HOURS':
            st.write(f"- **{col_name}**: consists of 1 value only ({vals}) - can be omitted")
            cols_to_drop.append(col_name)
        elif col_name == 'EMPLOYEE_NUMBER':
            st.write(f"- **{col_name}**: is employee ID, which is unique for each entry - keeping for reference")
        elif col_name == 'PERFORMANCE_RATING':
            st.write(f"- **{col_name}**: consists of merely {unique_count} values ({vals}) - limited analytical value")
            cols_to_drop.append(col_name)

# Drop columns
if cols_to_drop:
    st.subheader("Drop some columns")
    st.write(f"Dropping: {cols_to_drop}")
    cleaned_df = raw_data_df.drop(*cols_to_drop)
    st.success(f"Dropped {len(cols_to_drop)} columns")
else:
    cleaned_df = raw_data_df

# Removing Outliers
st.subheader("Removing Outliers")
st.write("There are outliers in some columns, but either the number of outliers is small (<5%) or the outliers values are in realistic, reasonable range. Except for monthly_income column.")

# Calculate quartiles for monthly income
income_stats = cleaned_df.select([
    F.expr("percentile_cont(0.25) within group (order by MONTHLY_INCOME)").alias("Q1"),
    F.expr("percentile_cont(0.75) within group (order by MONTHLY_INCOME)").alias("Q3"),
    F.count("*").alias("total_rows")
]).collect()[0]

# Convert to float to avoid Decimal multiplication issues
Q1 = float(income_stats['Q1'])
Q3 = float(income_stats['Q3']) 
IQR = Q3 - Q1
lower_bound_calc = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR

# Income cannot be negative - use reasonable minimum instead
import builtins
lower_bound = builtins.max(0, lower_bound_calc)  # Ensure positive lower bound

# Show outlier analysis before removal
st.subheader("📊 Outlier Analysis - Monthly Income")

# Display outlier statistics
col1, col2, col3 = st.columns(3)
with col1:
    st.metric("Q1 (25th percentile)", f"${Q1:,.0f}")
with col2:
    st.metric("Q3 (75th percentile)", f"${Q3:,.0f}")
with col3:
    st.metric("IQR", f"${IQR:,.0f}")

col1, col2 = st.columns(2)
with col1:
    st.metric("Lower Bound", f"${lower_bound:,.0f}")
with col2:
    st.metric("Upper Bound", f"${upper_bound:,.0f}")

# Get sample data for visualization (convert small sample to pandas for plotting)
income_sample = cleaned_df.select("MONTHLY_INCOME").limit(1000).to_pandas()

# Create histogram showing outliers
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(income_sample['MONTHLY_INCOME'], bins=50, alpha=0.7, color='lightblue', edgecolor='black')
ax.axvline(lower_bound, color='red', linestyle='--', linewidth=2, label=f'Lower Bound: ${lower_bound:,.0f}')
ax.axvline(upper_bound, color='red', linestyle='--', linewidth=2, label=f'Upper Bound: ${upper_bound:,.0f}')
ax.axvline(Q1, color='green', linestyle='-', alpha=0.7, label=f'Q1: ${Q1:,.0f}')
ax.axvline(Q3, color='green', linestyle='-', alpha=0.7, label=f'Q3: ${Q3:,.0f}')
ax.set_xlabel('Monthly Income ($)')
ax.set_ylabel('Frequency')
ax.set_title('Monthly Income Distribution with Outlier Boundaries')
ax.legend()
ax.grid(axis='y', alpha=0.3)
st.pyplot(fig)
plt.close()

# Count and identify outliers
outliers_df = cleaned_df.filter(
    (F.col("MONTHLY_INCOME") < lower_bound) | (F.col("MONTHLY_INCOME") > upper_bound)
)
outlier_count = outliers_df.count()
total_count = cleaned_df.count()
outlier_percentage = (outlier_count / total_count) * 100

st.subheader("🎯 Rationale for Outlier Removal")

col1, col2 = st.columns(2)
with col1:
    st.metric("Outliers Found", f"{outlier_count} records")
with col2:
    st.metric("Percentage of Data", f"{outlier_percentage:.1f}%")

if lower_bound_calc < 0:
    st.warning(f"⚠️ Standard IQR lower bound would be ${lower_bound_calc:,.0f} (negative), adjusted to $0")

st.markdown(f"""
**Why we're removing these outliers:**

1. **Statistical Reason**: Using modified IQR method - values above ${upper_bound:,.0f} are considered outliers.
   - Original lower bound: ${lower_bound_calc:,.0f} → Adjusted to ${lower_bound:,.0f} (income cannot be negative)

2. **Business Logic**: 
   - **Lower bound** (${lower_bound:,.0f}): Set to zero since income cannot be negative
   - **Upper outliers** (> ${upper_bound:,.0f}): Likely executive compensation that could skew our analysis of typical employee attrition patterns

3. **Model Performance**: Extreme high values can:
   - Skew statistical measures (mean, standard deviation)
   - Reduce the effectiveness of machine learning algorithms
   - Make it harder to identify patterns for the majority of employees

4. **Data Quality**: Only {outlier_percentage:.1f}% of the data - small enough that removal won't significantly impact our analysis but will improve model quality.

**Decision**: Remove {outlier_count} outlier records to focus on typical employee compensation patterns.
""")

# Remove outliers
final_df = cleaned_df.filter(
    (F.col("MONTHLY_INCOME") >= lower_bound) & (F.col("MONTHLY_INCOME") <= upper_bound)
)

original_count = income_stats['TOTAL_ROWS']  # Snowflake uses uppercase column names
final_count = final_df.count()

st.success(f"✅ Shape of the dataset after cleaning: ({final_count}, {len(final_df.columns)})")
st.info(f"📉 Removed {original_count - final_count} outlier records from monthly_income")

# Store cleaned dataset
cleaned_df = final_df

In [None]:
# DEPARTMENT AND JOB ROLE ANALYSIS using Snowpark DataFrame with Interactive Streamlit Visualizations
st.header("🏢 Department & Job Role Analysis")
st.markdown("---")

# Get unique departments and job roles for dropdown options
# Use cleaned dataset if available, otherwise use raw dataset
analysis_df = cleaned_df if 'cleaned_df' in locals() else raw_data_df
departments = [row['DEPARTMENT'] for row in analysis_df.select("DEPARTMENT").distinct().collect()]
job_roles = [row['JOB_ROLE'] for row in analysis_df.select("JOB_ROLE").distinct().collect()]

# Create side-by-side columns for the interactive charts
col1, col2 = st.columns(2)

with col1:
    st.subheader("📊 Department Analysis")
    
    # Department dropdown
    selected_dept = st.selectbox("Select Department:", departments, key="dept_selector")
    
    # Filter data for selected department
    dept_data = analysis_df.filter(F.col("DEPARTMENT") == selected_dept)
    dept_total = dept_data.count()
    dept_attritioned = dept_data.filter(F.col("ATTRITION") == "Yes").count()
    dept_retained = dept_total - dept_attritioned
    dept_attrition_rate = (dept_attritioned / dept_total) * 100 if dept_total > 0 else 0
    
    # Display statistics
    st.metric(
        label="Attrition Rate", 
        value=f"{dept_attrition_rate:.1f}%",
        delta=f"{dept_attritioned} out of {dept_total} employees"
    )
    
    # Create pie chart for department
    if dept_total > 0:
        fig1, ax1 = plt.subplots(figsize=(8, 6))
        labels = ['Retained', 'Left Company']
        sizes = [dept_retained, dept_attritioned]
        colors = ['lightblue', 'salmon']
        explode = (0.05, 0.05)  # slightly separate the slices
        
        ax1.pie(sizes, labels=labels, autopct='%1.1f%%', colors=colors, 
                explode=explode, shadow=True, startangle=90)
        ax1.set_title(f'Attrition in {selected_dept}', fontsize=14, fontweight='bold')
        
        st.pyplot(fig1)
        plt.close()
    else:
        st.warning("No data available for selected department")

with col2:
    st.subheader("👔 Job Role Analysis")
    
    # Job role dropdown
    selected_role = st.selectbox("Select Job Role:", job_roles, key="role_selector")
    
    # Filter data for selected job role
    role_data = analysis_df.filter(F.col("JOB_ROLE") == selected_role)
    role_total = role_data.count()
    role_attritioned = role_data.filter(F.col("ATTRITION") == "Yes").count()
    role_retained = role_total - role_attritioned
    role_attrition_rate = (role_attritioned / role_total) * 100 if role_total > 0 else 0
    
    # Display statistics
    st.metric(
        label="Attrition Rate", 
        value=f"{role_attrition_rate:.1f}%",
        delta=f"{role_attritioned} out of {role_total} employees"
    )
    
    # Create pie chart for job role
    if role_total > 0:
        fig2, ax2 = plt.subplots(figsize=(8, 6))
        labels = ['Retained', 'Left Company']
        sizes = [role_retained, role_attritioned]
        colors = ['lightgreen', 'orange']
        explode = (0.05, 0.05)  # slightly separate the slices
        
        ax2.pie(sizes, labels=labels, autopct='%1.1f%%', colors=colors, 
                explode=explode, shadow=True, startangle=90)
        ax2.set_title(f'Attrition for {selected_role}', fontsize=14, fontweight='bold')
        
        st.pyplot(fig2)
        plt.close()
    else:
        st.warning("No data available for selected job role")

# Summary table showing all departments and roles
st.subheader("📈 Summary Tables")

# Create tabs for detailed breakdown
tab1, tab2 = st.tabs(["Department Breakdown", "Job Role Breakdown"])

with tab1:
    # Department analysis table using Snowpark DataFrame
    dept_analysis = analysis_df.group_by("DEPARTMENT").agg([
        F.count("*").alias("total_employees"),
        F.sum(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attritioned"),
        F.avg(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attrition_rate_decimal")
    ]).with_column("attrition_rate_pct", F.col("attrition_rate_decimal") * 100).order_by(F.col("attrition_rate_pct").desc())
    
    st.dataframe(dept_analysis, use_container_width=True)

with tab2:
    # Job role analysis table using Snowpark DataFrame
    role_analysis = analysis_df.group_by("JOB_ROLE").agg([
        F.count("*").alias("total_employees"),
        F.sum(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attritioned"),
        F.avg(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attrition_rate_decimal")
    ]).with_column("attrition_rate_pct", F.col("attrition_rate_decimal") * 100).order_by(F.col("attrition_rate_pct").desc())
    
    st.dataframe(role_analysis, use_container_width=True)

In [None]:
# OVERTIME ANALYSIS using Snowpark DataFrame
print("⏰ OVERTIME ANALYSIS")
print("=" * 50)

# Overtime impact on attrition using Snowpark DataFrame
overtime_analysis = raw_data_df.group_by("OVER_TIME").agg([
    F.count("*").alias("total_employees"),
    F.sum(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attritioned"),
    F.avg(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attrition_rate_decimal")
]).with_column("attrition_rate_pct", F.col("attrition_rate_decimal") * 100).order_by("OVER_TIME")

print("📊 Attrition by Overtime Status:")
overtime_results = overtime_analysis.collect()
overtime_dict = {}
for row in overtime_results:
    status = row['OVER_TIME']
    total = row['TOTAL_EMPLOYEES']
    attritioned = row['ATTRITIONED']
    rate = row['ATTRITION_RATE_PCT']
    print(f"  {status}: {total} total, {attritioned} left ({rate:.1f}%)")
    overtime_dict[status] = rate

# Calculate the difference if we have both Yes and No
if 'Yes' in overtime_dict and 'No' in overtime_dict:
    overtime_diff = overtime_dict['Yes'] - overtime_dict['No']
    print(f"\n📈 Overtime Impact: +{overtime_diff:.1f} percentage points")
    print(f"🎯 Expected from Medium article: ~20 percentage points difference")
    print(f"   (30.5% for overtime vs 10.4% for non-overtime)")

# Visualization using collected data
plt.figure(figsize=(10, 6))
overtime_status = list(overtime_dict.keys())
overtime_rates = list(overtime_dict.values())

plt.bar(overtime_status, overtime_rates, color=['mediumblue', 'lightblue'])
plt.title('Attrition Rate by Overtime Status', fontsize=14, fontweight='bold')
plt.xlabel('Overtime Status')
plt.ylabel('Attrition Rate (%)')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()


In [None]:
# 1. CORRELATION HEATMAP ANALYSIS
st.subheader("🔥 Feature Correlation Heatmap")
st.markdown("*Analyzing relationships between numerical features to identify patterns*")

# Get numerical columns for correlation analysis
# Exclude EMPLOYEE_NUMBER if it exists, and get sample for correlation calculation
numerical_columns = []
for field in cleaned_df.schema.fields:
    # Include numerical columns but exclude EMPLOYEE_NUMBER and ID-like fields  
    datatype_str = str(field.datatype)
    if (any(num_type in datatype_str for num_type in ['LongType', 'IntegerType', 'FloatType', 'DoubleType', 'DecimalType']) and 
        field.name not in ['EMPLOYEE_NUMBER', 'EMPLOYEE_COUNT']):
        numerical_columns.append(field.name)

if numerical_columns:
    # Get sample data for correlation calculation (using pandas for correlation matrix)
    st.info(f"📈 Analyzing correlations for {len(numerical_columns)} numerical features")
    
    # Sample data for correlation (limit to reasonable size for performance)
    correlation_sample = cleaned_df.select(*numerical_columns).limit(1000).to_pandas()
    
    # Calculate correlation matrix
    correlation_matrix = correlation_sample.corr()
    
    # Create heatmap
    fig, ax = plt.subplots(figsize=(14, 12))
    
    # Generate heatmap with better styling
    mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))  # Mask upper triangle
    sns.heatmap(correlation_matrix, 
                mask=mask,
                annot=True, 
                cmap='RdYlBu_r', 
                center=0,
                square=True,
                fmt='.2f', 
                cbar_kws={"shrink": .8},
                ax=ax)
    
    ax.set_title('Feature Correlation Heatmap\n(Lower Triangle Only)', fontsize=16, fontweight='bold', pad=20)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    st.pyplot(fig)
    plt.close()
    
    # Find highly correlated features (absolute correlation > 0.7)
    st.subheader("🔍 Highly Correlated Feature Pairs")
    high_corr_pairs = []
    import builtins  # Import builtins to access Python's built-in abs function
    for i in range(len(correlation_matrix.columns)):
        for j in range(i+1, len(correlation_matrix.columns)):
            corr_value = correlation_matrix.iloc[i, j]
            if builtins.abs(corr_value) > 0.7:
                high_corr_pairs.append({
                    'Feature 1': correlation_matrix.columns[i],
                    'Feature 2': correlation_matrix.columns[j],
                    'Correlation': corr_value
                })
    
    if high_corr_pairs:
        corr_df = pd.DataFrame(high_corr_pairs)
        corr_df = corr_df.reindex(corr_df['Correlation'].abs().sort_values(ascending=False).index)
        st.dataframe(corr_df, use_container_width=True)
    else:
        st.info("No highly correlated feature pairs found (threshold: |correlation| > 0.7)")
else:
    st.warning("No numerical columns found for correlation analysis")


In [None]:
# 2. GENDER AND AGE DISTRIBUTIONS
st.subheader("👥 Gender & Age Distribution Analysis")
st.markdown("*Understanding demographic patterns in our workforce*")

# Create side-by-side analysis
col1, col2 = st.columns(2)

with col1:
    st.markdown("**Gender Distribution**")
    
    # Gender distribution using Snowpark DataFrame
    gender_dist = cleaned_df.group_by("GENDER").agg([
        F.count("*").alias("count"),
        F.avg(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attrition_rate")
    ]).with_column("attrition_rate_pct", F.col("attrition_rate") * 100)
    
    gender_results = gender_dist.collect()
    
    # Prepare data for visualization
    genders = [row['GENDER'] for row in gender_results]
    gender_counts = [row['COUNT'] for row in gender_results]
    gender_attrition_rates = [row['ATTRITION_RATE_PCT'] for row in gender_results]
    
    # Create gender distribution pie chart
    fig1, ax1 = plt.subplots(figsize=(8, 6))
    colors = ['lightblue', 'lightpink']
    ax1.pie(gender_counts, labels=genders, autopct='%1.1f%%', colors=colors, 
            startangle=90, explode=(0.05, 0.05))
    ax1.set_title('Employee Distribution by Gender', fontsize=12, fontweight='bold')
    st.pyplot(fig1)
    plt.close()
    
    # Display gender attrition rates
    st.markdown("**Attrition Rate by Gender:**")
    for i, gender in enumerate(genders):
        st.write(f"- {gender}: {gender_attrition_rates[i]:.1f}%")

with col2:
    st.markdown("**Age Distribution**")
    
    # Get age statistics using Snowpark
    age_stats = cleaned_df.select([
        F.min("AGE").alias("min_age"),
        F.max("AGE").alias("max_age"),
        F.avg("AGE").alias("avg_age"),
        F.expr("percentile_cont(0.25) within group (order by AGE)").alias("Q1"),
        F.expr("percentile_cont(0.50) within group (order by AGE)").alias("median"),
        F.expr("percentile_cont(0.75) within group (order by AGE)").alias("Q3")
    ]).collect()[0]
    
    # Display age statistics
    st.metric("Average Age", f"{age_stats['AVG_AGE']:.1f} years")
    st.metric("Age Range", f"{age_stats['MIN_AGE']} - {age_stats['MAX_AGE']} years")
    st.metric("Median Age", f"{age_stats['MEDIAN']:.0f} years")
    
    # Get age distribution data for histogram
    age_sample = cleaned_df.select("AGE", "ATTRITION").limit(1000).to_pandas()
    
    # Create age distribution histogram
    fig2, ax2 = plt.subplots(figsize=(8, 6))
    ax2.hist(age_sample['AGE'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    ax2.axvline(age_stats['AVG_AGE'], color='red', linestyle='--', linewidth=2, 
                label=f'Mean: {age_stats["AVG_AGE"]:.1f}')
    ax2.axvline(age_stats['MEDIAN'], color='green', linestyle='--', linewidth=2, 
                label=f'Median: {age_stats["MEDIAN"]:.0f}')
    ax2.set_xlabel('Age')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Age Distribution of Employees')
    ax2.legend()
    ax2.grid(axis='y', alpha=0.3)
    st.pyplot(fig2)
    plt.close()


In [None]:
# 3. ATTRITION VS DENSITY BY GENDER
st.subheader("⚖️ Attrition Density Analysis by Gender")
st.markdown("*Comparing attrition patterns between male and female employees*")

# Get gender-specific attrition data using Snowpark
gender_attrition_data = cleaned_df.select("GENDER", "AGE", "ATTRITION", "MONTHLY_INCOME").to_pandas()

# Create density plots for age by gender and attrition
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Age density by gender and attrition
for i, gender in enumerate(['Male', 'Female']):
    gender_data = gender_attrition_data[gender_attrition_data['GENDER'] == gender]
    
    # Age density plot
    ax = axes[0, i]
    stayed = gender_data[gender_data['ATTRITION'] == 'No']['AGE']
    left = gender_data[gender_data['ATTRITION'] == 'Yes']['AGE']
    
    ax.hist(stayed, bins=20, alpha=0.7, label='Stayed', color='lightblue', density=True)
    ax.hist(left, bins=20, alpha=0.7, label='Left', color='salmon', density=True)
    ax.set_title(f'Age Distribution - {gender}')
    ax.set_xlabel('Age')
    ax.set_ylabel('Density')
    ax.legend()
    ax.grid(axis='y', alpha=0.3)

# Income density by gender and attrition  
for i, gender in enumerate(['Male', 'Female']):
    gender_data = gender_attrition_data[gender_attrition_data['GENDER'] == gender]
    
    # Income density plot
    ax = axes[1, i]
    stayed = gender_data[gender_data['ATTRITION'] == 'No']['MONTHLY_INCOME']
    left = gender_data[gender_data['ATTRITION'] == 'Yes']['MONTHLY_INCOME']
    
    ax.hist(stayed, bins=20, alpha=0.7, label='Stayed', color='lightgreen', density=True)
    ax.hist(left, bins=20, alpha=0.7, label='Left', color='orange', density=True)
    ax.set_title(f'Income Distribution - {gender}')
    ax.set_xlabel('Monthly Income ($)')
    ax.set_ylabel('Density')
    ax.legend()
    ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
st.pyplot(fig)
plt.close()

In [None]:
# MONTHLY INCOME CORRELATION WITH ATTRITION  
st.subheader("💰 Monthly Income Correlation with Attrition")  
st.markdown("*Analyzing how monthly income levels correlate with employee attrition rates*")  
# Get income and attrition data using Snowpark  
income_attrition_data = cleaned_df.select("MONTHLY_INCOME", "ATTRITION").to_pandas()  
# Create income analysis visualizations  
fig, axes = plt.subplots(2, 2, figsize=(15, 10))  
# Income distribution by attrition status  
ax = axes[0, 0]  
stayed = income_attrition_data[income_attrition_data['ATTRITION'] == 'No']['MONTHLY_INCOME']  
left = income_attrition_data[income_attrition_data['ATTRITION'] == 'Yes']['MONTHLY_INCOME']  
ax.hist(stayed, bins=30, alpha=0.7, label='Stayed', color='lightblue', density=True)  
ax.hist(left, bins=30, alpha=0.7, label='Left', color='salmon', density=True)  
ax.set_title('Monthly Income Distribution by Attrition')  
ax.set_xlabel('Monthly Income ($)')  
ax.set_ylabel('Density')  
ax.legend()  
ax.grid(axis='y', alpha=0.3)  
# Box plot comparison  
ax = axes[0, 1]  
income_attrition_data.boxplot(column='MONTHLY_INCOME', by='ATTRITION', ax=ax)  
ax.set_title('Monthly Income Box Plot by Attrition Status')  
ax.set_xlabel('Attrition Status')  
ax.set_ylabel('Monthly Income ($)')  
# Income bins analysis  
income_attrition_data['INCOME_BIN'] = pd.cut(income_attrition_data['MONTHLY_INCOME'],   
                                           bins=5, labels=['Very Low', 'Low', 'Medium', 'High', 'Very High'])  
# Attrition rate by income bin  
ax = axes[1, 0]  
attrition_by_income = income_attrition_data.groupby('INCOME_BIN')['ATTRITION'].apply(  
    lambda x: (x == 'Yes').sum() / len(x) * 100  
).reset_index()  
attrition_by_income.columns = ['INCOME_BIN', 'ATTRITION_RATE']  
bars = ax.bar(attrition_by_income['INCOME_BIN'], attrition_by_income['ATTRITION_RATE'],   
              color='coral', alpha=0.7)  
ax.set_title('Attrition Rate by Income Level')  
ax.set_xlabel('Income Level')  
ax.set_ylabel('Attrition Rate (%)')  
ax.grid(axis='y', alpha=0.3)  
# Add value labels on bars  
for bar in bars:  
    height = bar.get_height()  
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,  
            f'{height:.1f}%', ha='center', va='bottom')  
# Count by income bin and attrition  
ax = axes[1, 1]  
income_counts = income_attrition_data.groupby(['INCOME_BIN', 'ATTRITION']).size().unstack()  
income_counts.plot(kind='bar', ax=ax, color=['lightblue', 'salmon'], alpha=0.7)  
ax.set_title('Employee Count by Income Level and Attrition')  
ax.set_xlabel('Income Level')  
ax.set_ylabel('Employee Count')  
ax.legend(title='Attrition')  
ax.grid(axis='y', alpha=0.3)  
plt.setp(ax.get_xticklabels(), rotation=45)  
plt.tight_layout()  
st.pyplot(fig)  
plt.close()  
# Statistical summary  
st.subheader("📈 Income-Attrition Statistical Summary")  
col1, col2 = st.columns(2)  
with col1:  
    st.write("**Average Monthly Income by Attrition Status:**")  
    avg_income = income_attrition_data.groupby('ATTRITION')['MONTHLY_INCOME'].agg(['mean', 'median', 'std'])  
    st.dataframe(avg_income.round(2))  
with col2:  
    st.write("**Attrition Rate by Income Quartile:**")  
    income_attrition_data['INCOME_QUARTILE'] = pd.qcut(income_attrition_data['MONTHLY_INCOME'],   
                                                      q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'])  
    quartile_attrition = income_attrition_data.groupby('INCOME_QUARTILE')['ATTRITION'].apply(  
        lambda x: (x == 'Yes').sum() / len(x) * 100  
    ).round(2)  
    st.dataframe(quartile_attrition.to_frame('Attrition Rate (%)'))  

In [None]:
# 5. ATTRITION RATE BY POSITION (JOB ROLE)
st.subheader("🎯 Attrition Rate by Position")
st.markdown("*Identifying which job roles have the highest turnover risk*")

# Job role attrition analysis using Snowpark
job_role_analysis = cleaned_df.group_by("JOB_ROLE").agg([
    F.count("*").alias("total_employees"),
    F.sum(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attritioned"),
    F.avg(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attrition_rate_decimal")
]).with_column("attrition_rate_pct", F.col("attrition_rate_decimal") * 100)\
  .filter(F.col("total_employees") >= 5)\
  .order_by(F.col("attrition_rate_pct").desc())

job_role_results = job_role_analysis.collect()

# Display top positions with highest attrition
col1, col2 = st.columns([2, 1])

with col1:
    # Create horizontal bar chart for better readability
    positions = [row['JOB_ROLE'] for row in job_role_results[:10]]  # Top 10
    attrition_rates = [row['ATTRITION_RATE_PCT'] for row in job_role_results[:10]]
    
    fig, ax = plt.subplots(figsize=(12, 8))
    bars = ax.barh(positions, attrition_rates, color='salmon')
    ax.set_xlabel('Attrition Rate (%)')
    ax.set_title('Top 10 Positions by Attrition Rate', fontsize=14, fontweight='bold')
    ax.grid(axis='x', alpha=0.3)
    
    # Add percentage labels on bars
    for bar, rate in zip(bars, attrition_rates):
        width = bar.get_width()
        ax.text(width + 0.5, bar.get_y() + bar.get_height()/2, 
                f'{rate:.1f}%', ha='left', va='center', fontweight='bold')
    
    plt.tight_layout()
    st.pyplot(fig)
    plt.close()

with col2:
    st.markdown("**📊 Position Risk Summary**")
    for i, row in enumerate(job_role_results[:5]):  # Top 5 highest risk
        role = row['JOB_ROLE']
        rate = row['ATTRITION_RATE_PCT']
        total = row['TOTAL_EMPLOYEES']
        attritioned = row['ATTRITIONED']
        
        st.write(f"**{i+1}. {role}**")
        st.write(f"   Rate: {rate:.1f}%")
        st.write(f"   ({attritioned}/{total} employees)")
        st.write("")


In [None]:
# 6. JOB SATISFACTION ANALYSIS
st.subheader("😊 Job Satisfaction Analysis")
st.markdown("*Understanding the relationship between job satisfaction and employee retention*")

# Job satisfaction analysis using Snowpark DataFrame
satisfaction_analysis = cleaned_df.group_by("JOB_SATISFACTION").agg([
    F.count("*").alias("total_employees"),
    F.sum(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attritioned"),
    F.avg(F.when(F.col("ATTRITION") == "Yes", 1).otherwise(0)).alias("attrition_rate_decimal")
]).with_column("attrition_rate_pct", F.col("attrition_rate_decimal") * 100)\
  .order_by("JOB_SATISFACTION")

satisfaction_results = satisfaction_analysis.collect()

# Create side-by-side analysis
col1, col2 = st.columns(2)

with col1:
    st.markdown("**📊 Satisfaction Levels Distribution**")
    
    # Satisfaction level distribution
    satisfaction_levels = [row['JOB_SATISFACTION'] for row in satisfaction_results]
    satisfaction_counts = [row['TOTAL_EMPLOYEES'] for row in satisfaction_results]
    
    fig1, ax1 = plt.subplots(figsize=(8, 6))
    bars = ax1.bar(satisfaction_levels, satisfaction_counts, color='lightblue', alpha=0.8)
    ax1.set_xlabel('Job Satisfaction Level')
    ax1.set_ylabel('Number of Employees')
    ax1.set_title('Employee Distribution by Job Satisfaction Level')
    ax1.grid(axis='y', alpha=0.3)
    
    # Add count labels on bars
    for bar, count in zip(bars, satisfaction_counts):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + height*0.01,
                f'{count}', ha='center', va='bottom', fontweight='bold')
    
    st.pyplot(fig1)
    plt.close()

with col2:
    st.markdown("**⚠️ Attrition Rate by Satisfaction Level**")
    
    # Attrition rate by satisfaction level
    attrition_rates = [row['ATTRITION_RATE_PCT'] for row in satisfaction_results]
    
    fig2, ax2 = plt.subplots(figsize=(8, 6))
    bars = ax2.bar(satisfaction_levels, attrition_rates, color='salmon', alpha=0.8)
    ax2.set_xlabel('Job Satisfaction Level')
    ax2.set_ylabel('Attrition Rate (%)')
    ax2.set_title('Attrition Rate by Job Satisfaction Level')
    ax2.grid(axis='y', alpha=0.3)
    
    # Add percentage labels on bars
    for bar, rate in zip(bars, attrition_rates):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + height*0.01,
                f'{rate:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    st.pyplot(fig2)
    plt.close()

# Detailed satisfaction analysis table
st.subheader("📋 Detailed Job Satisfaction Analysis")

# Create a more readable table
satisfaction_display = []
for row in satisfaction_results:
    satisfaction_display.append({
        'Satisfaction Level': f"Level {row['JOB_SATISFACTION']}",
        'Total Employees': row['TOTAL_EMPLOYEES'],
        'Employees Who Left': row['ATTRITIONED'],
        'Attrition Rate (%)': f"{row['ATTRITION_RATE_PCT']:.1f}%"
    })
satisfaction_df = pd.DataFrame(satisfaction_display)
st.dataframe(satisfaction_df, use_container_width=True)

In [None]:
# 7. PAIRWISE PLOTS FOR KEY FEATURES  
st.subheader("🔗 Pairwise Relationships Analysis")
# Define specific columns for pairwise analysis (matching your sample code)
cols = ['TOTAL_WORKING_YEARS', 'YEARS_AT_COMPANY', 'YEARS_IN_CURRENT_ROLE', 
        'YEARS_SINCE_LAST_PROMOTION', 'ATTRITION', 'JOB_LEVEL']
# Check which features exist in our dataset
available_features = []
for feature in cols:
    if feature in [field.name for field in cleaned_df.schema.fields]:
        available_features.append(feature)

if len(available_features) >= 3:
    st.info(f"📊 Creating pairwise plots for: {', '.join(available_features)}")
    
    # Get sample data for pairwise plotting (pandas required for seaborn pairplot)
    pairwise_sample = cleaned_df.select(*available_features).limit(500).to_pandas()
    
    # Create pairwise plot with seaborn - simple approach matching your sample
    st.subheader("📈 Pairwise Feature Relationships")
    
    # Use seaborn pairplot with hue for attrition (simple approach)
    pair_plot = sns.pairplot(pairwise_sample, hue='ATTRITION')
    
    # Customize the plot
    pair_plot.fig.suptitle('Pairwise Relationships - Tenure Features vs Attrition', 
                          fontsize=16, fontweight='bold', y=1.02)
    
    st.pyplot(pair_plot.fig)
    plt.close()

In [None]:
# # ========================================
# # FEATURE ENGINEERING & MODEL PREPARATION
# # ========================================
# st.header("🛠️ Feature Engineering & Model Preparation")
# st.markdown("---")
# st.markdown("*Preparing data for machine learning following Snowflake ML best practices*")

# # Display current dataset info
# total_rows = cleaned_df.count()
# total_cols = len(cleaned_df.columns)
# st.info(f"📊 **Starting Dataset**: {total_rows:,} rows × {total_cols} columns")

# # Check for EMPLOYEE_NUMBER column (to exclude from modeling)
# schema_fields = [field.name for field in cleaned_df.schema.fields]
# if 'EMPLOYEE_NUMBER' in schema_fields:
#     st.warning("📋 **Note**: EMPLOYEE_NUMBER will be kept for reference but excluded from model training")

# print("🚀 Starting Feature Engineering Pipeline...")
# print("=" * 60)


In [None]:
# Feature engineering pipeline
print("Starting feature engineering...")

# Define column types for encoding
ordinal_columns = [
    'EDUCATION', 'ENVIRONMENT_SATISFACTION', 'JOB_LEVEL',
    'JOB_SATISFACTION', 'RELATIONSHIP_SATISFACTION', 'WORK_LIFE_BALANCE'
]

# Categorize columns
categorical_columns = []
ordinal_columns_present = []
numerical_columns = []
target_column = 'ATTRITION'
exclude_columns = ['EMPLOYEE_NUMBER'] if 'EMPLOYEE_NUMBER' in [f.name for f in cleaned_df.schema.fields] else []

# Analyze columns and categorize
for field in cleaned_df.schema.fields:
    col_name = field.name
    datatype_str = str(field.datatype)
    
    if col_name == target_column or col_name in exclude_columns:
        continue
    
    if col_name in ordinal_columns:
        ordinal_columns_present.append(col_name)
    elif any(num_type in datatype_str for num_type in ['LongType', 'IntegerType', 'FloatType', 'DoubleType', 'DecimalType']):
        unique_count = cleaned_df.select(col_name).distinct().count()
        if unique_count <= 10 and col_name not in ordinal_columns:
            categorical_columns.append(col_name)
        else:
            numerical_columns.append(col_name)
    else:
        categorical_columns.append(col_name)

print(f"Ordinal columns: {len(ordinal_columns_present)}")
print(f"Categorical columns: {len(categorical_columns)}")
print(f"Numerical columns: {len(numerical_columns)}")

In [None]:
# Apply feature encoding
feature_df = cleaned_df

# Ordinal encoding
if ordinal_columns_present:
    print("Applying ordinal encoding...")
    ordinal_encoder = OrdinalEncoder(
        input_cols=ordinal_columns_present,
        output_cols=[f"{col}_ORDINAL" for col in ordinal_columns_present]
    )
    ordinal_encoder.fit(feature_df)
    feature_df = ordinal_encoder.transform(feature_df)
    feature_df = feature_df.drop(*ordinal_columns_present)
    ordinal_encoded_columns = [f"{col}_ORDINAL" for col in ordinal_columns_present]
    print(f"Ordinal encoded: {len(ordinal_encoded_columns)} columns")
else:
    ordinal_encoded_columns = []

# One-hot encoding
if categorical_columns:
    print("Applying one-hot encoding...")
    ohe = OneHotEncoder(
        input_cols=categorical_columns,
        output_cols=[f"{col}_ONEHOT" for col in categorical_columns]
    )
    ohe.fit(feature_df)
    feature_df = ohe.transform(feature_df)
    feature_df = feature_df.drop(*categorical_columns)
    onehot_encoded_columns = [f"{col}_ONEHOT" for col in categorical_columns]
    print(f"One-hot encoded: {len(onehot_encoded_columns)} columns")
else:
    onehot_encoded_columns = []

# Standard scaling for numerical features
if numerical_columns:
    print("Applying standard scaling...")
    scaler = StandardScaler(
        input_cols=numerical_columns,
        output_cols=[f"{col}_SCALED" for col in numerical_columns]
    )
    scaler.fit(feature_df)
    feature_df = scaler.transform(feature_df)
    feature_df = feature_df.drop(*numerical_columns)
    scaled_columns = [f"{col}_SCALED" for col in numerical_columns]
    print(f"Scaled: {len(scaled_columns)} columns")
else:
    scaled_columns = []

print("Feature engineering complete!")

In [None]:
# Convert ATTRITION to numeric (0/1) for ML
from snowflake.snowpark.functions import col, when
feature_df = feature_df.with_column("ATTRITION", 
    when(col("ATTRITION") == "Yes", 1).otherwise(0).cast(LongType()))

# Train/test split (80/20)
print("Creating train/test split...")
train_df, test_df = feature_df.random_split(weights=[0.8, 0.2], seed=42)

train_count = train_df.count()
test_count = test_df.count()
print(f"Training set: {train_count} samples")
print(f"Test set: {test_count} samples")

In [None]:
train_df.write.save_as_table('HR_ANALYTICS.ML_MODELING.HR_EMPLOYEE_ATTRITION_TRAIN_DF', mode = 'overwrite')
test_df.write.save_as_table('HR_ANALYTICS.ML_MODELING.HR_EMPLOYEE_ATTRITION_TEST_DF', mode = 'overwrite')

In [None]:
train_df = session.read.table("HR_ANALYTICS.ML_MODELING.HR_EMPLOYEE_ATTRITION_TRAIN_DF")
test_df = session.read.table("HR_ANALYTICS.ML_MODELING.HR_EMPLOYEE_ATTRITION_TEST_DF")

In [None]:
test_df.show()

In [None]:
wh = str(session.get_current_warehouse()).strip('"')
print(f"Current warehouse: {wh}")
print(session.sql(f"SHOW WAREHOUSES LIKE '{wh}';").collect())

session.sql(f"alter warehouse {session.get_current_warehouse()} set WAREHOUSE_SIZE = LARGE WAIT_FOR_COMPLETION = TRUE").collect()

print(session.sql(f"SHOW WAREHOUSES LIKE '{wh}';").collect())

In [None]:
# ========================================
# XGBOOST MODEL WITH GRID SEARCH
# ========================================
st.header("🚀 XGBoost Classifier with Grid Search")
# Import required libraries
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.metrics import accuracy_score, roc_auc_score
import pandas as pd

print("🚀 Training XGBoost with Grid Search...")

# Define target column and feature columns
target_column = ['ATTRITION']
output_column = ['PRED_ATTRITION']
exclude_columns = ['EMPLOYEE_NUMBER'] if 'EMPLOYEE_NUMBER' in train_df.columns else []

# Get all feature columns (exclude target and passthrough columns)
all_columns = train_df.columns
model_feature_columns = [col for col in all_columns if col not in target_column and col not in exclude_columns]

print(f"📊 Using {len(model_feature_columns)} features for training")
print(f"🎯 Target column: {target_column}")


# Create GridSearchCV
model_pipeline = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid={
        'n_estimators': [50, 100],
        'learning_rate': [0.01, 0.1],
        'max_depth': range(2,6,1)
    },
    n_jobs=-1,
    input_cols=model_feature_columns,
    passthrough_cols = exclude_columns,
    label_cols=target_column,
    output_cols=output_column,
)

print("🔧 Fitting XGBoost model with Grid Search...")

# Fit the model with grid search
fitted_model = model_pipeline.fit(train_df)

print("✅ XGBoost Grid Search training completed!")


# Make predictions
print("🔮 Making predictions...")
xgb_gs_train = model_pipeline.predict(train_df)
xgb_gs_predictions = model_pipeline.predict(test_df)

st.success("✅ **XGBoost with Grid Search trained successfully!**")

In [None]:
# Calculate accuracy scores
train_accuracy = accuracy_score(df=xgb_gs_train, y_true_col_names=target_column, y_pred_col_names=output_column)
test_accuracy = accuracy_score(df=xgb_gs_predictions, y_true_col_names=target_column, y_pred_col_names=output_column)
print(f'Test Accuracy: {test_accuracy}')

# ROC AUC scores
train_auc = roc_auc_score(df=xgb_gs_train, y_true_col_names=target_column, y_score_col_names=output_column)
test_auc = roc_auc_score(df=xgb_gs_predictions, y_true_col_names=target_column, y_score_col_names=output_column)
print(f'Test AUC: {test_auc}')


In [None]:
# Grid Search Results Analysis and Feature Importance
st.header("📊 Grid Search Analysis & Feature Importance")
st.markdown("---")

print("📈 Analyzing Grid Search Results...")

# Use the fitted model
fitted_model = model_pipeline

# Get grid search results
gs_results = fitted_model.to_sklearn().cv_results_
n_estimators_val = []
learning_rate_val = []
for param_dict in gs_results["params"]:
    n_estimators_val.append(param_dict["n_estimators"])
    learning_rate_val.append(param_dict["learning_rate"])
mape_val = gs_results["mean_test_score"]

gs_results_df = pd.DataFrame(data={
    "n_estimators": n_estimators_val,
    "learning_rate": learning_rate_val,
    "mape": mape_val
})

st.subheader("🔍 Grid Search Parameter Performance")

# Display grid search results summary
col1, col2 = st.columns(2)

with col1:
    st.write("**📋 Grid Search Results Summary:**")
    st.dataframe(gs_results_df.sort_values('mape', ascending=False).head(10))

with col2:
    # Plot grid search results
    sns.set_context("notebook", font_scale=0.5)
    fig = sns.relplot(data=gs_results_df, x="learning_rate", y="mape", hue="n_estimators", kind="line", height=3)
    fig.set_xlabels('Learning Rate')
    fig.set_ylabels('Mean Test Score (AUC)')
    plt.title('Grid Search Results: Learning Rate vs Performance')
    st.pyplot(fig)
    plt.close()

# Display best parameters and results
st.subheader("🏆 Best Model Results")

print("Results from Grid Search")
print("\\n The best estimator across ALL searched params:\\n", fitted_model.to_sklearn().best_estimator_)
print("\\n The best score across ALL searched params:\\n", fitted_model.to_sklearn().best_score_)
print("\\n The best parameters across ALL searched params:\\n", fitted_model.to_sklearn().best_params_)

# Display in Streamlit
col1, col2, col3 = st.columns(3)

with col1:
    st.metric("Best Score (AUC)", f"{fitted_model.to_sklearn().best_score_:.4f}")

with col2:
    best_params = fitted_model.to_sklearn().best_params_
    st.write("**Best Parameters:**")
    for param, value in best_params.items():
        st.write(f"- {param}: {value}")

with col3:
    st.write("**Model Performance:**")
    st.write(f"- Training AUC: {train_auc:.4f}")
    st.write(f"- Test AUC: {test_auc:.4f}")
    st.write(f"- Training Accuracy: {train_accuracy:.4f}")
    st.write(f"- Test Accuracy: {test_accuracy:.4f}")

# Feature Importance Analysis
st.subheader("🎯 Feature Importance Analysis")

# Get feature importance from the BEST estimator (this was the issue!)
best_estimator = fitted_model.to_sklearn().best_estimator_
feature_names = fitted_model.to_sklearn().feature_names_in_
feature_importances = best_estimator.feature_importances_

# Create feature importance DataFrame
feat_importance = pd.DataFrame({
    'Feature': feature_names,
    'FeatImportance': feature_importances
}).sort_values('FeatImportance', ascending=True)

# Plot feature importance
fig, ax = plt.subplots(figsize=(10, 12))
feat_importance.plot.barh(x='Feature', y='FeatImportance', ax=ax, color='skyblue')
ax.set_title('Feature Importance - XGBoost Model', fontsize=14, fontweight='bold')
ax.set_xlabel('Feature Importance')
ax.grid(axis='x', alpha=0.3)

st.pyplot(fig)
plt.close()

# Display top features
st.subheader("🔝 Top 10 Most Important Features")
top_features = feat_importance.tail(10).sort_values('FeatImportance', ascending=False)
st.dataframe(top_features, use_container_width=True)

print("✅ Grid Search Analysis and Feature Importance completed!")


In [None]:
wh = str(session.get_current_warehouse()).strip('"')
print(f"Current warehouse: {wh}")
print(session.sql(f"SHOW WAREHOUSES LIKE '{wh}';").collect())

session.sql(f"alter warehouse {session.get_current_warehouse()} set WAREHOUSE_SIZE = XSMALL WAIT_FOR_COMPLETION = TRUE").collect()

print(session.sql(f"SHOW WAREHOUSES LIKE '{wh}';").collect())

In [None]:
# FUNCTION used to iterate the model version so we can automatically 
# create the next version number
import ast
import builtins  # Import the builtins module
#from snowflake.snowpark import functions as F 

def get_next_version(reg, model_name) -> str:
    """
    Returns the next version of a model based on the existing versions in the registry.

    Args:
        reg: The registry object that provides access to the models.
        model_name: The name of the model.

    Returns:
        str: The next version of the model in the format "V_".

    Raises:
        ValueError: If the version list for the model is empty or if the version format is invalid.
    """
    models = reg.show_models()
    if models.empty:
        return "V_1"
    elif model_name not in models["name"].to_list():
        return "V_1"
    max_version_number = builtins.max(  
        [
            int(version.split("_")[-1])
            for version in ast.literal_eval(
                models.loc[models["name"] == model_name, "versions"].values[0]
            )
        ]
    )
    return f"V_{max_version_number + 1}"

In [None]:
# Let's now register the CV Classfier model into the model_registry
Reg = Registry(
    session=session,
    database_name=session.get_current_database(),
    schema_name='ML_MODELING',
)

In [None]:
model_name = 'EMPLOYEE_ATTRITION_XGBOOST'
model_version = get_next_version(Reg, model_name)

# Get model parameters for comment
model_params = model_pipeline.to_sklearn().get_params()
param_str = f"n_estimators={model_params.get('n_estimators')}, learning_rate={model_params.get('learning_rate')}, max_depth={model_params.get('max_depth')}"

mv = Reg.log_model(fitted_model,
    model_name=model_name,
    version_name=model_version,
    conda_dependencies=["snowflake-ml-python"],
    comment=f"XGBoost model - Params: {param_str}",
    metrics={"Test Acc": test_accuracy, "Test AUC": test_auc, "Train AUC": train_auc, "Train Acc": train_accuracy}, # We can save our model metrics here
    options= {"relax_version": False, "enable_explainability": True, 'case_sensitive': True},
    
)
m = Reg.get_model(model_name)
m.default = model_version

In [None]:
Reg.get_model(model_name).show_versions()

In [None]:
prod_model = Reg.get_model("EMPLOYEE_ATTRITION_XGBOOST").last() #or we can use .version()

In [None]:
# prod_model.run(test_df, function_name = 'PREDICT_PROBA')