In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
# read the data from drive
df = pd.read_csv('/content/drive/My Drive/XAI/healthcare-dataset-stroke-data.csv')

In [4]:
df

Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,9046,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,51676,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,,never smoked,1
2,31112,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
3,60182,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
4,1665,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1
...,...,...,...,...,...,...,...,...,...,...,...,...
5105,18234,Female,80.0,1,0,Yes,Private,Urban,83.75,,never smoked,0
5106,44873,Female,81.0,0,0,Yes,Self-employed,Urban,125.20,40.0,never smoked,0
5107,19723,Female,35.0,0,0,Yes,Self-employed,Rural,82.99,30.6,never smoked,0
5108,37544,Male,51.0,0,0,Yes,Private,Rural,166.29,25.6,formerly smoked,0


In [5]:
cat_col = [col for col in df.columns if df[col].dtypes=='O']
cat_col

['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']

In [6]:
# we can see that gender, ever_married, work_type, Residence_type, smoking_status features has data type object, 
# if we check the dataset, we can observe that these feature has categorial values, it needs to be converted into one-hot-encoding

for col in cat_col:
  df = pd.concat([df, pd.get_dummies(df[col], prefix=col, prefix_sep='_')], axis=1)
  df.drop(columns=col, inplace=True)
df

Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,stroke,gender_Female,gender_Male,gender_Other,ever_married_No,ever_married_Yes,work_type_Govt_job,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Rural,Residence_type_Urban,smoking_status_Unknown,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes
0,9046,67.0,0,1,228.69,36.6,1,0,1,0,0,1,0,0,1,0,0,0,1,0,1,0,0
1,51676,61.0,0,0,202.21,,1,1,0,0,0,1,0,0,0,1,0,1,0,0,0,1,0
2,31112,80.0,0,1,105.92,32.5,1,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0
3,60182,49.0,0,0,171.23,34.4,1,1,0,0,0,1,0,0,1,0,0,0,1,0,0,0,1
4,1665,79.0,1,0,174.12,24.0,1,1,0,0,0,1,0,0,0,1,0,1,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5105,18234,80.0,1,0,83.75,,0,1,0,0,0,1,0,0,1,0,0,0,1,0,0,1,0
5106,44873,81.0,0,0,125.20,40.0,0,1,0,0,0,1,0,0,0,1,0,0,1,0,0,1,0
5107,19723,35.0,0,0,82.99,30.6,0,1,0,0,0,1,0,0,0,1,0,1,0,0,0,1,0
5108,37544,51.0,0,0,166.29,25.6,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0


In [7]:
df.drop(columns='id', inplace=True)

In [8]:
train = df.query('bmi!="NaN"')
test = df.query('bmi=="NaN"')

In [10]:
train.shape, test.shape

((4909, 22), (201, 22))

In [11]:
x_train, y_train = train.drop(columns='bmi'), train['bmi']
x_train.shape, y_train.shape

((4909, 21), (4909,))

In [12]:
x_test = test.drop(columns='bmi')
x_test.shape

(201, 21)

In [13]:
from sklearn.linear_model import LinearRegression

In [30]:
# Split data into train and test
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures


model = Pipeline(steps=[('poly', PolynomialFeatures(2)),
                      ('reg', LinearRegression())])
model.fit(x_train, y_train)

Pipeline(memory=None,
         steps=[('poly',
                 PolynomialFeatures(degree=2, include_bias=True,
                                    interaction_only=False, order='C')),
                ('reg',
                 LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
                                  normalize=False))],
         verbose=False)

In [31]:
tt=model.predict(x_train)
from sklearn.metrics import mean_absolute_error
print('Traiing loss MAE: ', mean_absolute_error(tt, y_train))


Traiing loss MAE:  4.888695083850724


In [32]:
yhat = model.predict(x_test)

In [33]:
np.mean(yhat), np.std(yhat), np.min(yhat), np.max(yhat)

(30.213695490853258, 4.383901300921472, 17.601031992584467, 41.5135037638247)

In [34]:
manual = [29.87948718, 30.55609756, 27.24722222, 30.84186047, 33.14666667,
       31.95882353, 29.81794872, 30.77931034, 30.23181818, 27.24722222,
       29.81794872, 27.24722222, 27.75789474, 30.23181818, 31.60943396,
       31.81176471, 33.14666667, 33.14666667, 30.23181818, 29.76      ,
       27.24722222, 30.77931034, 31.58913043, 30.95652174, 27.93454545,
       30.23181818, 30.35925926, 18.525     , 27.68181818, 31.16923077,
       27.93454545, 28.35964912, 28.47692308, 29.62083333, 30.58064516,
       27.68181818, 29.55714286, 30.23181818, 29.55714286, 30.35925926,
       29.77391304, 30.23181818, 29.81794872, 30.58064516, 31.9972973 ,
       28.31363636, 32.78      , 30.58064516, 30.27916667, 30.93658537,
       18.94545455, 34.4969697 , 29.62083333, 28.5       , 30.77692308,
       28.59166667, 29.04888889, 28.31363636, 29.04888889, 30.77692308,
       21.51538462, 30.47619048, 18.43928571, 33.09333333, 32.05333333,
       30.37619048, 19.1875    , 32.04333333, 30.77692308, 32.11702128,
       28.47692308, 31.03777778, 29.83125   , 27.25238095, 31.13125   ,
       33.09333333, 30.82222222, 17.88333333, 29.87948718, 30.65882353,
       28.31578947, 30.35925926, 31.42      , 31.005     , 27.27714286,
       32.78      , 33.336     , 33.14666667, 32.4       , 31.01304348,
       31.95882353, 29.77391304, 30.26842105, 31.01304348, 32.78      ,
       30.575     , 31.28666667, 30.80588235, 32.00263158, 19.96      ,
       27.65      , 27.56923077, 29.94545455, 31.24      , 32.00263158,
       16.65      , 32.05333333, 31.77307692, 26.70714286, 29.62083333,
       29.94545455, 25.26666667, 34.75806452, 27.56923077, 32.60961538,
       32.01111111, 30.590625  , 32.66585366, 27.24722222, 31.005     ,
       32.9875    , 30.77692308, 27.68181818, 27.68181818, 31.9972973 ,
       30.37619048, 18.87096774, 31.58913043, 29.83125   , 24.92307692,
       31.77307692, 30.9       , 27.75789474, 30.93571429, 31.52424242,
       32.78387097, 31.44565217, 30.90555556, 30.35925926, 30.37619048,
       28.37826087, 31.0525    , 31.58913043, 29.41395349, 31.24      ,
       33.14666667, 30.17931034, 34.25652174, 27.24722222, 30.93571429,
       32.59666667, 31.28666667, 32.2627451 , 29.94545455, 32.60961538,
       33.08695652, 31.52424242, 33.08695652, 31.03777778, 34.4969697 ,
       32.21086957, 17.675     , 29.77391304, 25.25172414, 31.45932203,
       19.975     , 18.65      , 24.77241379, 29.87948718, 31.89166667,
       30.575     , 27.93454545, 19.70967742, 30.77931034, 27.93454545,
       30.590625  , 30.27916667, 28.090625  , 30.575     , 28.5       ,
       26.746875  , 31.95882353, 18.93125   , 25.26666667, 28.35964912,
       34.4969697 , 32.11351351, 31.95882353, 28.35964912, 31.9       ,
       31.9       , 32.07307692, 29.83125   , 21.35      , 33.14666667,
       30.27916667, 32.716     , 28.31363636, 31.45932203, 28.31363636,
       28.47692308]

In [35]:
dd=pd.DataFrame({'predict':yhat, 'manual':manual})

In [36]:
dd['diff'] = dd['predict'] - dd['manual']
dd

Unnamed: 0,predict,manual,diff
0,34.872541,29.879487,4.993054
1,28.351646,30.556098,-2.204451
2,31.957969,27.247222,4.710747
3,31.887718,30.841860,1.045857
4,32.950065,33.146667,-0.196601
...,...,...,...
196,31.075325,32.716000,-1.640675
197,31.632696,28.313636,3.319060
198,36.014231,31.459322,4.554909
199,30.894247,28.313636,2.580611


In [37]:
dd.describe()

Unnamed: 0,predict,manual,diff
count,201.0,201.0,201.0
mean,30.213695,29.54491,0.668786
std,4.394847,3.550462,2.898905
min,17.601032,16.65,-6.190283
25%,28.351646,28.359649,-1.308371
50%,30.3666,30.37619,0.421145
75%,32.576088,31.609434,2.203632
max,41.513504,34.758065,10.938504
