In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from plotly.offline import plot
import plotly.express as px

In [None]:
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

In [None]:
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

In [None]:
df = pd.read_csv('/kaggle/input/stroke-prediction-dataset/healthcare-dataset-stroke-data.csv')

In [None]:
df.head()

# Data Visualization

In [None]:
dataviz = df.copy()

In [None]:
dataviz.isnull().sum(axis=0)

<p style="font-size: 15pt;">Replacing missing values with mode</p>

In [None]:
dataviz['bmi'].fillna(dataviz['bmi'].mode()[0],inplace=True)

<h1>Relationship between Stroke and various attributes</h1>

In [None]:
counts0 = Counter(dataviz[dataviz['stroke']==0]['work_type'])
counts1 = Counter(dataviz[dataviz['stroke']==1]['work_type'])
fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}]])

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts0.items())],
     labels=[item[0] for item in sorted(counts0.items())],
     domain=dict(x=[0, 0.5]),
     name="Stroke Negative",title='Stroke Negative'), 
     row=1, col=1)

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts1.items())],
     labels=[item[0] for item in sorted(counts1.items())],
     domain=dict(x=[0.5, 1.0]),
     name="Stroke Positive",title='Stroke Positive'),
    row=1, col=2)

fig.update_layout(
    title={'text':'Stroke vs Worktype','xanchor':'left','yanchor': 'top','y':0.9,'x':0.35},
    xaxis_title="X Axis Title",
    yaxis_title="Y Axis Title",
    legend_title="Worktype",
    font=dict(size=18)
)

fig.show()

<p style="font-size: 15pt;">26% of the stroke positive individuals are self employed compared to only 15% for stroke negative individuals.</p>

In [None]:
counts0 = Counter(dataviz[dataviz['stroke']==0]['ever_married'])
counts1 = Counter(dataviz[dataviz['stroke']==1]['ever_married'])
fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}]])

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts0.items())],
     labels=[item[0] for item in sorted(counts0.items())],
     domain=dict(x=[0, 0.5]),
     name="Stroke Negative",title='Stroke Negative'), 
     row=1, col=1)

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts1.items())],
     labels=[item[0] for item in sorted(counts1.items())],
     domain=dict(x=[0.5, 1.0]),
     name="Stroke Positive",title='Stroke Positive'),
    row=1, col=2)

fig.update_layout(
    title={'text':'Stroke vs Marital Status','xanchor':'left','yanchor': 'top','y':0.9,'x':0.35},
    xaxis_title="X Axis Title",
    yaxis_title="Y Axis Title",
    legend_title="Ever Married?",
    font=dict(size=18)
)

fig.show()

<p style="font-size: 15pt;">90% of the stroke positive individuals have been married compared to only 65% for stroke negative individuals.</p>

In [None]:
counts0 = Counter(dataviz[dataviz['Residence_type']=='Rural']['stroke'])
counts1 = Counter(dataviz[dataviz['Residence_type']=='Urban']['stroke'])
fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}]])

fig.add_trace(go.Pie(
     values=[v for k,v in counts0.items()],
     labels=['Stroke Positive','Stroke Negative'],
     domain=dict(x=[0, 0.5]),
     name="Stroke Negative",title='Rural'), 
     row=1, col=1)

fig.add_trace(go.Pie(
     values=[v for k,v in counts1.items()],
     labels=['Stroke Positive','Stroke Negative'],
     domain=dict(x=[0.5, 1.0]),
     name="Stroke Positive",title='Urban'),
    row=1, col=2)

fig.update_layout(
    title={'text':'Residence Type vs Stroke','xanchor':'left','yanchor': 'top','y':0.9,'x':0.35},
    xaxis_title="X Axis Title",
    yaxis_title="Y Axis Title",
    legend_title="Stroke Status",
    font=dict(size=18)
)

fig.show()
#<p style="font-size: 15pt;">90% of the stroke positive individuals have been married compared to only 35% for stroke negative individuals.</p>

<p style="font-size: 15pt;">Incidence of strokes is slightly higher in urban areas</p>

In [None]:
glucose_bins = np.linspace(0,280,29)
bmi_bins = np.linspace(0,100,51)
age_bins = np.linspace(0,90,10)
dataviz['binned_glucose'] = pd.cut(dataviz['avg_glucose_level'], glucose_bins, labels=glucose_bins[:-1],right=False)
dataviz['binned_bmi'] = pd.cut(dataviz['bmi'], bmi_bins, labels=bmi_bins[:-1],right=False)
dataviz['binned_age'] = pd.cut(dataviz['age'], age_bins, labels=age_bins[:-1],right=False)
dataviz['binned_glucose'] = dataviz['binned_glucose'].astype('int')
dataviz['binned_bmi'] = dataviz['binned_bmi'].astype('int')
dataviz['binned_age'] = dataviz['binned_age'].astype('int')

In [None]:
sns.set(rc={'figure.figsize':(9,5)})
ax = sns.countplot(x='binned_age',hue='stroke',data=dataviz)
ax.set_title('Stroke/Non stroke cases at various age groups').set_fontsize(22)
ax.set_xlabel('Age',fontsize=14)
ax.set_ylabel('Count',fontsize=14)

<p style="font-size: 15pt;">Risk of stroke starts in 30-40 age group and becomes more severe in older age groups.</p>

In [None]:
sns.set(rc={'figure.figsize':(9,5)})
ax = sns.countplot(x='binned_glucose',hue='stroke',data=dataviz)
ax.set_title('Stroke/Non stroke cases at various glucose Levels').set_fontsize(22)
ax.set_xlabel('Glucose level',fontsize=14)
ax.set_ylabel('Count',fontsize=14)

<p style="font-size: 15pt;">Strokes tend to happen in the 50-110 and the 160 to 240 glucose level ranges. It is worthwhile to note that though there are a higher number of cases of stroke (bigger orange bars) in the lower range they are a smaller proportion of total cases. Let's plot a line chart below to observe the actual risk of stroke at each glucose level.</p>

In [None]:
stroke_cases = Counter(dataviz['binned_glucose'][dataviz['stroke']==1])
total_cases = Counter(dataviz['binned_glucose'])
stroke_pct = 100*np.array([item[1] for item in sorted(stroke_cases.items())])/np.array([item[1] for item in sorted(total_cases.items())])
labels = [item[0] for item in sorted(total_cases.items())]
plt.style.use('seaborn')
plt.figure(figsize=(9,5))
plt.ylim(0,30)
plt.title('Stroke risk vs Glucose level (All age groups)',fontsize=20)
plt.ylabel('% of strokes',fontsize=15)
plt.xlabel('Average glucose level',fontsize=15)
plt.plot(labels[:-1],stroke_pct[:-1],'b')

<p style="font-size: 15pt;">The risk of stroke shoots up from 2-5% to 10-20% as glucose levels go beyond 150. <br> Let's split this into 2 groups - one below the age of 40 and the other above 40 - and see if there are any differences.</p>

In [None]:
fix, ax = plt.subplots(1,2,figsize=(18,5))
plt.tight_layout()
stroke_cases = Counter(dataviz['binned_glucose'][(dataviz['binned_age']<40) & (dataviz['stroke']==1)])
total_cases = Counter(dataviz['binned_glucose'][dataviz['binned_age']<40])
z=sorted(stroke_cases.items())
[z.append((item[0],0)) for item in sorted(total_cases.items()) if item[0] not in [item2[0] for item2 in sorted(stroke_cases.items())]]
stroke_pct = 100*np.array([item[1] for item in sorted(z)])/np.array([item[1] for item in sorted(total_cases.items())])
labels = [item[0] for item in sorted(total_cases.items())]
ax[0].set_ylim(0,30)
ax[0].set_xlim(50,280)
ax[0].plot(labels[:-1],stroke_pct[:-1],'g')
ax[0].set_title('Stroke risk vs Glucose level (Age below 40)',fontsize=20)
ax[0].set_ylabel('% of strokes',fontsize=15)
ax[0].set_xlabel('Average glucose level',fontsize=15)

stroke_cases = Counter(dataviz['binned_glucose'][(dataviz['binned_age']>=40) & (dataviz['stroke']==1)])
total_cases = Counter(dataviz['binned_glucose'][dataviz['binned_age']>=40])
z=sorted(stroke_cases.items())
[z.append((item[0],0)) for item in sorted(total_cases.items()) if item[0] not in [item2[0] for item2 in sorted(stroke_cases.items())]]
stroke_pct = 100*np.array([item[1] for item in sorted(z)])/np.array([item[1] for item in sorted(total_cases.items())])
labels = [item[0] for item in sorted(total_cases.items())]
ax[1].set_ylim(0,30)
ax[1].set_xlim(50,280)
ax[1].plot(labels[:-1],stroke_pct[:-1],'r')
ax[1].set_title('Stroke risk vs Glucose level (Age above 40)',fontsize=20)
#ax[1].set_ylabel('% of strokes',fontsize=15)
ax[1].set_xlabel('Average glucose level',fontsize=15)

<p style="font-size: 15pt;">The plot on the left shows that, in age groups below 40, strokes are very rare irrespective of the glucose levels but tend to happen in the low to normal glucose ranges.</p><br>
<p style="font-size: 15pt;">For age groups above 40, the risk of stroke is much more severe than that seen in the aggregated plot of all age groups  - ranging from 5-10% (below glucose levels of 150) to 10-25% (above 150).</p>

In [None]:
counts0 = Counter(dataviz[dataviz['stroke']==0]['hypertension'])
counts1 = Counter(dataviz[dataviz['stroke']==1]['hypertension'])
fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}]])

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts0.items())],
     labels=['No','Yes'],
     domain=dict(x=[0, 0.5]),
     name="Stroke Negative",title='Stroke Negative'), 
     row=1, col=1)

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts1.items())],
     labels=['No','Yes'],
     domain=dict(x=[0.5, 1.0]),
     name="Stroke Positive",title='Stroke Positive'),
    row=1, col=2)

fig.update_layout(
    title={'text':'Stroke vs Hypertension','xanchor':'left','yanchor': 'top','y':0.9,'x':0.35},
    xaxis_title="X Axis Title",
    yaxis_title="Y Axis Title",
    legend_title="Hypertension",
    font=dict(size=18)
)

fig.show()

<p style="font-size: 15pt;">In both stroke positive and stroke negative groups, the majority are not hypertense but 1 in 4 is hypertense in the stroke positive group compared to 1 in 10 in the stroke negative group.</p>

In [None]:
counts0 = Counter(dataviz[dataviz['stroke']==0]['heart_disease'])
counts1 = Counter(dataviz[dataviz['stroke']==1]['heart_disease'])
fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}]])

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts0.items())],
     labels=['No','Yes'],
     domain=dict(x=[0, 0.5]),
     name="Stroke Negative",title='Stroke Negative'), 
     row=1, col=1)

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts1.items())],
     labels=['No','Yes'],
     domain=dict(x=[0.5, 1.0]),
     name="Stroke Positive",title='Stroke Positive'),
    row=1, col=2)

fig.update_layout(
    title={'text':'Stroke vs Heart Disease','xanchor':'left','yanchor': 'top','y':0.9,'x':0.35},
    xaxis_title="X Axis Title",
    yaxis_title="Y Axis Title",
    legend_title="Heart Disease",
    font=dict(size=18)
)

fig.show()

<p style="font-size: 15pt;">Close to 19% of the stroke positive population has heart disease compared to only 5% for the stroke negative group.</p>

In [None]:
counts0 = Counter(dataviz[dataviz['stroke']==0]['smoking_status'])
counts1 = Counter(dataviz[dataviz['stroke']==1]['smoking_status'])
fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}]])

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts0.items())],
     labels=[item[0] for item in sorted(counts0.items())],
     domain=dict(x=[0, 0.5]),
     name="Stroke Negative",title='Stroke Negative'), 
     row=1, col=1)

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts1.items())],
     labels=[item[0] for item in sorted(counts1.items())],
     domain=dict(x=[0.5, 1.0]),
     name="Stroke Positive",title='Stroke Positive'),
    row=1, col=2)

fig.update_layout(
    title={'text':'Stroke vs Smoking Status','xanchor':'left','yanchor': 'top','y':0.9,'x':0.35},
    xaxis_title="X Axis Title",
    yaxis_title="Y Axis Title",
    legend_title="Smoking Status",
    font=dict(size=18)
)

fig.show()

<p style="font-size: 15pt;">With a significant chunk of the samples in 'unknown' status, it is difficult to tell with any certainty if smoking has any impact on the risk of stroke. But as the extremes of 'smokes' and 'never smoked' are having more or less similar proportions in both groups, it may be safe to assume that smoking is a weak predictor of stroke.</p>

In [None]:
counts0 = Counter(dataviz[dataviz['stroke']==0]['gender'])
counts1 = Counter(dataviz[dataviz['stroke']==1]['gender'])
fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}]])

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts0.items())],
     labels=[item[0] for item in sorted(counts0.items())],
     domain=dict(x=[0, 0.5]),
     name="Stroke Negative",title='Stroke Negative'), 
     row=1, col=1)

fig.add_trace(go.Pie(
     values=[item[1] for item in sorted(counts1.items())],
     labels=[item[0] for item in sorted(counts1.items())],
     domain=dict(x=[0.5, 1.0]),
     name="Stroke Positive",title='Stroke Positive'),
    row=1, col=2)

fig.update_layout(
    title={'text':'Stroke vs Gender','xanchor':'left','yanchor': 'top','y':0.9,'x':0.35},
    xaxis_title="X Axis Title",
    yaxis_title="Y Axis Title",
    legend_title="Gender",
    font=dict(size=18)
)

fig.show()

<p style="font-size: 15pt;">The proportion of males and females in both groups is almost same.</p>

In [None]:
stroke_cases = Counter(dataviz['binned_bmi'][dataviz['stroke']==1])
total_cases = Counter(dataviz['binned_bmi'])
z=sorted(stroke_cases.items())
[z.append((item[0],0)) for item in sorted(total_cases.items()) if item[0] not in [item2[0] for item2 in sorted(stroke_cases.items())]]
stroke_pct = 100*np.array([item[1] for item in sorted(z)])/np.array([item[1] for item in sorted(total_cases.items())])
labels = [item[0] for item in sorted(total_cases.items())]
plt.style.use('ggplot')
plt.figure(figsize=(9,5))
plt.ylim(0,30)
plt.title('Stroke risk vs BMI (All age groups)',fontsize=20)
plt.ylabel('% of strokes',fontsize=15)
plt.xlabel('BMI',fontsize=15)
plt.plot(labels,stroke_pct,'b')

<p style="font-size: 15pt;">BMI does not seem to have strong relationship with the risk of stroke.</p>

<p style="font-size: 15pt;">Based on our analysis so far, glucose levels and age are the strongest predictors of stroke. Let us study their relationship more closely using violin charts</p>

In [None]:
plt.figure(figsize=(18,9))
ax = sns.violinplot(x='binned_age',y='binned_glucose',hue='stroke',data=dataviz,scale='count',palette={0: "b", 1: "orange"},cut=0)
ax.set_xlabel('Age',fontsize=14)
ax.set_ylabel('Glucose level',fontsize=14)
ax.set_title('Age, Glucose vs Stroke',fontsize=20)

<h2>Analysis of the violin plot:</h2>
<p style="font-size: 15pt;">The plot corroborates some of our earlier conclusions that,<br> - people below 40 are immune to strokes even at higher levels of glucose (longer blue violins upto glucose levels of 250 and absence of or thinner orange violins below age of 40)<br> - the risk of strokes starts around 30-40 years (orange violins get thicker progressively)<br><br>
In higher age groups, orange violins have slightly higher median glucose levels (white dots inside the violins) but the range of glucose levels for stroke and non stroke cases is more or less similar. This implies that even in higher age groups there is not much difference between stroke and non stroke cases in terms of glucose levels. It means that in addition to glucose levels there could be other causative factors.<br><br>
Let's try incorporating some of the other known risk factors in the dataset, namely - hypertension, heart disease, BMI>25 and smoking - and create a new attribute called "no_of_risk_factors"
</p>



In [None]:
dataviz['smokes'] = np.nan
dataviz['smokes'][dataviz['smoking_status']=='smokes'] = 1
dataviz['smokes'][dataviz['smoking_status']!='smokes'] = 0
dataviz['smokes'] = dataviz['smokes'].astype('int')

In [None]:
dataviz['high_bmi']=np.nan
dataviz['high_bmi'][dataviz['bmi']>=25] = 1
dataviz['high_bmi'][dataviz['bmi']<25] = 0
dataviz['no_of_risk_factors'] = dataviz['hypertension']+dataviz['heart_disease']+dataviz['smokes']+dataviz['high_bmi']

In [None]:
plt.figure(figsize=(18,9))
ax = sns.violinplot(x='binned_age',y='no_of_risk_factors',hue='stroke',data=dataviz[dataviz['binned_glucose']>150],scale='count',palette={0: "b", 1: "orange"},cut=0)
ax.set_title('Age, No. of risk factors vs Stroke (Glucose level>150)',fontsize=20)
ax.set_xlabel('Age',fontsize=14)
ax.set_ylabel('No. of risk factors',fontsize=14)

<p style="font-size: 15pt;">Incorporating the number of risk factors, we see that the orange violins are placed one risk factor above the blue violins which reveals that the stroke group has 1 additional risk factor at the median and a higher range of no. of risk factors compared to the non stroke group.<br><br> In fact, it turns out that the number of risk factors does a good job of predicting strokes on its own as the below line plot shows -  as the risk factors rise the corresponding lines rise up and hit higher peaks.</p>

In [None]:
#risk 0
stroke_cases = Counter(dataviz['binned_age'][(dataviz['no_of_risk_factors']==0) & dataviz['stroke']==1])
total_cases = Counter(dataviz['binned_age'][dataviz['no_of_risk_factors']==0])
all_cases = total_cases
z=sorted(stroke_cases.items())
[z.append((item[0],0)) for item in sorted(total_cases.items()) if item[0] not in [item2[0] for item2 in sorted(stroke_cases.items())]]
stroke_pct = 100*np.array([item[1] for item in sorted(z)])/np.array([item[1] for item in sorted(total_cases.items())])
labels = [item[0] for item in sorted(total_cases.items())]
plt.style.use('seaborn')
plt.figure(figsize=(12,7))
plt.title('Age vs Stroke at various risk levels',fontsize=20)
plt.ylabel('% of strokes',fontsize=15)
plt.xlabel('Age',fontsize=15)
plt.plot(labels,stroke_pct,'b',label='0 risk factors')

#risk 1
stroke_cases = Counter(dataviz['binned_age'][(dataviz['no_of_risk_factors']==1) & dataviz['stroke']==1])
total_cases = Counter(dataviz['binned_age'][dataviz['no_of_risk_factors']==1])
z=sorted(stroke_cases.items())
[z.append((item[0],0)) for item in sorted(total_cases.items()) if item[0] not in [item2[0] for item2 in sorted(stroke_cases.items())]]
stroke_pct = 100*np.array([item[1] for item in sorted(z)])/np.array([item[1] for item in sorted(total_cases.items())])
plt.plot(labels,stroke_pct,'g',label='1 risk factor')

#risk 2
stroke_cases = Counter(dataviz['binned_age'][(dataviz['no_of_risk_factors']==2) & dataviz['stroke']==1])
total_cases = Counter(dataviz['binned_age'][dataviz['no_of_risk_factors']==2])
z=sorted(stroke_cases.items())
[z.append((item[0],0)) for item in sorted(total_cases.items()) if item[0] not in [item2[0] for item2 in sorted(stroke_cases.items())]]
stroke_pct = 100*np.array([item[1] for item in sorted(z)])/np.array([item[1] for item in sorted(total_cases.items())])
plt.plot(labels,stroke_pct,'r',label='2 risk factors')

#risk 3
stroke_cases = Counter(dataviz['binned_age'][(dataviz['no_of_risk_factors']==3) & dataviz['stroke']==1])
total_cases = Counter(dataviz['binned_age'][dataviz['no_of_risk_factors']==3])
y=sorted(total_cases.items())
[y.append((item[0],np.inf)) for item in sorted(all_cases.items()) if item[0] not in [item2[0] for item2 in sorted(total_cases.items())]]
z = sorted(stroke_cases.items())
[z.append((item[0],0)) for item in sorted(y) if item[0] not in [item2[0] for item2 in sorted(stroke_cases.items())]]
stroke_pct = 100*np.array([item[1] for item in sorted(z)])/np.array([item[1] for item in sorted(y)])
plt.plot(labels,stroke_pct,'orange',label='3 risk factors')

#risk 4
stroke_cases = Counter(dataviz['binned_age'][(dataviz['no_of_risk_factors']==4) & dataviz['stroke']==1])
total_cases = Counter(dataviz['binned_age'][dataviz['no_of_risk_factors']==4])
y=sorted(total_cases.items())
[y.append((item[0],np.inf)) for item in sorted(all_cases.items()) if item[0] not in [item2[0] for item2 in sorted(total_cases.items())]]
z = sorted(stroke_cases.items())
[z.append((item[0],0)) for item in sorted(y) if item[0] not in [item2[0] for item2 in sorted(stroke_cases.items())]]
stroke_pct = 100*np.array([item[1] for item in sorted(z)])/np.array([item[1] for item in sorted(y)])
plt.plot(labels,stroke_pct,'purple',label='4 risk factors')
plt.legend()

# Data Preparation

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from imblearn.over_sampling import RandomOverSampler

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC

<p style="font-size: 15pt;">Looking at the distribution of strokes in the dataset using a pie chart shows that the data is highly imbalanced with only 5% of the cases being strokes. An imbalanced dataset implies that metrics like accuracy and ROC AUC score will be misleading. We will need to look at precision, recall and f1 score to judge the performance of our models. Since this is a medical application, our focus should be on recall as it is necessary to identify as many risky cases as possible even if it is at the expense of precision.<br><br>I will be balancing the dataset using RandomOverSampler from imblearn library.</p>

In [None]:
fig = px.pie(df['stroke'], values=[item[1] for item in Counter(df['stroke']).items()], names=['Stroke Positive','Stroke Negative'], title='Strokes and Non stroke cases')
fig.show()

<p style="font-size: 15pt;">Train test split will be done before over sampling to prevent data leakage by way of duplicates. I will only oversample the train set and leave the test set as it is.</p>

In [None]:
X = df.drop(['id','stroke'],axis=1)
Y = df['stroke']

In [None]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2,stratify=Y,random_state=42)

In [None]:
bmi_mode = X_train['bmi'].mode()
X_train['bmi'].fillna(bmi_mode[0],inplace=True)
X_test['bmi'].fillna(bmi_mode[0],inplace=True)

<p style="font-size: 15pt;">I am adding 2 additional columns - one is the number of risk factors which we saw in data visualization and the other one is a value that is directly proportional to both average glucose level and the number of risk factors, (Avg. Glucose Level) * (1 + No. of risk factors)</p>

In [None]:
def add_additional_cols(df):
    df['smokes'] = np.nan
    df['smokes'][df['smoking_status']=='smokes'] = 1
    df['smokes'][df['smoking_status']!='smokes'] = 0
    df['smokes'] = df['smokes'].astype('int')

    df['high_bmi']=np.nan
    df['high_bmi'][df['bmi']>=25] = 1
    df['high_bmi'][df['bmi']<25] = 0
    
    df['no_of_risk_factors'] = df['hypertension']+df['heart_disease']+df['smokes']+df['high_bmi']
    df['attrib1'] = df['avg_glucose_level'] * (1 + df['no_of_risk_factors'])

In [None]:
add_additional_cols(X_train)
add_additional_cols(X_test)

In [None]:
oversampler = RandomOverSampler(sampling_strategy='minority')
X_train_bl, Y_train_bl = oversampler.fit_resample(X_train,Y_train)

In [None]:
num_attribs = ['age','hypertension','heart_disease','avg_glucose_level','bmi','no_of_risk_factors','attrib1']
cat_attribs = ['gender','ever_married','work_type','Residence_type','smoking_status']
num_pipeline = Pipeline([
    ('scaler',StandardScaler())
])
full_pipeline = ColumnTransformer([
    ('num',num_pipeline,num_attribs),
    ('cat',OneHotEncoder(handle_unknown='ignore'),cat_attribs)
])

X_train_prepared = full_pipeline.fit_transform(X_train_bl)
X_test_prepared = full_pipeline.transform(X_test)

<h1>Training and prediction</h1>

In [None]:
log_reg = LogisticRegression()
log_reg.fit(X_train_prepared, Y_train_bl)
Y_pred = log_reg.predict(X_test_prepared)
sns.heatmap(confusion_matrix(Y_test, Y_pred),annot=True,fmt='d',cmap='Blues')

In [None]:
print(classification_report(Y_test, Y_pred))

In [None]:
lin_svc = LinearSVC()
lin_svc.fit(X_train_prepared, Y_train_bl)
Y_pred = lin_svc.predict(X_test_prepared)
sns.heatmap(confusion_matrix(Y_test, Y_pred),annot=True,fmt='d',cmap='Blues')

In [None]:
print(classification_report(Y_test, Y_pred))

<h2>Using only a single attribute - Age</h2>

In [None]:
num_attribs = ['age']
num_pipeline = Pipeline([
    ('scaler',StandardScaler())
])
full_pipeline = ColumnTransformer([
    ('num',num_pipeline,num_attribs),
])

X_train_prepared = full_pipeline.fit_transform(X_train_bl)
X_test_prepared = full_pipeline.transform(X_test)

In [None]:
lin_svc = LinearSVC()
lin_svc.fit(X_train_prepared, Y_train_bl)
Y_pred = lin_svc.predict(X_test_prepared)
sns.heatmap(confusion_matrix(Y_test, Y_pred),annot=True,fmt='d',cmap='Blues')

In [None]:
print(classification_report(Y_test, Y_pred))

<p style="font-size: 15pt;">Surprisingly, using only age as the input to the Linear SVC is giving the best recall performance of 84% with minor worsening of precision and f1 score compared to the models trained on all attributes. It means that for the machine learning algorithms all the other attributes in their present state are merely noise or just redundant.<br><br> This makes a strong case for feature engineering/selection of the other attributes for improving the performance.</p>

<h2>Other feature combinations:</h2>
<p style="font-size: 15pt;">After some trial and error, I found that Heart disease and Marriage status are giving an even better recall of 90%. Let's look at it below,<br>

In [None]:
num_attribs = ['heart_disease']
cat_attribs = ['ever_married']
num_pipeline = Pipeline([
    ('scaler',StandardScaler())
])
full_pipeline = ColumnTransformer([
    ('num',num_pipeline,num_attribs),
    ('cat',OneHotEncoder(handle_unknown='ignore'),cat_attribs)
])

X_train_prepared = full_pipeline.fit_transform(X_train_bl)
X_test_prepared = full_pipeline.transform(X_test)

In [None]:
lin_svc = LinearSVC()
lin_svc.fit(X_train_prepared, Y_train_bl)
Y_pred_2 = lin_svc.predict(X_test_prepared)
sns.heatmap(confusion_matrix(Y_test, Y_pred_2),annot=True,fmt='d',cmap='Blues')

In [None]:
print(classification_report(Y_test, Y_pred_2))

<p style="font-size: 15pt;">But using these 2 attributes, precision of the stroke class and recall of the non stroke class are severely compromised. This model looks very poor, so the single attribute model is the best one so far.</p>

<h1>Conclusion</h1>

<p style="font-size: 15pt;">The reason for the under performance of the model when all features are included could be the fact that important risk factors like hypertension and heart disease arise with the onset of aging, i.e being correlated with age they are not adding any useful information. Absent any feature engineering, the only way to improve the performance would be to incorporate more useful features - not just any features but those which don't have any correlation with age. <br><br>The current dataset is missing several important risk factors for stroke, namely, <br>
    <br>1. Stress
    <br>2. Physical inactivity
    <br>3. Family history - genetic disorders like CADASIL
    <br>4. Prior stroke history
    <br>5. Drug and alcohol abuse
    <br>6. Atrial Fibrillation
    <br><br>All the above risk factors are uncorrelated with age and are likely to improve model performance.
</p>