# Library Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn, optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from torch.nn.utils import weight_norm as WN
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

import os
import gc
from time import time

import warnings
warnings.filterwarnings(action="ignore")

si_mean = SimpleImputer(missing_values=np.nan, strategy="mean")
seed = 0

# Helpers

In [None]:
def breaker():
    print("\n" + 50*"-" + "\n")
    
def head(x=None, no_of_ele=5):
    print(x[:no_of_ele])
    
def getCol(x=None):
    return [col for col in x.columns]

def getAccuracy(y_pred=None, y_true=None):
    y_pred, y_true = torch.sigmoid(y_pred).detach(), y_true.detach()

    y_pred[y_pred <= 0.5] = 0
    y_pred[y_pred > 0.5] = 1

    return torch.count_nonzero(y_pred == y_true).item() / len(y_pred)

def preprocess(x=None, *args):
    df = x.copy()
    df[args[0]] = df[args[0]].map({"Other" : 0, "Female" : 1, "Male" : 2})
    df[args[1]] = df[args[1]].map({"No" : 0, "Yes" : 1})
    df[args[2]] = df[args[2]].map({"children" : 0, "Never_worked" : 1, "Govt_job" : 2, "Self-employed" : 3, "Private" : 4})
    df[args[3]] = df[args[3]].map({"Urban" : 0, "Rural" : 1})
    df[args[4]] = df[args[4]].map({"smokes" : 0, "formerly smoked" : 1, "never smoked" : 2, "Unknown" : 3})
    return df

# Data Handling and Analysis

**Data Input**

In [None]:
data = pd.read_csv("../input/stroke-prediction-dataset/healthcare-dataset-stroke-data.csv", engine="python")

breaker()
print("Dataset Shape : {}".format(data.shape))
breaker()
for col in getCol(data):
    print(col + " - " + repr(data[col].nunique()))
breaker()

**Simple EDA**

In [None]:
plt.figure(figsize=(20, 9))
plt.subplot(2, 5, 1)
sns.countplot(data=data, x="gender")
plt.subplot(2, 5, 2)
sns.histplot(data=data, x="age", kde=True) 
plt.subplot(2, 5, 3)
sns.countplot(data=data, x="hypertension")
plt.subplot(2, 5, 4)
sns.countplot(data=data, x="heart_disease")
plt.subplot(2, 5, 5)
sns.countplot(data=data, x="ever_married")
plt.subplot(2, 5, 6)
sns.countplot(data=data, x="work_type")
plt.subplot(2, 5, 7)
sns.countplot(data=data, x="Residence_type")
plt.subplot(2, 5, 8)
sns.histplot(data=data, x="avg_glucose_level", kde=True) 
plt.subplot(2, 5, 9)
sns.histplot(data=data, x="bmi", kde=True)
plt.subplot(2, 5, 10)
sns.countplot(data=data, x="smoking_status")
plt.show()

sns.countplot(data=data, x="stroke")
plt.show()

**Preprocessing**

In [None]:
data = preprocess(data, "gender", "ever_married", "work_type", "Residence_type", "smoking_status")

data = si_mean.fit_transform(data)

test_features,   test_labels = data[5000:, 1:-1], data[5000:, -1] 
train_features, train_labels = data[:5000, 1:-1], data[:5000, -1]

**Custom Pytorch Dataset Template**

In [None]:
class DS(Dataset):
    def __init__(self, X=None, y=None, mode="train"):
        self.mode = mode
        self.X = X
        if self.mode == "train" or self.mode == "valid":
            self.y = y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        if self.mode == "train" or self.mode == "valid":
            return torch.FloatTensor(self.X[idx]), torch.FloatTensor(self.y[idx])
        else:
            return torch.FloatTensor(self.X[idx])

# ANN Config and Setup

**Config**

In [None]:
class CFG:
    ts_batch_size = 128
    IL = 10
    OL = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __init__(self, HL=None, epochs=None, n_folds=None, batch_size=None, tr_va_split=0.2):
        self.HL = HL
        self.epochs = epochs
        self.n_folds = n_folds
        self.tr_batch_size = batch_size
        self.va_batch_size = batch_size
        self.tr_va_split = tr_va_split

**Setup**

In [None]:
class Classifier(nn.Module):
    def __init__(self, IL=None, HL=None, OL=None, use_DP=False, DP=0.5):

        super(Classifier, self).__init__()

        self.use_DP = use_DP
        if self.use_DP:
            self.DP_ = nn.Dropout(p=DP)

        self.HL = HL

        if len(self.HL) == 1:
            self.BN1 = nn.BatchNorm1d(num_features=IL, eps=1e-5)
            self.FC1 = WN(nn.Linear(in_features=IL, out_features=HL[0]))

            self.BN2 = nn.BatchNorm1d(num_features=HL[0], eps=1e-5)
            self.FC2 = WN(nn.Linear(in_features=HL[0], out_features=OL))

        elif len(self.HL) == 2:
            self.BN1 = nn.BatchNorm1d(num_features=IL, eps=1e-5)
            self.FC1 = WN(nn.Linear(in_features=IL, out_features=HL[0]))

            self.BN2 = nn.BatchNorm1d(num_features=HL[0], eps=1e-5)
            self.FC2 = WN(nn.Linear(in_features=HL[0], out_features=HL[1]))

            self.BN3 = nn.BatchNorm1d(num_features=HL[1], eps=1e-5)
            self.FC3 = WN(nn.Linear(in_features=HL[1], out_features=OL))

        elif len(self.HL) == 3:
            self.BN1 = nn.BatchNorm1d(num_features=IL, eps=1e-5)
            self.FC1 = WN(nn.Linear(in_features=IL, out_features=HL[0]))

            self.BN2 = nn.BatchNorm1d(num_features=HL[0], eps=1e-5)
            self.FC2 = WN(nn.Linear(in_features=HL[0], out_features=HL[1]))

            self.BN3 = nn.BatchNorm1d(num_features=HL[1], eps=1e-5)
            self.FC3 = WN(nn.Linear(in_features=HL[1], out_features=HL[2]))

            self.BN4 = nn.BatchNorm1d(num_features=HL[2], eps=1e-5)
            self.FC4 = WN(nn.Linear(in_features=HL[2], out_features=OL))

    def getOptimizer(self, lr=1e-3, wd=0):
        return optim.Adam(self.parameters(), lr=lr, weight_decay=wd)

    def getPlateauLR(self, optimizer=None, patience=5, eps=1e-6):
        return optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=patience, eps=eps, verbose=True)

    def forward(self, x):
        if not self.use_DP:
            if len(self.HL) == 1:
                x = F.relu(self.FC1(self.BN1(x)))
                x = self.FC2(self.BN2(x))

                return x

            elif len(self.HL) == 2:
                x = F.relu(self.FC1(self.BN1(x)))
                x = F.relu(self.FC2(self.BN2(x)))
                x = self.FC3(self.BN3(x))

                return x

            elif len(self.HL) == 3:
                x = F.relu(self.FC1(self.BN1(x)))
                x = F.relu(self.FC2(self.BN2(x)))
                x = F.relu(self.FC3(self.BN3(x)))
                x = self.FC4(self.BN4(x))

                return x
        else:
            if len(self.HL) == 1:
                x = F.relu(self.DP_(self.FC1(self.BN1(x))))
                x = self.FC2(self.BN2(x))

                return x

            elif len(self.HL) == 2:
                x = F.relu(self.DP_(self.FC1(self.BN1(x))))
                x = F.relu(self.DP_(self.FC2(self.BN2(x))))
                x = self.FC3(self.BN3(x))

                return x

            elif len(self.HL) == 3:
                x = F.relu(self.DP_(self.FC1(self.BN1(x))))
                x = F.relu(self.DP_(self.FC2(self.BN2(x))))
                x = F.relu(self.DP_(self.FC3(self.BN3(x))))
                x = self.FC4(self.BN4(x))

                return x

**ANN Helpers**

In [None]:
def fit_(model=None, optimizer=None, scheduler=None, epochs=None, early_stopping_patience=None,
         trainloader=None, validloader=None,
         criterion=None, device=None,
         save_to_file=False,
         path=None, verbose=False):

    breaker()
    print("Training ...")
    breaker()

    model.to(device)

    DLS = {"train": trainloader, "valid": validloader}
    bestLoss = {"train": np.inf, "valid": np.inf}
    bestAccs = {"train": 0.0, "valid": 0.0}

    Losses = []
    Accuracies = []
    
    if save_to_file:
        file = open(os.path.join(path, "Metrics.txt"), "w")

    start_time = time()
    for e in range(epochs):
        e_st = time()

        epochLoss = {"train": 0.0, "valid": 0.0}
        epochAccs = {"train": 0.0, "valid": 0.0}

        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            lossPerPass = []
            accsPerPass = []

            for X, y in DLS[phase]:
                X = X.to(device)
                if y.dtype == torch.int64:
                    y = y.to(device).view(-1)
                else:
                    y = y.to(device)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    output = model(X)
                    loss = criterion(output, y)
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                lossPerPass.append(loss.item())
                accsPerPass.append(getAccuracy(output, y))
            epochLoss[phase] = np.mean(np.array(lossPerPass))
            epochAccs[phase] = np.mean(np.array(accsPerPass))
        Losses.append(epochLoss)
        Accuracies.append(epochAccs)

        torch.save({"model_state_dict": model.state_dict(),
                    "optim_state_dict": optimizer.state_dict()},
                    os.path.join(path, "Epoch_{}.pt".format(e + 1)))

        if early_stopping_patience:
            if epochLoss["valid"] < bestLoss["valid"]:
                bestLoss = epochLoss
                bestLossEpoch = e + 1
                torch.save({"model_state_dict": model.state_dict(),
                            "optim_state_dict": optimizer.state_dict()},
                            os.path.join(path, "Epoch_{}.pt".format(e + 1)))
                early_stopping_step = 0
            else:
                early_stopping_step += 1
                if early_stopping_step > early_stopping_patience:
                    print("Early Stopping at Epoch {}".format(e + 1))
                    break

        if epochLoss["valid"] < bestLoss["valid"]:
            bestLoss = epochLoss
            bestLossEpoch = e + 1

        if epochAccs["valid"] > bestAccs["valid"]:
            bestAccs = epochAccs
            bestAccsEpoch = e + 1

        if verbose:
            print("Epoch: {} | Train Loss: {:.5f} | Valid Loss: {:.5f} | Train Accs : {:.5f} | \
Valid Accs : {:.5f} | Time: {:.2f} seconds".format(e + 1,
                                                  epochLoss["train"], epochLoss["valid"],
                                                  epochAccs["train"], epochAccs["valid"],
                                                  time() - e_st))

        if save_to_file:
            text = "Epoch: {} | Train Loss: {:.5f} | Valid Loss: {:.5f} | Train Accs : {:.5f} | \
Valid Accs : {:.5f} | Time: {:.2f} seconds\n".format(e + 1,
                                                     epochLoss["train"], epochLoss["valid"],
                                                     epochAccs["train"], epochAccs["valid"],
                                                     time() - e_st)
            file.write(text)

        if scheduler:
            scheduler.step(epochLoss["valid"])

    breaker()
    print("-----> Best Validation Loss at Epoch {}".format(bestLossEpoch))
    breaker()
    print("-----> Best Validation Accs at Epoch {}".format(bestAccsEpoch))
    breaker()
    print("Time Taken [{} Epochs] : {:.2f} minutes".format(epochs, (time() - start_time) / 60))
    breaker()
    print("Training Complete")
    breaker()

    if save_to_file:
        text_1 = "\n-----> Best Validation Loss at Epoch {}\n".format(bestLossEpoch)
        text_2 = "-----> Best Validation Accs at Epoch {}\n".format(bestAccsEpoch)
        text_3 = "Time Taken [{} Epochs] : {:.2f} minutes\n".format(epochs, (time() - start_time) / 60)

        file.write(text_1)
        file.write(text_2)
        file.write(text_3)

    return Losses, Accuracies, bestLossEpoch, bestAccsEpoch


def predict_(model=None, dataloader=None, device=None, path=None):
    if path:
        model.load_state_dict(torch.load(path, map_location=device)["model_state_dict"])
    
    model.eval()
    
    y_pred = torch.zeros(1, 1).to(device)
    for X in dataloader:
        X = X.to(cfg.device)
        with torch.no_grad():
            output = torch.sigmoid(model(X))
        y_pred = torch.cat((y_pred, output), dim=0)
    
    return y_pred[1:].detach().cpu().numpy()

# Training and Validation

In [None]:
cfg = CFG(HL=[128], epochs=10, batch_size=512, tr_va_split=0.2)

X_train, X_valid, y_train, y_valid = train_test_split(train_features, train_labels, test_size=cfg.tr_va_split, shuffle=True, random_state=seed)

tr_data_setup = DS(X=X_train, y=y_train.reshape(-1, 1), mode="train")
va_data_setup = DS(X=X_valid, y=y_valid.reshape(-1, 1), mode="valid")
tr_data = DL(tr_data_setup, batch_size=cfg.tr_batch_size, shuffle=True, pin_memory=True, generator=torch.manual_seed(seed),)
va_data = DL(va_data_setup, batch_size=cfg.va_batch_size, shuffle=False, pin_memory=True)

torch.manual_seed(seed)
model = Classifier(IL=cfg.IL, HL=cfg.HL, OL=cfg.OL, use_DP=True, DP=0.5)
optimizer = model.getOptimizer(lr=1e-3, wd=0)

L, A, BLE, BAE = fit_(model=model, optimizer=optimizer, scheduler=None, epochs=cfg.epochs,
                      trainloader=tr_data, validloader=va_data, device=cfg.device,
                      criterion=nn.BCEWithLogitsLoss(),
                      save_to_file=True,
                      path="./", verbose=True)

TL, VL, TA, VA = [], [], [], []

for i in range(len(L)):
    TL.append(L[i]["train"])
    VL.append(L[i]["valid"])
    TA.append(A[i]["train"])
    VA.append(A[i]["valid"])

x_Axis = np.arange(len(L))
plt.figure(figsize=(12, 6))
plt.title("Metrics for M1")
plt.subplot(1, 2, 1)
plt.plot(x_Axis, TL, "r", label="Training Loss")
plt.plot(x_Axis, VL, "b--", label="validation Loss")
plt.legend()
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(x_Axis, TA, "r", label="Training Accuracy")
plt.plot(x_Axis, VA, "b--", label="validation Accuracy")
plt.legend()
plt.grid()
plt.show()

# Predictions

In [None]:
ts_data_setup = DS(X=test_features, y=None, mode="test")
ts_data = DL(ts_data_setup, batch_size=cfg.ts_batch_size, shuffle=False)

y_pred = predict_(model=model, dataloader=ts_data, device=cfg.device, path="./Epoch_{}.pt".format(BLE))

y_pred[y_pred > 0.5] = 1
y_pred[y_pred <= 0.5] = 0

precision, recall, f1, _ = precision_recall_fscore_support(y_pred, test_labels)

breaker()
print("Accuracy  : {:.5f}".format(accuracy_score(y_pred, test_labels)))
print("F1 Score  : {:.5f}".format(f1[0]))
print("Precision : {:.5f}".format(precision[0]))
print("Recall    : {:.5f}".format(recall[0]))
breaker()