In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from imblearn.over_sampling import SMOTE
from xgboost import XGBClassifier
from collections import Counter

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.compose import ColumnTransformer

In [2]:
# Load the data
data = pd.read_excel('./data/data_new.xlsx')

# Display the first few rows of the data
print(data.head())

         SourceType         SourceDescription        TargetName  \
0     System/Server    Payment Gateway System   Web Application   
1  Software/Program  User Management Software  Software/Program   
2            Device                IoT Device            Device   
3   Web Application   Online Shopping Website   Web Application   
4     System/Server              Email Server     System/Server   

         TargetType TargetDescription Relationship AuthRequired Encryption  \
0          Database          Database       Writes          Yes        Yes   
1          Database          Database        Reads          Yes         No   
2     System/Server     System/Server   Sends Data           No        Yes   
3          Database          Database        Reads          Yes        Yes   
4  Software/Program  Software/Program        Sends           No        Yes   

  EncryptionType  DataFormat     Frequency      DataIntegrity AccessType  \
0        TLS/SSL        JSON     Real-Time  Digital 

In [3]:
for col in data.columns:
    print(col)

SourceType
SourceDescription
TargetName
TargetType
TargetDescription
Relationship
AuthRequired
Encryption
EncryptionType
DataFormat
Frequency
DataIntegrity
AccessType
AccessTarget
NetworkProtocol
CommunicationChannel
CredentialStorage
Interactor
Threat


In [4]:
data = data[["SourceType", "TargetType", "AuthRequired", "Encryption", "EncryptionType", "DataFormat", "Frequency", "DataIntegrity", "AccessType",
             "AccessTarget", "NetworkProtocol", "CommunicationChannel", "CredentialStorage", "Interactor", "Threat"]]

In [5]:
empty_cells_per_column = data.isna().sum()
print("Number of empty cells per column:")
empty_cells_per_column

Number of empty cells per column:


SourceType                0
TargetType                0
AuthRequired              0
Encryption                0
EncryptionType          197
DataFormat                2
Frequency                 0
DataIntegrity           129
AccessType                4
AccessTarget             37
NetworkProtocol           5
CommunicationChannel     13
CredentialStorage       153
Interactor               11
Threat                   25
dtype: int64

In [6]:
# Drop rows where the target variable is NaN
data = data.dropna(subset=['Threat'])

In [7]:
mapping = {
    'Phishing Attacks': 'Phishing Attack',
    'Spear Phishing Attacks': 'Spear Phishing Attack',
    'Drive-by Download Attacks': 'Drive-by Download Attack',
    'Distributed Denial of Service (DDoS)': 'DDoS Attack',
    'Denial of Service (DoS)': 'DDoS Attack',
    # Add other necessary mappings here
}

# Apply the mapping to the DataFrame
data['Threat'] = data['Threat'].replace(mapping)

In [8]:
class_counts = data['Threat'].value_counts()
class_counts

Threat
DDoS Attack                          55
Cross-Site Scripting (XSS)           55
Phishing Attack                      53
Malware Attack                       53
SQL Injection                        52
Man-in-the-Middle (MitM) Attack      51
API Security Breach                  50
Malvertising                         47
Ransomware Attack                    46
Drive-by Download Attack             44
Cross-Site Request Forgery (CSRF)    44
Spear Phishing Attack                43
Zero-day Exploit                     43
Cryptojacking                        43
Credential Stuffing                  42
Remote Code Execution (RCE)          41
Directory Traversal                  36
Side-Channel Attack                  35
Password Attack                      30
Name: count, dtype: int64

In [9]:
# Remove classes with fewer than 2 instances
valid_classes = class_counts[class_counts > 1].index
data = data[data['Threat'].isin(valid_classes)]


In [10]:
data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 863 entries, 0 to 886
Data columns (total 15 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   SourceType            863 non-null    object
 1   TargetType            863 non-null    object
 2   AuthRequired          863 non-null    object
 3   Encryption            863 non-null    object
 4   EncryptionType        672 non-null    object
 5   DataFormat            861 non-null    object
 6   Frequency             863 non-null    object
 7   DataIntegrity         742 non-null    object
 8   AccessType            859 non-null    object
 9   AccessTarget          826 non-null    object
 10  NetworkProtocol       858 non-null    object
 11  CommunicationChannel  851 non-null    object
 12  CredentialStorage     716 non-null    object
 13  Interactor            852 non-null    object
 14  Threat                863 non-null    object
dtypes: object(15)
memory usage: 107.9+ KB


In [11]:
#Fill nan cells with values
data[['EncryptionType', 'CredentialStorage']] = data[['EncryptionType', 'CredentialStorage']].fillna("None")
data[['AuthRequired', 'DataIntegrity', 'NetworkProtocol', 'CommunicationChannel']] = data[['AuthRequired', 'DataIntegrity', 'NetworkProtocol', 'CommunicationChannel']].fillna("No")

In [12]:
data

Unnamed: 0,SourceType,TargetType,AuthRequired,Encryption,EncryptionType,DataFormat,Frequency,DataIntegrity,AccessType,AccessTarget,NetworkProtocol,CommunicationChannel,CredentialStorage,Interactor,Threat
0,System/Server,Database,Yes,Yes,TLS/SSL,JSON,Real-Time,Digital Signature,Write,Database,HTTP/HTTPS,Wired,Secure Vault,User,SQL Injection
1,Software/Program,Database,Yes,No,,XML,Periodic,Hash,Read,Database,TCP/IP,Wireless,Encrypted,System,Cross-Site Scripting (XSS)
2,Device,System/Server,No,Yes,Asymmetric,Binary,Real-Time,Checksum,Write,Web Service/API,MQTT,Wired,Hashed,User,DDoS Attack
3,Web Application,Database,Yes,Yes,Hashing,JSON,Event-Driven,Digital Signature,Read,Database,HTTP/HTTPS,Virtual Private Network,Encrypted,User,Phishing Attack
4,System/Server,Software/Program,No,Yes,TLS/SSL,Plain Text,Batch,No,Write,Message Queue,FTP/SFTP,Wireless,Environment Variable,System,Ransomware Attack
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
882,Device,Database,Yes,Yes,Asymmetric,JSON,Real-Time,Checksum,Write,Database,HTTP/HTTPS,Wired,Encrypted,Device,Zero-day Exploit
883,Software/Program,Web Application,Yes,Yes,TLS/SSL,JSON,Real-Time,Digital Signature,Write,Web Service/API,FTP/SFTP,Virtual Private Network,Secure Vault,User,Cross-Site Request Forgery (CSRF)
884,Web Application,Service,Yes,Yes,Hashing,XML,Periodic,Checksum,Write,File System,HTTP/HTTPS,Wireless,Encrypted,User,Directory Traversal
885,Software/Program,Web Application,Yes,Yes,Symmetric,CSV,On-Demand,Hash,Read,Database,HTTP/HTTPS,Wired,Hashed,User,Drive-by Download Attack


In [13]:
# Define the groupings for the Source Type column
# Define the groupings
groupings = {
    "Database": [
        "Database", "Database System", "Backend Database", "CRM Database", "HR Database", 
        "Payment Database", "Medical Database", "Warehouse Management System", "E-commerce Database"
    ],
    "System and Server": [
        "System", "Internal System", "External System", "Server", "Platform", "Network", "API", 
        "Web Server", "Application Server", "LDAP server", "POS System", "Electronic Health Record (EHR) System", 
        "Banking System", "HR System", "Time Tracking System", "Telematics System", "Industrial Control System", 
        "Centralized System", "ERP Software", "Payment gateway", "CRM System"
    ],
    "Web and Application": [
        "Web Application", "Application", "Mobile Application", "Web Interface", "Web Platform", "Web Form", 
        "Website", "E-commerce Platform", "E-commerce Website", "Web Portal", "Webmail Application", 
        "Online Banking System", "Learning Management System", "Educational Website", "Social Networking Site", 
        "Messaging App", "Messaging Service", "Online Store", "Booking Website", "Streaming Platform", "Cloud Service", 
        "SaaS Platform", "Cloud Storage Service", "Survey Tool"
    ],
    "Device": [
        "Device", "IoT Device", "Edge Device", "Wearable Device", "Smart Home Device", "GPS Tracker", 
        "Barcode Scanner", "Sensor", "Time Clock System", "Environmental Sensor", "Industrial Control Device", 
        "Surveillance Device", "Smartphone", "Smart Thermostat", "Medical Device", "Telematics Device", 
        "Production Machines", "GPS Device", "Home Automation Gateway"
    ],
    "Service": [
        "Service", "External Service", "Internal Service", "Authentication Service", "Software as a Service (SaaS)", 
        "API Service", "Cloud Storage Service", "Payment Gateway", "Communication Tool", "Helpdesk Software", 
        "Logistics Software", "GPS Tracking Service", "Healthcare Service", "AI Chatbot", "AI Assistant"
    ],
    "Software/Program": [
        "Software", "Software Application", "Communication Software", "Educational Software", "Healthcare Software", 
        "Enterprise Software", "Enterprise Application", "Medical Software", "Security Equipment", "Messaging Software", 
        "Retail Software", "Business Management Software", "HR Software", "Warehouse Software", 
        "Inventory Management Software", "Team Workspace", "Market Data Provider", "Data Hosting Platform", 
        "Data Visualization Tool"
    ],
    "User": [
        "User", "External User", "Internal User", "Individual", "Person", "Customer", "Staff", "Worker", "Learner", 
        "Business", "User Input Form", "User Profiles", "External Entity", "Internal Department"
    ]
}

# Create a reverse lookup dictionary
reverse_lookup = {item: key for key, values in groupings.items() for item in values}

# Replace values in the SourceType column with the corresponding group name
data['SourceType'] = data['SourceType'].map(reverse_lookup).fillna(data['SourceType'])


In [14]:
# #Groups all values with little instances together as Other
# def shorten_categories(categories, cutoff):
#     categorical_map = {}
#     for i in range(len(categories)):
#         if categories.values[i] >= cutoff:
#             categorical_map[categories.index[i]] = categories.index[i]
#         else:
#             categorical_map[categories.index[i]] = 'Other'
#     return categorical_map

# for column in data.columns:
#     if data[column].dtype == 'object':  # Ensure the column is categorical
#         category_counts = data[column].value_counts()
#         category_map = shorten_categories(category_counts, 2)
#         data[column] = data[column].map(category_map)

In [15]:
def clean_Type(x):
    if 'Device' in x:
        return 'Device'
    if 'Data' in x:
        return 'Database'
    if 'App' in x or 'Software' in x or 'Program' in x:
        return 'Software/Program'
    if 'User' in x:
        return 'User'
    if 'Service' in x:
        return 'Service'
    if 'Server' in x:
        return 'Server'    
    return x

data['SourceType'] = data['SourceType'].apply(clean_Type)
data['TargetType'] = data['TargetType'].apply(clean_Type)

In [16]:
unique_values_count = data.nunique()

value_counts_per_column = {col: data[col].value_counts() for col in data.columns}

unique_values_count, value_counts_per_column


(SourceType               6
 TargetType               6
 AuthRequired             2
 Encryption               2
 EncryptionType           6
 DataFormat               7
 Frequency                5
 DataIntegrity            5
 AccessType               4
 AccessTarget            11
 NetworkProtocol          7
 CommunicationChannel    11
 CredentialStorage        7
 Interactor               8
 Threat                  19
 dtype: int64,
 {'SourceType': SourceType
  Software/Program    295
  Device              154
  Server              146
  Database            117
  Service              95
  User                 56
  Name: count, dtype: int64,
  'TargetType': TargetType
  Software/Program    223
  Server              216
  Service             187
  Database            150
  Device               46
  User                 41
  Name: count, dtype: int64,
  'AuthRequired': AuthRequired
  Yes    685
  No     178
  Name: count, dtype: int64,
  'Encryption': Encryption
  Yes    682
  No     181
  

In [17]:

# Separate features and target variable
X = data.drop('Threat', axis=1)
y = data['Threat']

# Encode the target variable
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

# Identify categorical and numerical columns
categorical_features = X.select_dtypes(include=['object', 'category']).columns
numerical_features = X.select_dtypes(include=['int64', 'float64']).columns

# Define preprocessing for numerical columns (impute missing values and scale)
numerical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())
])

# Define preprocessing for categorical columns (impute missing values and one-hot encode)
categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),
    ('onehot', OneHotEncoder(handle_unknown='ignore'))
])

# Combine preprocessing steps
preprocessor = ColumnTransformer(
    transformers=[
        ('num', numerical_transformer, numerical_features),
        ('cat', categorical_transformer, categorical_features)
    ])

# Preprocess the features
X_preprocessed = preprocessor.fit_transform(X)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_preprocessed, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded)

# Handle class imbalance using SMOTE
smote = SMOTE(random_state=42)
X_train_res, y_train_res = smote.fit_resample(X_train, y_train)



### Gradient Boosting

In [18]:

# Define the XGBoost model
model = XGBClassifier(random_state=42)

# Create the pipeline
clf = Pipeline(steps=[('classifier', model)])

# Define parameter grid for hyperparameter tuning
param_grid = {
    'classifier__n_estimators': [100, 200],
    'classifier__max_depth': [3, 6, 9],
    'classifier__learning_rate': [0.01, 0.1, 0.3],
    'classifier__subsample': [0.7, 1.0],
    'classifier__colsample_bytree': [0.7, 1.0]
}

# Perform grid search with cross-validation
grid_search = GridSearchCV(clf, param_grid, cv=3, scoring='accuracy', n_jobs=-1)

# Fit the model
grid_search.fit(X_train_res, y_train_res)

# Predict on the test set
y_pred = grid_search.predict(X_test)

# Decode the predictions back to original labels
y_test_decoded = label_encoder.inverse_transform(y_test)
y_pred_decoded = label_encoder.inverse_transform(y_pred)

# Evaluate the model
print("Best parameters found: ", grid_search.best_params_)
print("Accuracy:", accuracy_score(y_test_decoded, y_pred_decoded))
print("Classification Report:\n", classification_report(y_test_decoded, y_pred_decoded))

Best parameters found:  {'classifier__colsample_bytree': 1.0, 'classifier__learning_rate': 0.1, 'classifier__max_depth': 6, 'classifier__n_estimators': 100, 'classifier__subsample': 0.7}
Accuracy: 0.15028901734104047
Classification Report:
                                    precision    recall  f1-score   support

              API Security Breach       0.10      0.10      0.10        10
              Credential Stuffing       0.10      0.12      0.11         8
Cross-Site Request Forgery (CSRF)       0.00      0.00      0.00         9
       Cross-Site Scripting (XSS)       0.33      0.45      0.38        11
                    Cryptojacking       0.09      0.11      0.10         9
                      DDoS Attack       0.11      0.09      0.10        11
              Directory Traversal       0.00      0.00      0.00         7
         Drive-by Download Attack       0.00      0.00      0.00         9
                     Malvertising       0.00      0.00      0.00         9
        

### Decision Tree

In [19]:
# Define the Decision Tree model
model = DecisionTreeClassifier(random_state=42)

# Create the pipeline
clf = Pipeline(steps=[('classifier', model)])

# Define parameter grid for hyperparameter tuning
param_grid = {
    'classifier__max_depth': [None, 10, 20, 30],
    'classifier__min_samples_split': [2, 5, 10],
    'classifier__min_samples_leaf': [1, 2, 4],
    'classifier__max_features': ['auto', 'sqrt', 'log2']
}

# Perform grid search with cross-validation
grid_search = GridSearchCV(clf, param_grid, cv=3, scoring='accuracy', n_jobs=-1)

# Fit the model
grid_search.fit(X_train_res, y_train_res)

# Predict on the test set
y_pred = grid_search.predict(X_test)

# Decode the predictions back to original labels
y_test_decoded = label_encoder.inverse_transform(y_test)
y_pred_decoded = label_encoder.inverse_transform(y_pred)

# Evaluate the model
print("Best parameters found: ", grid_search.best_params_)
print("Accuracy:", accuracy_score(y_test_decoded, y_pred_decoded))
print("Classification Report:\n", classification_report(y_test_decoded, y_pred_decoded))

Best parameters found:  {'classifier__max_depth': None, 'classifier__max_features': 'log2', 'classifier__min_samples_leaf': 1, 'classifier__min_samples_split': 2}
Accuracy: 0.09248554913294797
Classification Report:
                                    precision    recall  f1-score   support

              API Security Breach       0.09      0.10      0.10        10
              Credential Stuffing       0.00      0.00      0.00         8
Cross-Site Request Forgery (CSRF)       0.00      0.00      0.00         9
       Cross-Site Scripting (XSS)       0.20      0.18      0.19        11
                    Cryptojacking       0.00      0.00      0.00         9
                      DDoS Attack       0.00      0.00      0.00        11
              Directory Traversal       0.33      0.14      0.20         7
         Drive-by Download Attack       0.17      0.11      0.13         9
                     Malvertising       0.00      0.00      0.00         9
                   Malware Attac

### Random Forest

In [20]:
# Define the Random Forest model
model = RandomForestClassifier(random_state=42)

# Create the pipeline
clf = Pipeline(steps=[('classifier', model)])

# Define parameter grid for hyperparameter tuning
param_grid = {
    'classifier__n_estimators': [100, 200, 300],
    'classifier__max_depth': [None, 10, 20, 30],
    'classifier__min_samples_split': [2, 5, 10],
    'classifier__min_samples_leaf': [1, 2, 4],
    'classifier__max_features': ['auto', 'sqrt', 'log2']
}

# Perform grid search with cross-validation
grid_search = GridSearchCV(clf, param_grid, cv=3, scoring='accuracy', n_jobs=-1)

# Fit the model
grid_search.fit(X_train_res, y_train_res)

# Predict on the test set
y_pred = grid_search.predict(X_test)

# Decode the predictions back to original labels
y_test_decoded = label_encoder.inverse_transform(y_test)
y_pred_decoded = label_encoder.inverse_transform(y_pred)

# Evaluate the model
print("Best parameters found: ", grid_search.best_params_)
print("Accuracy:", accuracy_score(y_test_decoded, y_pred_decoded))
print("Classification Report:\n", classification_report(y_test_decoded, y_pred_decoded))

Best parameters found:  {'classifier__max_depth': 20, 'classifier__max_features': 'log2', 'classifier__min_samples_leaf': 1, 'classifier__min_samples_split': 2, 'classifier__n_estimators': 200}
Accuracy: 0.1329479768786127
Classification Report:
                                    precision    recall  f1-score   support

              API Security Breach       0.00      0.00      0.00        10
              Credential Stuffing       0.00      0.00      0.00         8
Cross-Site Request Forgery (CSRF)       0.00      0.00      0.00         9
       Cross-Site Scripting (XSS)       0.24      0.45      0.31        11
                    Cryptojacking       0.29      0.22      0.25         9
                      DDoS Attack       0.00      0.00      0.00        11
              Directory Traversal       0.00      0.00      0.00         7
         Drive-by Download Attack       0.00      0.00      0.00         9
                     Malvertising       0.00      0.00      0.00         9
  

In [21]:
import pickle

# Create the pipeline with the preprocessor and the model
model = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', RandomForestClassifier(random_state=42))
])

# Define parameter grid for hyperparameter tuning
param_grid = {
    'classifier__n_estimators': [100, 200, 300],
    'classifier__max_depth': [None, 10, 20, 30],
    'classifier__min_samples_split': [2, 5, 10],
    'classifier__min_samples_leaf': [1, 2, 4],
    'classifier__max_features': ['auto', 'sqrt', 'log2']
}

# Perform grid search with cross-validation
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy', n_jobs=-1)

# Fit the model
grid_search.fit(X, y_encoded)

# Save the model and preprocessor to a pickle file
with open('saved_steps.pkl', 'wb') as file:
    pickle.dump(grid_search, file)
