In [None]:
#if you use anaconda env below installations should possibly suffice
pip install shap
pip install tensorflow
pip install python-whois


In [None]:
#preprocessing and feature creation

import pandas as pd
import numpy as np
from urllib.parse import urlparse
import tldextract
import socket
import ssl
import re
import ipaddress
from math import log
from re import compile
from urllib.parse import urlparse
from socket import gethostbyname
from googlesearch import search
from requests import get
from json import dump
from string import ascii_lowercase
from numpy import array
import requests
import whois
from bs4 import BeautifulSoup
from datetime import datetime
from dateutil.relativedelta import relativedelta

#read CSV file
ds = pd.read_csv("C:\\Users\\dollie\\Downloads\\urldata.csv")

#shuffling the data to ensure uniqueness
ds = ds.sample(frac=1, random_state=42)

#pick 5k each
ds_good = ds[ds['label']=='good'].sample(n=5000, random_state=42)
ds_bad = ds[ds['label']=='bad'].sample(n=5000,random_state=42)

#concatenate
ds_sampled = pd.concat([ds_good,ds_bad])

#changing label
ds_sampled['label'] = ds_sampled['label'].map({'good':'benign', 'bad':'malicious'})

#feature extraction

# 1.Check if url is an IP address
def is_ip(url):
    match = re.search(r'(([01]?\d\d?|2[0-4]\d|25[0-5])\.([01]?\d\d?|2[0-4]\d|25[0-5])\.([01]?\d\d?|2[0-4]\d|25[0-5])\.([01]?\d\d?|2[0-4]\d|25[0-5]))', url)  # IPv4
    if match:
        return 1  # phishing
    else:
        return 0  # legit

ds_sampled['using_ip'] = ds_sampled['url'].apply(is_ip)

# 2. Check if url is long
def long_url(url):
    if len(url) >= 54:
        return 1 #suspicious
    else:
        return 0 #legit
ds_sampled['long_url'] = ds_sampled['url'].apply(lambda x: long_url(x))

#3.url shortener
def shortening_service(url):
    match = re.search('bit\.ly|goo\.gl|shorte\.st|go2l\.ink|x\.co|ow\.ly|t\.co|tinyurl|tr\.im|is\.gd|cli\.gs|'
                      'yfrog\.com|migre\.me|ff\.im|tiny\.cc|url4\.eu|twit\.ac|su\.pr|twurl\.nl|snipurl\.com|'
                      'short\.to|BudURL\.com|ping\.fm|post\.ly|Just\.as|bkite\.com|snipr\.com|fic\.kr|loopt\.us|'
                      'doiop\.com|short\.ie|kl\.am|wp\.me|rubyurl\.com|om\.ly|to\.ly|bit\.do|t\.co|lnkd\.in|'
                      'db\.tt|qr\.ae|adf\.ly|goo\.gl|bitly\.com|cur\.lv|tinyurl\.com|ow\.ly|bit\.ly|ity\.im|'
                      'q\.gs|is\.gd|po\.st|bc\.vc|twitthis\.com|u\.to|j\.mp|buzurl\.com|cutt\.us|u\.bb|yourls\.org|'
                      'x\.co|prettylinkpro\.com|scrnch\.me|filoops\.info|vzturl\.com|qr\.net|1url\.com|tweez\.me|v\.gd|'
                      'tr\.im|link\.zip\.net', url)
    if match:
        return 1 #phishing
    else:
        return 0 #legit
ds_sampled['url'] = ds_sampled['url'].astype(str)
ds_sampled['using_ip'] = ds_sampled['url'].apply(lambda x: is_ip(x))

# 4. Check if url has @ symbol
def have_at(url):
    if "@" in url:
        return 1 #phishing
    else:
        return 0 #legit
ds_sampled['having_@_symbol'] = ds_sampled['url'].apply(lambda x: have_at(x))

# 5. Check if url uses redirection "//"
def redirection(url):
    position = urlparse(url).path.find("//")
    if position > 7:
        return 1 #phishing
    else:
        return 0 #legit
ds_sampled['redirection_//_symbol'] = ds_sampled['url'].apply(lambda x: redirection(x))

# 6.Adding Prefix or Suffix Separated by (-) to the Domain

def has_dash(url):
    domain = urlparse(url).netloc
    if '-' in domain:
        return 1  # Phishing
    else:
        return 0  # Legit
ds_sampled['has_dash_-_symbol'] = ds_sampled['url'].apply(lambda x: has_dash(x))

# 7.Sub Domain and Multi Sub Domains

def sub_domains(url):
    domain = urlparse(url).netloc
    if domain.count('.') == 1:
        return 0 # legit
    else:
        return 1 #phish
ds_sampled['numof_sub_domains']  = ds_sampled['url'].apply(lambda x: sub_domains(x))  

#8. Using Non-Standard Port
def non_standard_port(url):
    match = re.search(r':\d+', url)  # port pattern
    if match:
        port = match.group()[1:]  # Extract the port number without the leading ':'
        try:
            port = int(port)
            if port != 80 and port != 443:  # Not a common port for HTTP or HTTPS
                return 1  # phishing
        except ValueError:
            return 1  # phishing
    return 0  # legit

ds_sampled['non_std_port'] = ds_sampled['url'].apply(non_standard_port)

#9.The Existence of “HTTPS” Token in the Domain Part of the URL

def has_https_token(url):
    domain = urlparse(url).netloc
    if 'https' in domain:
        return 1  # Phishing
    else:
        return 0  # Legit
# 10. hostname as substring
def abnormal_url(url):
    hostname = urlparse(url).hostname
    hostname = str(hostname)
    match = re.search(hostname, url)
    if match:
        # print match.group()
        return 1 #abnormal
    else:
        # print 'No matching pattern found'
        return 0 #normal
ds_sampled['abnormal_url'] = ds_sampled['url'].apply(lambda x: abnormal_url(x))

# 11. is google indexed?
def google_index(url):
    site = search(url, 5)
    return 1 if site else 0
ds_sampled['google_index'] = ds_sampled['url'].apply(lambda x: google_index(x))

# 12. number of dots
def count_dot(url):
    count_dot = url.count('.')
    return count_dot
ds_sampled['count.'] = ds_sampled['url'].apply(lambda x: count_dot(x))

# 13. www count
def count_www(url):
    url.count('www')
    return url.count('www')
ds_sampled['count-www'] = ds_sampled['url'].apply(lambda x: count_www(x))

# 14. @count
def count_atrate(url):
     
    return url.count('@')
ds_sampled['count@'] = ds_sampled['url'].apply(lambda x: count_atrate(x))

# 15. no/of dir

def no_of_dir(url):
    urldir = urlparse(url).path
    return urldir.count('/')
ds_sampled['count_dir'] = ds_sampled['url'].apply(lambda x: no_of_dir(x))

# 16. no.of substrings
def no_of_embed(url):
    urldir = urlparse(url).path
    return urldir.count('//')
ds_sampled['count_embed_domian'] = ds_sampled['url'].apply(lambda x: no_of_embed(x))

# 17. count%
def count_per(url):
    return url.count('%')
ds_sampled['count%'] = ds_sampled['url'].apply(lambda x : count_per(x))

# 18. count ?
def count_ques(url):
    return url.count('?')
ds_sampled['count?'] = ds_sampled['url'].apply(lambda x: count_ques(x))

# 19. count -
def count_hyphen(url):
    return url.count('-')
ds_sampled['count-'] = ds_sampled['url'].apply(lambda x: count_hyphen(x))

# 20. count =
def count_equal(url):
    return url.count('=')
ds_sampled['count='] = ds_sampled['url'].apply(lambda x: count_equal(x))

# 21. url length
def url_length(url):
    return len(str(url))
ds_sampled['url_length'] = ds_sampled['url'].apply(lambda x: url_length(x))

# 22. hostname length
def hostname_length(url):
    return len(urlparse(url).netloc)
ds_sampled['hostname_length'] = ds_sampled['url'].apply(lambda x: hostname_length(x))
#23
ds_sampled.head()
#24.sus words

def suspicious_words(url):
    match = re.search('PayPal|login|signin|bank|account|update|free|lucky|service|bonus|ebayisapi|webscr',
                      url)
    if match:
        return 1
    else:
        return 0
ds_sampled['sus_url'] = ds_sampled['url'].apply(lambda x: suspicious_words(x))

#25.http
def http_over_https(url):
    if 'http://' in url:
        return 1
    return 0
ds_sampled['has_http'] = ds_sampled['url'].apply(lambda x: http_over_https(x))

#26. special chars
def special_chars_in_url(url):
    special_chars = ["<", ">", "#", "&", "*"]  
    for char in special_chars:
        if char in url:
            return 1
    return 0
ds_sampled['has_special_char'] = ds_sampled['url'].apply(lambda x: special_chars_in_url(x))

def getDomain(url):
    domain = urlparse(url).netloc
    if not domain:  # This checks if domain is None or empty
        return ""   # Returns an empty string instead of None
    if re.match(r"^www.",domain):
        domain = domain.replace("www.","")
    return domain

# 27. misspell
def misspelled_brand_name(url):
    domain = getDomain(url)
    if domain is not None:
        brand_names = ["goggle","gooqle","googledocument", "facebok","face-book", "appl", "microsotf","micro-soft","microsoftsup", "amazn"]  
        for brand in brand_names:
            if brand in domain:
                return 1
    return 0
ds_sampled['misspell_brand_name'] = ds_sampled['url'].apply(lambda x: misspelled_brand_name(x))

#28. URL entropy
def get_entropy(url):
    s = url.strip()
    prob = [float(s.count(c)) / len(s) for c in dict.fromkeys(list(s))]
    entropy = - sum(p * np.log2(p) for p in prob)
    return entropy

ds_sampled['url_entropy'] = ds_sampled['url'].apply(lambda x: get_entropy(x))

#29. Total number of digits in URL string
def count_digits(url):
    return sum(c.isdigit() for c in url)

ds_sampled['count_digits'] = ds_sampled['url'].apply(lambda x: count_digits(x))

# 30. Total number of query parameters in URL
def count_parameters(url):
    parameters = urlparse(url).query.split('&')
    if parameters[0] == '':
        return 0
    else:
        return len(parameters)

ds_sampled['count_parameters'] = ds_sampled['url'].apply(lambda x: count_parameters(x))

# 31.Total number of fragments in URL
def count_fragments(url):
    fragments = urlparse(url).fragment.split('#')
    if fragments[0] == '':
        return 0
    else:
        return len(fragments)

ds_sampled['count_fragments'] = ds_sampled['url'].apply(lambda x: count_fragments(x))

# 32. Check if page is online
def is_online(url):
    try:
        response = requests.get(url, timeout=5)
        return 1 if response.status_code == 200 else 0
    except requests.RequestException:
        return 0

ds_sampled['is_online'] = ds_sampled['url'].apply(lambda x: is_online(x))

# 33. Number of days since domain was registered
def days_since_domain_registered(url):
    domain_name = urlparse(url).netloc
    try:
        whois = get_whois(domain_name)
        creation_date = whois['creation_date'][0]
        if isinstance(creation_date, str):
            creation_date = datetime.datetime.strptime(creation_date, "%Y-%m-%d")
        delta = datetime.datetime.now() - creation_date
        return delta.days
    except:
        return -1

ds_sampled['days_since_domain_registered'] = ds_sampled['url'].apply(lambda x: days_since_domain_registered(x))

#34 Number of days until domain expires
def days_until_domain_expires(url):
    domain_name = urlparse(url).netloc
    try:
        whois = get_whois(domain_name)
        expiration_date = whois['expiration_date'][0]
        if isinstance(expiration_date, str):
            expiration_date = datetime.datetime.strptime(expiration_date, "%Y-%m-%d")
        delta = expiration_date - datetime.datetime.now()
        return delta.days
    except:
        return -1

ds_sampled['days_until_domain_expires'] = ds_sampled['url'].apply(lambda x: days_until_domain_expires(x))

#35 Total number of characters in HTML page
def num_html_chars(url):
    try:
        response = requests.get(url, timeout=5)
        return len(response.content)
    except requests.RequestException:
        return -1

ds_sampled['num_html_chars'] = ds_sampled['url'].apply(lambda x: num_html_chars(x))

def get_html_content(url):
    try:
        response = requests.get(url, timeout=5)
        return BeautifulSoup(response.content, 'html.parser')
    except requests.RequestException:
        return None

# 36. Total number of h1-h6 tags in HTML
def num_headings(url):
    soup = get_html_content(url)
    if not soup:
        return -1
    return len(soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']))

ds_sampled['num_headings'] = ds_sampled['url'].apply(lambda x: num_headings(x))

#37. Total number of images in HTML
def num_images(url):
    soup = get_html_content(url)
    if not soup:
        return -1
    return len(soup.find_all('img'))

ds_sampled['num_images'] = ds_sampled['url'].apply(lambda x: num_images(x))

# 38. Total number of links in HTML
def num_links(url):
    soup = get_html_content(url)
    if not soup:
        return -1
    return len(soup.find_all('a'))

ds_sampled['num_links'] = ds_sampled['url'].apply(lambda x: num_links(x))

# 39. Total number of characters in scripts
def num_script_chars(url):
    soup = get_html_content(url)
    if not soup:
        return -1
    scripts = soup.find_all('script')
    return sum(len(script.text) for script in scripts)

ds_sampled['num_script_chars'] = ds_sampled['url'].apply(lambda x: num_script_chars(x))

# 40.Total number of special characters
def num_special_chars(url):
    soup = get_html_content(url)
    if not soup:
        return -1
    content = str(soup)
    special_chars = '!@#$%^&*()_+=-[]{}|;:,.<>?/~`'
    return sum(content.count(char) for char in special_chars)

ds_sampled['num_special_chars'] = ds_sampled['url'].apply(lambda x: num_special_chars(x))

# 41. Script to special character ratio
ds_sampled['script_to_special_char_ratio'] = ds_sampled['num_script_chars'] / ds_sampled['num_special_chars']

#42.Script to body ratio
def script_to_body_ratio(url):
    soup = get_html_content(url)
    if not soup:
        return -1
    body = soup.find('body')
    if not body:
        return -1
    scripts = soup.find_all('script')
    return sum(len(script.text) for script in scripts) / len(body.text)

ds_sampled['script_to_body_ratio'] = ds_sampled['url'].apply(lambda x: script_to_body_ratio(x))

#43. Number of redirects before reaching the final page
def count_redirects(url):
    try:
        response = requests.get(url)
        if len(response.history) > 0:
            return len(response.history)
        else:
            return 0
    except:
        return 0
ds_sampled['count_redirects'] = ds_sampled['url'].apply(lambda x: count_redirects(x))

#44. Number of links to external domains in the HTML.
def count_external_links(html):
    bs = BeautifulSoup(html, 'html.parser')
    links = bs.findAll('a', href=re.compile('^http[s]?://'))
    return len(links)
ds_sampled['count_external_links'] = ds_sampled['url'].apply(lambda x: count_external_links(x))

#45.Number of iframe tags in the HTML.
def count_iframes(html):
    bs = BeautifulSoup(html, 'html.parser')
    iframes = bs.findAll('iframe')
    return len(iframes)
ds_sampled['count_iframes'] = ds_sampled['url'].apply(lambda x: count_iframes(x))

#46.has_privacy_policy - Binary feature indicating whether the website has a privacy policy page.
def has_privacy_policy(url):
    try:
        response = requests.get(url)
        soup = BeautifulSoup(response.text, 'html.parser')
        return any(['privacy' in a['href'] for a in soup.select('a[href]') if a['href']])
    except:
        return False
ds_sampled['has_privacy_policy'] = ds_sampled['url'].apply(lambda x: has_privacy_policy(x))

#47. domain age
def domain_registered_long_ago(url):
    try:
        w = whois.whois(url)
        if w.creation_date is None:
            return -1
        elif type(w.creation_date) is list:
            creation_date = w.creation_date[0]
        else:
            creation_date = w.creation_date

        if creation_date < datetime.now() - relativedelta(years=5):  # Change the years as per requirement
            return 1
        else:
            return 0
    except Exception:
        return -1
ds_sampled['domain_registered_long_ago'] = ds_sampled['url'].apply(lambda x: domain_registered_long_ago(x))

#48. domain does not expire until
def domain_expiration_far_future(url):
    try:
        w = whois.whois(url)
        if w.expiration_date is None:
            return -1
        elif type(w.expiration_date) is list:
            expiration_date = w.expiration_date[0]
        else:
            expiration_date = w.expiration_date

        if expiration_date > datetime.now() + relativedelta(years=1):  # Change the years as per requirement
            return 1
        else:
            return 0
    except Exception:
        return -1

ds_sampled['domain_expiration_far_future'] = ds_sampled['url'].apply(lambda x: domain_expiration_far_future(x))

#49.favicon loaded from external domain?
def favicon_external(url):
    soup = get_html_content(url)
    if not soup:
        return -1
    link_tags = soup.find_all('link', rel=lambda x: x and 'icon' in x)
    for tag in link_tags:
        favicon_url = tag.get('href')
        if not favicon_url:
            continue
        if 'http' in favicon_url and urlparse(favicon_url).netloc != urlparse(url).netloc:
            return 1
    return 0
ds_sampled['favicon_external'] = ds_sampled['url'].apply(lambda x: favicon_external(x))

#50.using iframe =phishing 
def using_iframe(url):
    soup = get_html_content(url)
    if not soup:
        return -1
    iframe_tags = soup.find_all('iframe')
    return 1 if iframe_tags else 0 

print(ds_sampled.head())
#processed dataset
ds_sampled.to_csv('C:\\Users\\dollie\\Downloads\\processed_urldata_50.csv', index=False)




In [None]:
#decision trees - final code - includes feature importance

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.preprocessing import LabelEncoder

# Load the processed dataset
df = pd.read_csv('C:\\Users\\dollie\\Downloads\\processed_urldata_new.csv')

# Separate the features (X) from the target variable (y)
X = df.drop(['label'], axis=1)  
y = df['label']

# Initialize the label encoder
le = LabelEncoder()

# Transform all columns of X
for col in X.columns:
    if X[col].dtype == 'object':
        X[col] = le.fit_transform(X[col])

# Initialize the Decision Tree model
model = DecisionTreeClassifier()

# Initialize the StratifiedKFold class
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

# Store accuracy scores
scores = []

# Initialize fold counter
fold = 0

# Initialize the DataFrame to store feature importances
feature_importances = pd.DataFrame(index=X.columns)

# Initialize confusion matrix to store summed confusion matrix
cm_sum = np.zeros((2,2))

for train_index, test_index in skf.split(X, y):
    # Increment fold counter
    fold += 1
    
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train, y_test = y.iloc[train_index], y.iloc[test_index]

    # Train the model on the training data
    model.fit(X_train, y_train)

    # Use the trained model to predict the labels of the test data
    y_pred = model.predict(X_test)

    # Calculate accuracy and append to scores list
    accuracy = accuracy_score(y_test, y_pred)
    scores.append(accuracy)

    # Evaluate the model
    print(f"\nFold: {fold}")
    print(f"Accuracy: {accuracy}")

    # Sum up the confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    cm_sum += cm

    print("\nClassification Report:\n", classification_report(y_test, y_pred))

    # Save the feature importances
    feature_importances[f'fold_{fold}'] = model.feature_importances_

# Print the summed confusion matrix
plt.figure(figsize=(5,5))
sns.heatmap(cm_sum, annot=True, fmt=".1f")
plt.title('Summed confusion matrix')
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.show()

# Print the feature importance DataFrame
print("\nFeature Importances:\n", feature_importances)

# Calculate mean and standard deviation of feature importances
feature_importances['mean'] = feature_importances.mean(axis=1)
feature_importances['std'] = feature_importances.std(axis=1)

# Sort features by mean importance
feature_importances = feature_importances.sort_values(by='mean', ascending=False)

# Plot the mean feature importance
plt.figure(figsize=(12,8))
plt.title('Feature Importances')
sns.barplot(x=feature_importances['mean'], y=feature_importances.index, xerr=feature_importances['std'], color='b')
plt.xlabel('Relative Importance')
plt.show()

print("Cross-validation scores: ", scores)
print("Average cross-validation score: ", sum(scores)/len(scores))


In [None]:
#SVM - final code - includes feature importance

import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load the processed dataset
df = pd.read_csv('C:\\Users\\dollie\\Downloads\\processed_urldata_new.csv')

# Separate the features (X) from the target variable (y)
X = df.drop(['label'], axis=1)  
y = df['label']

# Initialize the label encoder
le = LabelEncoder()

# Transform all columns of X
for col in X.columns:
    if X[col].dtype == 'object':
        X[col] = le.fit_transform(X[col])

# Initialize the Support Vector Machine model with a linear kernel
model = SVC(kernel='linear') # Make sure it's linear!

# Initialize the DataFrame to store feature importances
feature_importances = pd.DataFrame(index=X.columns)

# 10-fold cross-validation
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
scores = []

# Initialize the cumulative confusion matrix
cumulative_cm = np.zeros((2,2)) # Modify this if you have more than two classes

fold = 0
for train_index, test_index in skf.split(X, y):
    fold += 1
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train, y_test = y[train_index], y[test_index]

    # Train the model
    model.fit(X_train, y_train)

    # Test the model
    y_pred = model.predict(X_test)

    # Calculate accuracy and append to scores list
    accuracy = accuracy_score(y_test, y_pred)
    scores.append(accuracy)
    
    # Evaluate the model
    print(f"\nFold: {fold}")
    print(f"Accuracy: {accuracy}")

    # Add the confusion matrix to the cumulative one
    cm = confusion_matrix(y_test, y_pred)
    cumulative_cm += cm

    # Save the feature importances
    feature_importances[f'fold_{fold}'] = abs(model.coef_[0]) # Absolute value of coefficients

# Plot the cumulative confusion matrix
plt.figure(figsize=(5,5))
sns.heatmap(cumulative_cm, annot=True, fmt=".1f")
plt.title('Cumulative Confusion matrix')
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.show()

# Calculate mean and standard deviation of feature importances
feature_importances['mean'] = feature_importances.mean(axis=1)
feature_importances['std'] = feature_importances.std(axis=1)

# Sort features by mean importance
feature_importances = feature_importances.sort_values(by='mean', ascending=False)

# Plot the mean feature importance
plt.figure(figsize=(12,8))
plt.title('Feature Importances')
sns.barplot(x=feature_importances['mean'], y=feature_importances.index, xerr=feature_importances['std'], color='b')
plt.xlabel('Relative Importance')
plt.show()
    
print("Cross-validation scores: ", scores)
print("Average cross-validation score: ", sum(scores)/len(scores))


In [None]:
#neural networks - includes feature importance
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import seaborn as sns
import shap  

# Load the processed dataset
df = pd.read_csv('C:\\Users\\dollie\\Downloads\\processed_urldata_new.csv')

# Separate the features (X) from the target variable (y)
X = df.drop(['label'], axis=1)  
y = df['label']

# Initialize the label encoder
le = LabelEncoder()

# Transform all columns of X
for col in X.columns:
    if X[col].dtype == 'object':
        X[col] = le.fit_transform(X[col])

# Convert labels to categorical
y = to_categorical(le.fit_transform(y))

# Initialize the StratifiedKFold class
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

# Store accuracy scores
scores = []

# Initialize fold counter
fold = 0

# Initialize the cumulative confusion matrix
cumulative_cm = np.zeros((2,2)) # Modify this if you have more than two classes

# Convert all data into float data type
X = X.astype(float)

for train_index, test_index in skf.split(X, y.argmax(1)):
    # Increment fold counter
    fold += 1
    
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train, y_test = y[train_index], y[test_index]

    # Define the model architecture
    model = Sequential()
    model.add(Dense(32, input_dim=X_train.shape[1], activation='relu'))
    model.add(Dense(16, activation='relu'))
    model.add(Dense(2, activation='softmax'))  # 2 represents number of classes

    # Compile the model
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    # Train the model
    model.fit(X_train, y_train, epochs=50, batch_size=32, verbose=0)

    # Use the trained model to predict the labels of the test data
    y_pred = model.predict(X_test)
    y_pred = np.argmax(y_pred, axis=1)

    # Calculate accuracy and append to scores list
    accuracy = accuracy_score(y_test.argmax(1), y_pred)
    scores.append(accuracy)

    # Evaluate the model
    print(f"\nFold: {fold}")
    print(f"Accuracy: {accuracy}")

    # Add the confusion matrix to the cumulative one
    cm = confusion_matrix(y_test.argmax(1), y_pred)
    cumulative_cm += cm

print("Cross-validation scores: ", scores)
print("Average cross-validation score: ", sum(scores)/len(scores))

# Plot the cumulative confusion matrix
plt.figure(figsize=(5,5))
sns.heatmap(cumulative_cm, annot=True, fmt=".1f")
plt.title('Cumulative Confusion matrix')
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.show()

# Calculate SHAP values and plot the summary
background = X_train.iloc[np.random.choice(X_train.shape[0], 100, replace=False)]
explainer = shap.KernelExplainer(model.predict, background)
shap_values = explainer.shap_values(X_test, nsamples=100)
shap.summary_plot(shap_values, X_test, plot_type="bar")