# Decision Tree
This algorithm is amongst the simplest and fastest classification algorithm.  
The classification takes place in a series of decisions and the procedure will be clear to you by the end of the notebook.

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.impute import KNNImputer
from sklearn.tree import *
from sklearn.metrics import f1_score

# Loading Dataset

In [None]:
data = pd.read_csv('../input/stroke-prediction-dataset/healthcare-dataset-stroke-data.csv')
data.head()

# Histogram Plots

In [None]:
data.hist(figsize=(10,10))

# Other Properties

In [None]:
data.nunique()

In [None]:
data.isna().sum()

In [None]:
data.stroke.value_counts()

As the dataset is highly imbalanced, we'll go for startified split rather than train test split.  
Startified Split, balances the ration of train and test target variables as in original dataset, to represent population

In [None]:
data.describe()

# One Hot Encoding Categorical Variables

In [None]:
cat = ['gender','ever_married','Residence_type','smoking_status','work_type']
for i in cat:
    dummy = pd.get_dummies(data[i],drop_first=True,prefix=f"{i}_")
    data = pd.concat([data,dummy],axis=1)

In [None]:
data.head()

In [None]:
data = data.drop([*cat,'id'],axis=1)

In [None]:
data.head()

In [None]:
data.corrwith(data['stroke'])

# Train and Test

In [None]:
X = data.drop('stroke',axis=1).values
y = data['stroke'].values

In [None]:
skf = StratifiedKFold(n_splits=5)
skf.get_n_splits(X, y)

for train_index, test_index in skf.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    
    imputer = KNNImputer(n_neighbors=2)
    X_train = imputer.fit_transform(X_train)
    X_test = imputer.fit_transform(X_test)
    
    clf = DecisionTreeClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    f = f1_score(y_true = y_test , y_pred = y_pred,average = 'weighted')
    
    print(f)

**Note** Preprocessing techniques are avoided here to represent True Decision Tree.  
It shall be done to increase overall accuracy

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(20,10))
plot_tree(clf,feature_names=data.drop('stroke',axis=1).columns,max_depth = 2,filled=True,class_names = True)
plt.show()

You can learn aboout this graph and it's interpretation from [here](https://towardsdatascience.com/understanding-decision-trees-once-and-for-all-2d891b1be579)  
Happy Learning