# Survival Analysis: IBM HR Analytics Employee Attrition
## Analyzing Time in Current Role

**Objective**: Understand how long employees stay in their current role and identify factors affecting role tenure.

**Dataset**: IBM HR Analytics Employee Attrition & Performance
- Available on Kaggle: https://www.kaggle.com/datasets/pavansubhasht/ibm-hr-analytics-attrition-dataset

**Survival Analysis Setup**:
- **Time variable**: YearsInCurrentRole
- **Event**: Role change (will use Attrition or promotion as proxy)
- **Censoring**: Employees still in current role

In [None]:
# Install required packages
!pip install lifelines pandas numpy matplotlib seaborn scikit-learn

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test, multivariate_logrank_test
from lifelines.plotting import plot_lifetimes
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
%matplotlib inline

## 1. Data Loading and Exploration

In [None]:
# Load the dataset
# Download from: https://www.kaggle.com/datasets/pavansubhasht/ibm-hr-analytics-attrition-dataset
df = pd.read_csv('WA_Fn-UseC_-HR-Employee-Attrition.csv')

print(f"Dataset shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
df.head()

In [None]:
# Check for relevant columns
print("Key columns for survival analysis:")
print(f"- YearsInCurrentRole: {df['YearsInCurrentRole'].describe()}")
print(f"\n- Attrition distribution: \n{df['Attrition'].value_counts()}")
print(f"\n- Missing values: \n{df.isnull().sum()[df.isnull().sum() > 0]}")

## 2. Survival Data Preparation

In [None]:
# Create survival dataset
survival_df = df.copy()

# Time variable: YearsInCurrentRole
survival_df['duration'] = survival_df['YearsInCurrentRole']

# Event variable: 1 if attrition occurred, 0 if censored (still employed)
survival_df['event'] = (survival_df['Attrition'] == 'Yes').astype(int)

# Handle zero duration (add small epsilon)
survival_df['duration'] = survival_df['duration'].replace(0, 0.5)

print(f"Total employees: {len(survival_df)}")
print(f"Events (attrition): {survival_df['event'].sum()}")
print(f"Censored (still employed): {(1 - survival_df['event']).sum()}")
print(f"\nDuration statistics:")
print(survival_df['duration'].describe())

In [None]:
# Encode categorical variables for Cox model
categorical_cols = ['BusinessTravel', 'Department', 'EducationField', 'Gender', 
                   'JobRole', 'MaritalStatus', 'OverTime']

survival_df_encoded = pd.get_dummies(survival_df, columns=categorical_cols, drop_first=True)

print(f"Encoded dataset shape: {survival_df_encoded.shape}")

## 3. Kaplan-Meier Analysis: Overall Survival

In [None]:
# Fit Kaplan-Meier estimator
kmf = KaplanMeierFitter()
kmf.fit(survival_df['duration'], survival_df['event'], label='All Employees')

# Plot survival function
fig, ax = plt.subplots(figsize=(12, 6))
kmf.plot_survival_function(ax=ax, ci_show=True)
plt.title('Kaplan-Meier Survival Curve: Time in Current Role', fontsize=14, fontweight='bold')
plt.xlabel('Years in Current Role', fontsize=12)
plt.ylabel('Probability of Staying in Role', fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print median survival time
print(f"Median time in current role: {kmf.median_survival_time_:.2f} years")
print(f"\nSurvival probabilities at key time points:")
for year in [1, 3, 5, 7, 10]:
    prob = kmf.predict(year)
    print(f"  {year} years: {prob:.1%}")

## 4. Stratified Kaplan-Meier: By Department

In [None]:
# Compare survival across departments
fig, ax = plt.subplots(figsize=(12, 6))

departments = survival_df['Department'].unique()
for dept in departments:
    mask = survival_df['Department'] == dept
    kmf_dept = KaplanMeierFitter()
    kmf_dept.fit(survival_df[mask]['duration'], 
                 survival_df[mask]['event'], 
                 label=dept)
    kmf_dept.plot_survival_function(ax=ax, ci_show=False)

plt.title('Survival Curves by Department', fontsize=14, fontweight='bold')
plt.xlabel('Years in Current Role', fontsize=12)
plt.ylabel('Probability of Staying in Role', fontsize=12)
plt.legend(title='Department')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Log-rank test for department comparison
groups = survival_df.groupby('Department')
result = multivariate_logrank_test(
    survival_df['duration'],
    survival_df['Department'],
    survival_df['event']
)
print(f"\nLog-rank test for Department:")
print(f"Test statistic: {result.test_statistic:.4f}")
print(f"p-value: {result.p_value:.4f}")
print(f"Significant difference: {'Yes' if result.p_value < 0.05 else 'No'}")

## 5. Stratified Analysis: By Job Satisfaction

In [None]:
# Compare survival by job satisfaction level
fig, ax = plt.subplots(figsize=(12, 6))

satisfaction_levels = sorted(survival_df['JobSatisfaction'].unique())
for level in satisfaction_levels:
    mask = survival_df['JobSatisfaction'] == level
    kmf_sat = KaplanMeierFitter()
    kmf_sat.fit(survival_df[mask]['duration'], 
                survival_df[mask]['event'], 
                label=f'Satisfaction {level}')
    kmf_sat.plot_survival_function(ax=ax, ci_show=False)

plt.title('Survival Curves by Job Satisfaction', fontsize=14, fontweight='bold')
plt.xlabel('Years in Current Role', fontsize=12)
plt.ylabel('Probability of Staying in Role', fontsize=12)
plt.legend(title='Job Satisfaction')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Log-rank test
result_sat = multivariate_logrank_test(
    survival_df['duration'],
    survival_df['JobSatisfaction'],
    survival_df['event']
)
print(f"\nLog-rank test for Job Satisfaction:")
print(f"Test statistic: {result_sat.test_statistic:.4f}")
print(f"p-value: {result_sat.p_value:.4f}")

## 6. Cox Proportional Hazards Model

In [None]:
# Select features for Cox model
cox_features = ['Age', 'DistanceFromHome', 'Education', 'EnvironmentSatisfaction',
                'JobInvolvement', 'JobLevel', 'JobSatisfaction', 'MonthlyIncome',
                'NumCompaniesWorked', 'PercentSalaryHike', 'PerformanceRating',
                'RelationshipSatisfaction', 'StockOptionLevel', 'TotalWorkingYears',
                'TrainingTimesLastYear', 'WorkLifeBalance', 'YearsAtCompany',
                'YearsSinceLastPromotion', 'YearsWithCurrManager']

# Add encoded categorical variables
encoded_cols = [col for col in survival_df_encoded.columns 
               if any(cat in col for cat in categorical_cols)]
cox_features.extend(encoded_cols)

# Create Cox dataset
cox_data = survival_df_encoded[cox_features + ['duration', 'event']].copy()

# Handle any missing values
cox_data = cox_data.dropna()

print(f"Cox model dataset shape: {cox_data.shape}")
print(f"Features: {len(cox_features)}")

In [None]:
# Fit Cox Proportional Hazards model
cph = CoxPHFitter(penalizer=0.1)  # L2 regularization
cph.fit(cox_data, duration_col='duration', event_col='event')

# Display model summary
print("Cox Proportional Hazards Model Summary:")
print(f"Concordance Index: {cph.concordance_index_:.4f}")
print(f"Log-likelihood: {cph.log_likelihood_:.4f}")
print(f"AIC: {cph.AIC_:.4f}")

In [None]:
# Show top significant coefficients
summary = cph.summary
summary['hazard_ratio'] = np.exp(summary['coef'])
summary_sorted = summary.sort_values('p', ascending=True)

print("\nTop 15 Most Significant Factors (p < 0.05):")
significant = summary_sorted[summary_sorted['p'] < 0.05].head(15)
print(significant[['coef', 'hazard_ratio', 'p']].to_string())

In [None]:
# Visualize hazard ratios for significant variables
fig, ax = plt.subplots(figsize=(10, 8))

significant_top = significant.head(10)
y_pos = np.arange(len(significant_top))
hazard_ratios = significant_top['hazard_ratio'].values
labels = significant_top.index

colors = ['red' if hr > 1 else 'green' for hr in hazard_ratios]
ax.barh(y_pos, hazard_ratios - 1, color=colors, alpha=0.6)
ax.axvline(0, color='black', linestyle='--', linewidth=1)
ax.set_yticks(y_pos)
ax.set_yticklabels(labels)
ax.set_xlabel('Hazard Ratio - 1 (relative to baseline)', fontsize=12)
ax.set_title('Top 10 Factors Affecting Role Tenure\n(Red: Increases attrition, Green: Decreases attrition)', 
            fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.show()

## 7. Model Validation: Proportional Hazards Assumption

In [None]:
# Check proportional hazards assumption using Schoenfeld residuals
cph.check_assumptions(cox_data, p_value_threshold=0.05, show_plots=True)

## 8. Survival Predictions for Individual Employees

In [None]:
# Select a few example employees
example_indices = [0, 10, 50, 100, 200]
example_employees = cox_data.iloc[example_indices].copy()

# Predict survival curves
fig, ax = plt.subplots(figsize=(12, 6))

for idx in example_indices:
    employee = cox_data.iloc[idx:idx+1]
    survival_func = cph.predict_survival_function(employee)
    survival_func.plot(ax=ax, label=f'Employee {idx}')

plt.title('Predicted Survival Curves for Sample Employees', fontsize=14, fontweight='bold')
plt.xlabel('Years in Current Role', fontsize=12)
plt.ylabel('Probability of Staying in Role', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 9. Risk Stratification

In [None]:
# Calculate risk scores for all employees
risk_scores = cph.predict_partial_hazard(cox_data)

# Stratify into risk groups
risk_percentiles = np.percentile(risk_scores, [33, 66])
risk_groups = pd.cut(risk_scores, 
                     bins=[0, risk_percentiles[0], risk_percentiles[1], np.inf],
                     labels=['Low Risk', 'Medium Risk', 'High Risk'])

# Add to dataframe
cox_data['risk_group'] = risk_groups

# Plot survival by risk group
fig, ax = plt.subplots(figsize=(12, 6))

for group in ['Low Risk', 'Medium Risk', 'High Risk']:
    mask = cox_data['risk_group'] == group
    kmf_risk = KaplanMeierFitter()
    kmf_risk.fit(cox_data[mask]['duration'], 
                 cox_data[mask]['event'], 
                 label=group)
    kmf_risk.plot_survival_function(ax=ax, ci_show=False)

plt.title('Survival Curves by Risk Stratification', fontsize=14, fontweight='bold')
plt.xlabel('Years in Current Role', fontsize=12)
plt.ylabel('Probability of Staying in Role', fontsize=12)
plt.legend(title='Risk Group')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("Risk Group Distribution:")
print(cox_data['risk_group'].value_counts())

## 10. Key Insights and Recommendations

In [None]:
# Summary statistics
print("=" * 80)
print("KEY FINDINGS")
print("=" * 80)

print(f"\n1. Overall Retention:")
print(f"   - Median time in current role: {kmf.median_survival_time_:.2f} years")
print(f"   - 3-year retention rate: {kmf.predict(3):.1%}")
print(f"   - 5-year retention rate: {kmf.predict(5):.1%}")

print(f"\n2. Top Risk Factors (increasing attrition):")
top_risk = summary_sorted[summary_sorted['hazard_ratio'] > 1].head(5)
for idx, (factor, row) in enumerate(top_risk.iterrows(), 1):
    print(f"   {idx}. {factor}: HR={row['hazard_ratio']:.3f} (p={row['p']:.4f})")

print(f"\n3. Top Protective Factors (decreasing attrition):")
top_protect = summary_sorted[summary_sorted['hazard_ratio'] < 1].head(5)
for idx, (factor, row) in enumerate(top_protect.iterrows(), 1):
    print(f"   {idx}. {factor}: HR={row['hazard_ratio']:.3f} (p={row['p']:.4f})")

print(f"\n4. Model Performance:")
print(f"   - Concordance Index: {cph.concordance_index_:.4f}")
print(f"   - High-risk employees: {(risk_groups == 'High Risk').sum()} ({(risk_groups == 'High Risk').sum()/len(risk_groups):.1%})")

print("\n" + "=" * 80)
print("RECOMMENDATIONS")
print("=" * 80)
print("1. Focus retention efforts on high-risk employees identified by the model")
print("2. Monitor employees with low job satisfaction and environment satisfaction")
print("3. Consider career development programs to reduce stagnation in roles")
print("4. Implement regular check-ins for employees approaching critical timepoints")
print("5. Address overtime and work-life balance issues proactively")
print("=" * 80)

## Next Steps

1. **Time-varying covariates**: Incorporate changes in satisfaction, performance over time
2. **Competing risks**: Model different exit types (promotion, lateral move, attrition)
3. **Machine learning**: Try Random Survival Forests or gradient boosting methods
4. **Intervention analysis**: Test impact of retention programs using before/after analysis
5. **Segmentation**: Build separate models for different departments or job levels