# Stroke Prediction

1. Importing Libraries

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

# Setting visualisation parameters
sns.set_style('darkgrid')
cmap = sns.cm.mako_r

%matplotlib inline

# Preventing warnings from libraries especially scikit learn
import warnings
warnings.filterwarnings('ignore')

2. Importing data and viewing basic details

In [None]:
stroke = pd.read_csv('../input/stroke-prediction-dataset/healthcare-dataset-stroke-data.csv')

In [None]:
stroke.head()

In [None]:
# Viewing the shape of the data in (row, column) format
stroke.shape

In [None]:
stroke.info()

3. Preprocessing Data before exploration

In [None]:
stroke.drop(columns=['id']).describe()

1. Using **round()** to round off age.


2. Setting values to NaN where BMI is less than 12 and greater than 60. We were told in the dataset that these values should be considered **outliers** and therefore should not be considered when building a model.


3. We will sort the dataframe based on **gender** and then on **age** and use **forward filling** to fill out those missing BMI values.

In [None]:
# Round off age values
stroke['age'] = stroke['age'].apply(lambda x : round(x))

# BMI to NaN
stroke['bmi'] = stroke['bmi'].apply(lambda bmi_value: bmi_value if 12 < bmi_value < 60 else np.nan)

# Sorting dataframe based on gender then on age and using forward fill-ffill() to fill NaN value for BMI
stroke.sort_values(['gender', 'age'], inplace = True)
stroke.reset_index(drop=True, inplace=True)
stroke['bmi'].ffill(inplace=True)

In [None]:
stroke.info()

We have now converted our age column to int64 and have no missing values in our bmi column

4. Exploratory data analysis

1. Check if the data is balanced


2. Plotting various graphs to check for any relation between each column

    - Age vs BMI
    - BMI vs AVG glucose level
    - Percentage of people who had a stroke in each category

In [None]:
# Checking if the data is balanced
xs = stroke['stroke'].value_counts().index
ys = stroke['stroke'].value_counts().values

ax = sns.barplot(xs, ys)
ax.set_xlabel("Stroke")
plt.show()

As we can see from the plot, the data is not balanced, this will result in a badly fitted model. To resolve this issue, we need to use SMOTE to balance the data. This will be done before fitting the model.

In [None]:
# Age vs BMI
plt.figure(figsize = (12,8))
ax = sns.scatterplot(x="bmi", y="age", alpha=0.4, data=stroke[stroke['stroke']==0])
sns.scatterplot(x = "bmi", y="age", alpha=1, data=stroke[stroke['stroke']==1], ax=ax)
plt.show()

From the above Age vs BMI plot we can clearly see that when people attain an age of 40 or greater, the chances of having a stroke increases. After age 60, it tends to increase even more. Furthermore, people with a BMI of over 20-25 have shown a greatly increased chance of having a stroke.

So, from this plot we can conclude that people who are aged over 40 and have a BMI of over 20-25 have a grater probability of having a stroke.

In [None]:
# AVG Glucose level vs BMI with hue = stroke
plt.figure(figsize = (12,8))
ax = sns.scatterplot(x="bmi", y="avg_glucose_level", alpha=0.4, data=stroke[stroke['stroke']==0])
sns.scatterplot(x="bmi", y="avg_glucose_level", alpha=1, data=stroke[stroke['stroke']==1], ax=ax)
plt.show()

In [None]:
# Percentage of people
def plot_percent_of_stroke_in_each_category(df, column, axis):
    x_axis = []
    y_axis = []
    
    unique_values = df[column].unique()
    
    for value in unique_values:
        stroke_yes = len(df[(df[column] == value) & (df['stroke'] ==1)])
        total = len(df[df[column] == value])
        percentage = (stroke_yes/total) * 100
        x_axis.append(value)
        y_axis.append(percentage)
        
    sns.barplot(x_axis, y_axis, ax=axis)
    
columns = ['gender', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']

fig, axes = plt.subplots(4, 2, figsize=(16,18))
axes[3, 1].remove()

plot_percent_of_stroke_in_each_category(stroke, 'gender', axes[0,0])
axes[0,0].set_xlabel("Gender")
axes[0,0].set_ylabel("Percentage")

plot_percent_of_stroke_in_each_category(stroke, 'hypertension', axes[0,1])
axes[0,1].set_xlabel("Hypertension")

plot_percent_of_stroke_in_each_category(stroke, 'heart_disease', axes[1,0])
axes[1,0].set_xlabel("Heart Disease")
axes[1,0].set_ylabel("Percentage")

plot_percent_of_stroke_in_each_category(stroke, 'ever_married', axes[1,1])
axes[1,1].set_xlabel("Ever Married")


plot_percent_of_stroke_in_each_category(stroke, 'work_type', axes[2,0])
axes[2,0].set_xlabel("Work Type")
axes[2,0].set_ylabel("Percentage")

plot_percent_of_stroke_in_each_category(stroke, 'Residence_type', axes[2,1])
axes[2,1].set_xlabel("Residence Type")

plot_percent_of_stroke_in_each_category(stroke, 'smoking_status', axes[3,0])
axes[3,0].set_xlabel("Smoking Status")
axes[3,0].set_ylabel("Percentage")

plt.show()

**Insights drawn from above plots**

1. Both genders have around a 5% chance


2. People with a history of hypertension and heart disease have shown an increased percentage of encountering a stroke with around a 12.5% chance and 16.5% chance respectively.


3. Married/Divorced people have a 6.5% chance of a stroke.


4. Self Employed people have a higher chance compared to private and government jobs. Stress induced?


5. Rural and urban residency doesn't seem to show much of a difference.


6. Former smokers have higher chance compared to people who have never smoked or currently smoke.

**5. Preparing the data for prediction**

1. Converting the categorical columns into numerical by mapping each category to an integer value using **map()** on pandas series object

2. As we saw earlier, the data is imbalanced. To make it balanced we will use a technique called SMOTE (Synthetic minority oversampling technique). There are other techniques available such as NearMiss algorithm.



3. Splitting the data into training and testing samples.

In [None]:
# Converting categorical data to numerical

gender_dict = {'Male': 0, 'Female': 1, 'Other': 2}
ever_married_dict = {'No': 0, 'Yes': 1}
work_type_dict = {'children': 0, 'Never_worked': 1, 'Govt_job': 2, 'Private': 3, 'Self-employed': 4}
residence_type_dict = {'Rural': 0, 'Urban': 1}
smoking_status_dict = {'Unknown': 0, 'never smoked': 1, 'formerly smoked':2, 'smokes': 3}

stroke['gender'] = stroke['gender'].map(gender_dict)
stroke['ever_married'] = stroke['ever_married'].map(ever_married_dict)
stroke['work_type'] = stroke['work_type'].map(work_type_dict)
stroke['Residence_type'] = stroke['Residence_type'].map(residence_type_dict)
stroke['smoking_status'] = stroke['smoking_status'].map(smoking_status_dict)


In [None]:
# Splitting into features and value to be predicted
X = stroke.drop(columns=['id', 'stroke'])
y = stroke['stroke']

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)

sns.barplot(x=['0', '1'], y =[sum(y == 0), sum(y == 1)], ax = ax1)
ax1.set_title("Before Oversampling")
ax1.set_xlabel('Stroke')

#Using SMOTE to balance the Data
from imblearn.over_sampling import SMOTE

sm = SMOTE(random_state = 2) 
X, y = sm.fit_resample(X, y) 

sns.barplot(x=['0', '1'], y =[sum(y == 0), sum(y == 1)], ax = ax2)
ax2.set_title("After Oversampling")
ax2.set_xlabel('Stroke')

plt.tight_layout()
plt.show()

In [None]:
# Splitting data into train and test
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=40)

**6. Creating a model for stroke prediction**

In [None]:
# Importing neccessary libraries
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, plot_confusion_matrix

pipeline = make_pipeline(StandardScaler(), RandomForestClassifier())
pipeline.fit(X_train, y_train)
prediction = pipeline.predict(X_test)

print(f"Accuracy Score : {round(accuracy_score(y_test, prediction) * 100, 2)}%")

In [None]:
print(classification_report(y_test, prediction))

In [None]:
plot_confusion_matrix(pipeline, X_test, y_test, cmap=cmap)
plt.grid(False)
plt.show()