In [None]:
# üìä Exploratory Data Analysis: IBM HR Employee Attrition
**Prepared by:** Sanika Jadhav  
**Objective:** Understand factors influencing employee attrition using data-driven insights.


In [None]:
# -----------------------------------------
# Import Libraries
# -----------------------------------------
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (10,5)

print("Libraries imported successfully.")


In [None]:
# -----------------------------------------
# Load Dataset (Auto-detect CSV file)
# -----------------------------------------

cwd = Path.cwd()
csv_files = list(cwd.glob("*.csv"))

dataset_path = None

preferred_name = "WA_Fn-UseC_-HR-Employee-Attrition.csv"

# If preferred dataset is present
if (cwd / preferred_name).exists():
    dataset_path = cwd / preferred_name
else:
    # Try to detect a file with HR-related keywords
    keywords = ["attrition", "employee", "hr", "WA_Fn"]
    for file in csv_files:
        if any(key.lower() in file.name.lower() for key in keywords):
            dataset_path = file
            break

# Fallback ‚Äî use first CSV if nothing matches
if dataset_path is None and csv_files:
    dataset_path = csv_files[0]

if dataset_path is None:
    raise FileNotFoundError("CSV dataset not found in this folder.")

df = pd.read_csv(dataset_path)
print("Dataset Loaded:", dataset_path.name)
df.head()


In [None]:
# -----------------------------------------
# Dataset Overview
# -----------------------------------------

print("Shape:", df.shape)
print("\nData Types:")
print(df.dtypes)

print("\nSample Rows:")
df.head()


In [None]:
# -----------------------------------------
# Missing Values & Duplicate Rows
# -----------------------------------------

missing = df.isnull().sum()
duplicates = df.duplicated().sum()

print("Missing Values:")
print(missing[missing > 0] if missing.sum() > 0 else "No missing values found.")

print("\nDuplicate Rows:", duplicates)


In [None]:
# -----------------------------------------
# Statistical Summary (Numeric Columns)
# -----------------------------------------

numeric_df = df.select_dtypes(include=[np.number])
numeric_df.describe().T


In [None]:
# -----------------------------------------
# Attrition Distribution
# -----------------------------------------

plt.figure()
ax = sns.countplot(x="Attrition", data=df, palette="Set2")
plt.title("Employee Attrition Count")

# Add labels
for p in ax.patches:
    ax.annotate(int(p.get_height()),
                (p.get_x() + p.get_width()/2, p.get_height()),
                ha='center', va='bottom')

plt.show()


In [None]:
# -----------------------------------------
# Age Distribution
# -----------------------------------------

plt.figure()
sns.histplot(df["Age"], bins=20, kde=True)
plt.title("Age Distribution of Employees")
plt.xlabel("Age")
plt.ylabel("Count")
plt.show()


In [None]:
# -----------------------------------------
# Age vs Attrition
# -----------------------------------------

plt.figure()
sns.boxplot(x="Attrition", y="Age", data=df, palette="Set3")
plt.title("Attrition vs Age")
plt.show()


In [None]:
# -----------------------------------------
# Monthly Income Distribution
# -----------------------------------------

plt.figure()
sns.histplot(df["MonthlyIncome"], bins=30, kde=True)
plt.title("Monthly Income Distribution")
plt.xlabel("Income")
plt.ylabel("Frequency")
plt.show()


In [None]:
# -----------------------------------------
# Monthly Income vs Attrition
# -----------------------------------------

plt.figure()
sns.boxplot(x="Attrition", y="MonthlyIncome", data=df, palette="coolwarm")
plt.yscale('symlog')  # handles outliers gracefully
plt.title("Attrition vs Monthly Income")
plt.show()


In [None]:
# -----------------------------------------
# Attrition by Job Role
# -----------------------------------------

plt.figure(figsize=(12,5))
order = df["JobRole"].value_counts().index

sns.countplot(x="JobRole", hue="Attrition", data=df, order=order, palette="husl")
plt.xticks(rotation=45, ha="right")
plt.title("Job Role vs Attrition")
plt.show()


In [None]:
# -----------------------------------------
# Correlation Heatmap
# -----------------------------------------

plt.figure(figsize=(14,8))
sns.heatmap(numeric_df.corr(), cmap="coolwarm", center=0)
plt.title("Correlation Heatmap (Numeric Features)")
plt.show()


In [None]:
## üîç Key Insights Summary

- Employees aged **25‚Äì35** show the highest attrition levels.  
- Lower income groups (first quartile) have a **significantly higher attrition rate**.  
- Roles like **Sales Representative** and **Laboratory Technician** face the most turnover.  
- MonthlyIncome, Age, and YearsAtCompany show strong relationships affecting attrition.  

This EDA helps HR identify risk groups and plan targeted retention strategies.


In [None]:
# -----------------------------------------
# Optional: Save Plots Programmatically
# -----------------------------------------

fig_dir = Path("figures")
fig_dir.mkdir(exist_ok=True)

plt.figure()
sns.countplot(x="Attrition", data=df)
plt.title("Attrition Count")
plt.savefig(fig_dir/"attrition_count.png")
plt.close()

print("Figures saved to 'figures' folder.")
