# Survival Analysis Lab

Complete the following exercises to solidify your knowledge of survival analysis.

In [1]:
import pandas as pd
import plotly as py
import cufflinks as cf
import numpy as np

from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test, multivariate_logrank_test

pd.options.plotting.backend = 'plotly'


cf.go_offline()

In [2]:
data = pd.read_csv('../data/attrition.csv')

data.columns

Index(['Age', 'Attrition', 'BusinessTravel', 'DailyRate', 'Department',
       'DistanceFromHome', 'Education', 'EducationField', 'EmployeeCount',
       'EmployeeNumber', 'EnvironmentSatisfaction', 'Gender', 'HourlyRate',
       'JobInvolvement', 'JobLevel', 'JobRole', 'JobSatisfaction',
       'MaritalStatus', 'MonthlyIncome', 'MonthlyRate', 'NumCompaniesWorked',
       'Over18', 'OverTime', 'PercentSalaryHike', 'PerformanceRating',
       'RelationshipSatisfaction', 'StandardHours', 'StockOptionLevel',
       'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance',
       'YearsAtCompany', 'YearsInCurrentRole', 'YearsSinceLastPromotion',
       'YearsWithCurrManager'],
      dtype='object')

In [3]:
data = data.sort_values('Age')
data['MonthlyRate']

1311     8018
457      8059
972     19305
301      9724
296     25233
        ...  
536     11924
427      2845
411      3854
879     10893
1209    20467
Name: MonthlyRate, Length: 1470, dtype: int64

## 1. Generate and plot a survival function that shows how employee retention rates vary by gender and employee age.

*Tip: If your lines have gaps in them, you can fill them in by using the `fillna(method=ffill)` and the `fillna(method=bfill)` methods and then taking the average. We have provided you with a revised survival function below that you can use for the exercises in this lab*

In [4]:
def survival(data, group_field, time_field, event_field):
    kmf = KaplanMeierFitter()
    results = []

    for i in data[group_field].unique():
        group = data[data[group_field]==i]
        T = group[time_field]
        E = group[event_field]
        kmf.fit(T, E, label=str(i))
        results.append(kmf.survival_function_)

    survival = pd.concat(results, axis=1).sort_values('timeline')
    front_fill = survival.fillna(method='ffill')
    back_fill = survival.fillna(method='bfill')
    smoothed = (front_fill + back_fill) / 2
    return smoothed

In [5]:
gender_age_attrition = survival(data, 'Gender', 'Age', 'Attrition')
print(gender_age_attrition)
gender_age_attrition.plot(kind='line')



            Female      Male
timeline                    
0.0       1.000000  1.000000
18.0      0.998299  0.996599
19.0      0.994880  0.992058
20.0      0.988019  0.989786
21.0      0.979413  0.988645
22.0      0.975964  0.985201
23.0      0.975964  0.980564
24.0      0.969018  0.977050
25.0      0.969018  0.969918
26.0      0.963684  0.959034
27.0      0.960054  0.957799
28.0      0.954396  0.943770
29.0      0.935192  0.933211
30.0      0.925201  0.927689
31.0      0.910549  0.911794
32.0      0.903950  0.899657
33.0      0.890079  0.890069
34.0      0.880378  0.881656
35.0      0.872470  0.868931
36.0      0.869628  0.858920
37.0      0.866500  0.848020
38.0      0.866500  0.843335
39.0      0.859062  0.833144
40.0      0.855104  0.822217
41.0      0.837381  0.816172
42.0      0.837381  0.809563
43.0      0.832212  0.805832
44.0      0.826589  0.785585
45.0      0.814610  0.785585
46.0      0.801145  0.775249
47.0      0.793587  0.763678
48.0      0.784961  0.757519
49.0      0.77

In [6]:
# def survival(data: pd.DataFrame, group_field: str,
#              time_field: str, event_field: str) -> tuple:
#     """
#       Función que recibe un DF y filtra por columna a partir del campo de 
#       agrupación. 

#       Regresa
#         Plot de Función de Supervivencia
#         Test de Logrank
#       """

#     model = KaplanMeierFitter()
#     results = []
#     con_exp = []

#     for i in data[group_field].unique():
#         group = data[data[group_field] == i]
#         T = group[time_field]
#         E = group[event_field]
#         con_exp.append([T, E])
#         model.fit(T, E, label=str(i))
#         results.append(model.survival_function_)

#     survival = pd.concat(results, axis=1)

#     log_result = logrank_test(con_exp[0][0], con_exp[1][0], 
#                             event_observed_A=con_exp[0][1], 
#                             event_observed_B=con_exp[1][1])
#     return survival, log_result

In [7]:

# survival, log_result = survival(data, 'Gender', 'Age', 'Attrition')
# print(log_result)
# survival.plot(kind='line')

## 2. Compare the plot above with one that plots employee retention rates by gender over the number of years the employee has been working for the company.

In [8]:
gender_nyears_attrition = survival(data, 'Gender', 'YearsAtCompany', 'Attrition')
print(gender_nyears_attrition)
gender_nyears_attrition.plot(kind='line')

            Female      Male
timeline                    
0.0       0.993197  0.986395
1.0       0.956797  0.942452
2.0       0.938000  0.920975
3.0       0.923879  0.902917
4.0       0.912585  0.881528
5.0       0.890624  0.861027
6.0       0.881315  0.848334
7.0       0.864235  0.834195
8.0       0.848663  0.820565
9.0       0.830314  0.808084
10.0      0.796878  0.763398
11.0      0.789826  0.757658
12.0      0.789826  0.757658
13.0      0.781241  0.750643
14.0      0.781241  0.734671
15.0      0.770391  0.734671
16.0      0.770391  0.725005
17.0      0.756634  0.725005
18.0      0.756634  0.714020
19.0      0.756634  0.702119
20.0      0.756634  0.689354
21.0      0.756634  0.672940
22.0      0.714598  0.672940
23.0      0.682117  0.647058
24.0      0.649635  0.647058
25.0      0.649635  0.647058
26.0      0.649635  0.647058
27.0      0.649635  0.647058
29.0      0.649635  0.647058
30.0      0.541362  0.647058
31.0      0.541362  0.597284
32.0      0.433090  0.597284
33.0      0.21

## 3. Let's look at retention rate by gender from a third perspective - the number of years since the employee's last promotion. Generate and plot a survival curve showing this.

In [9]:
gender_yearspromotion_attrition = survival(data, 'Gender', 'YearsSinceLastPromotion', 'Attrition')
print(gender_yearspromotion_attrition)
gender_yearspromotion_attrition.plot(kind='line')

            Female      Male
timeline                    
0.0       0.931973  0.920635
1.0       0.882785  0.868425
2.0       0.854697  0.812757
3.0       0.828797  0.797127
4.0       0.817041  0.783841
5.0       0.809874  0.778473
6.0       0.792074  0.753361
7.0       0.720068  0.689396
8.0       0.720068  0.689396
9.0       0.720068  0.637366
10.0      0.695238  0.637366
11.0      0.695238  0.605498
12.0      0.695238  0.605498
13.0      0.695238  0.544948
14.0      0.695238  0.503029
15.0      0.397279  0.503029


## 4. Let's switch to looking at retention rates from another demographic perspective: marital status. Generate and plot survival curves for the different marital statuses by number of years at the company.

In [10]:
marital_years_attrition = survival(data, 'MaritalStatus', 'YearsAtCompany', 'Attrition')
print(marital_years_attrition)
marital_years_attrition.plot(kind='line')

            Single   Married  Divorced
timeline                              
0.0       0.970213  0.997028  1.000000
1.0       0.904361  0.965353  0.975232
2.0       0.873585  0.950602  0.958063
3.0       0.858123  0.938070  0.932069
4.0       0.826763  0.932018  0.912153
5.0       0.786051  0.916410  0.907767
6.0       0.762818  0.910856  0.902129
7.0       0.753999  0.895776  0.876895
8.0       0.723432  0.885786  0.876895
9.0       0.699713  0.870177  0.876895
10.0      0.650733  0.823891  0.865652
11.0      0.641302  0.817026  0.865652
12.0      0.641302  0.817026  0.865652
13.0      0.629426  0.808603  0.865652
14.0      0.629426  0.789350  0.865652
15.0      0.614788  0.789350  0.865652
16.0      0.614788  0.776821  0.865652
17.0      0.597223  0.776821  0.865652
18.0      0.597223  0.776821  0.836797
19.0      0.576629  0.776821  0.836797
20.0      0.576629  0.759933  0.836797
21.0      0.540590  0.759933  0.836797
22.0      0.540590  0.731788  0.836797
23.0      0.506803  0.691

## 5. Let's also look at the marital status curves by employee age. Generate and plot the survival curves showing retention rates by marital status and age.

In [11]:
marital_age_attrition = survival(data, 'MaritalStatus', 'Age', 'Attrition')
print(marital_age_attrition)
marital_age_attrition.plot(kind='line')

            Single   Married  Divorced
timeline                              
0.0       1.000000  1.000000  1.000000
18.0      0.991489  0.997771  1.000000
19.0      0.978613  0.997771  1.000000
20.0      0.965651  0.997771  1.000000
21.0      0.952543  0.997771  1.000000
22.0      0.948102  0.995542  1.000000
23.0      0.945871  0.992539  0.996923
24.0      0.939115  0.986496  0.996923
25.0      0.934578  0.980360  0.996923
26.0      0.920801  0.975676  0.987337
27.0      0.918458  0.974079  0.984089
28.0      0.899019  0.967464  0.977372
29.0      0.876544  0.957280  0.966975
30.0      0.863615  0.951992  0.963284
31.0      0.842092  0.935451  0.959352
32.0      0.819785  0.929665  0.959352
33.0      0.796109  0.923549  0.955011
34.0      0.783670  0.917135  0.945828
35.0      0.773448  0.910257  0.925811
36.0      0.765902  0.905200  0.915290
37.0      0.757532  0.899629  0.904128
38.0      0.757532  0.893711  0.904128
39.0      0.747694  0.883961  0.897431
40.0      0.726331  0.880

## 6. Now that we have looked at the retention rates by gender and marital status individually, let's look at them together. 

Create a new field in the data set that concatenates marital status and gender, and then generate and plot a survival curve that shows the retention by this new field over the age of the employee.

In [12]:
data['gender_marital']= data['MaritalStatus'] + "_"+ data['Gender']
data['gender_marital']

gender_marital = survival(data, 'gender_marital', 'Age', 'Attrition')
gender_marital.plot()


## 6. Let's find out how job satisfaction affects retention rates. Generate and plot survival curves for each level of job satisfaction by number of years at the company.

In [13]:
jobsatisfaction_years_attrition = survival(data, 'JobSatisfaction', 'YearsAtCompany', 'Attrition')
jobsatisfaction_years_attrition.plot()

## 7. Let's investigate whether the department the employee works in has an impact on how long they stay with the company. Generate and plot survival curves showing retention by department and years the employee has worked at the company.

In [14]:
department_years_attrition = survival(data, 'Department', 'YearsAtCompany', 'Attrition')
department_years_attrition.plot()

## 8. From the previous example, it looks like the sales department has the highest attrition. Let's drill down on this and look at what the survival curves for specific job roles within that department look like.

Filter the data set for just the sales department and then generate and plot survival curves by job role and the number of years at the company.

In [15]:
sales_department = data[data['Department'] == 'Sales']

jobrole_years_attrition = survival(sales_department, 'JobRole', 'YearsAtCompany', 'Attrition')
jobrole_years_attrition.plot()

## 9. Let examine how compensation affects attrition.

- Use the `pd.qcut` method to bin the HourlyRate field into 5 different pay grade categories (Very Low, Low, Moderate, High, and Very High).
- Generate and plot survival curves showing employee retention by pay grade and age.

In [16]:
data['PayGrade'] = pd.qcut(data['HourlyRate'], q=5, labels=['Very low', 'Low', 'Moderate', 'High', 'Very High'])

paygrade_age_attrition = survival(data, 'PayGrade', 'Age', 'Attrition')
paygrade_age_attrition.plot()

## 10. Finally, let's take a look at how the demands of the job impact employee attrition.

- Create a new field whose values are 'Overtime' or 'Regular Hours' depending on whether there is a Yes or a No in the OverTime field.
- Create a new field that concatenates that field with the BusinessTravel field.
- Generate and plot survival curves showing employee retention based on these conditions and employee age.

In [17]:
data['OverTimeRegular'] = np.where(data['OverTime'] == 'Yes','Overtime','Regular_Hours')


In [18]:
data['OverTimeRegular_BusinessTravel'] = data['OverTimeRegular'] + '_' + data['BusinessTravel']
OverTimeRegular_BusinessTravel = survival(data, 'OverTimeRegular_BusinessTravel', 'Age', 'Attrition')
OverTimeRegular_BusinessTravel.plot()