In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xgboost as xgb

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,

)
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import RobustScaler

from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline

import time
import psutil
import threading
from memory_profiler import memory_usage

import joblib


Note: You have installed the 'manylinux2014' variant of XGBoost. Certain features such as GPU algorithms or federated learning are not available. To use these features, please upgrade to a recent Linux distro with glibc 2.28+, and install the 'manylinux_2_28' variant.


In [2]:
df = pd.read_csv('/home/wahba/Documents/dataset/processed/3_cicids2017_attack_only.csv')
df.drop(df[df['Label'] == 'Bot'].index, inplace=True)

selected_features = [
    # Flow-level
    'Flow Duration',
    'Flow Packets/s',
    'Flow Bytes/s',
    'Flow IAT Mean',
    'Flow IAT Max',
    'Flow IAT Std',
    
    # Forward features
    'Fwd Header Length',
    'Fwd IAT Total',
    'Fwd IAT Mean',
    'Fwd IAT Max',
    'Fwd IAT Std',
    'Fwd Packet Length Min',
    'Fwd Packet Length Max',
    'Fwd Packet Length Mean',
    'Fwd Packet Length Std',
    'Subflow Fwd Bytes',
    'Total Fwd Packets',
    'Total Length of Fwd Packets',
    
    # Backward features
    'Bwd Header Length',
    'Bwd Packet Length Min',
    'Bwd Packet Length Max',
    'Bwd Packet Length Std',
    'Bwd Packets/s',
    'Init_Win_bytes_backward',
    
    # Packet-level
    'Packet Length Mean',
    'Packet Length Std',
    'Packet Length Variance',
    'Average Packet Size',
    'PSH Flag Count',
    'Init_Win_bytes_forward',
    'Max Packet Length',

    'Label',
]

# Keep only the selected features
df = df[selected_features]

# 1.0 Dataset Preparation

## 1.1 Training and Testing Dataset Splitting

In [3]:
# splitting df for training and testing using stratified split
X = df.drop('Label', axis=1) # features
y = df['Label'] # target

split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, test_index in split.split(X, y):
    strat_train_set = df.loc[train_index]
    strat_test_set = df.loc[test_index]

X_train = strat_train_set.drop("Label", axis=1)
y_train = strat_train_set["Label"]

X_test = strat_test_set.drop("Label", axis=1)
y_test = strat_test_set["Label"]

label_mapping = {'DoS': 0, 'PortScan': 1, 'BruteForce': 2, 'WebAttack': 3}
y_train = y_train.map(label_mapping)
y_test  = y_test.map(label_mapping)

print(pd.DataFrame({
    "count (df)": df["Label"].value_counts(),
    "count (train_set)": strat_train_set["Label"].value_counts(),
    "count (test_set)": strat_test_set["Label"].value_counts(),
    "proportion": strat_train_set["Label"].value_counts(normalize=True),
})
)

KeyError: '[205713, 206277, 206633, 205822, 206037, 206389, 206273, 205316, 205932, 205348, 205511, 205816, 205265, 205633, 206226, 205579, 205334, 205043, 205635, 206576, 206583, 205625, 205141, 205185, 205087, 205950, 205440, 206981, 205602, 206298, 206544, 206463, 206379, 206630, 206883, 205549, 205243, 205177, 206851, 205444, 205779, 205104, 206816, 206823, 206102, 205380, 205156, 206478, 206812, 205655, 206329, 205897, 205391, 206095, 205206, 206015, 205193, 206788, 205634, 206854, 205753, 206176, 206845, 205854, 206111, 205460, 206335, 206777, 206943, 205144, 206539, 206871, 205450, 206938, 205062, 205987, 205055, 206307, 205327, 205464, 205889, 205434, 206646, 205946, 205447, 205387, 206818, 205772, 205954, 206328, 205927, 205810, 205114, 205446, 206564, 205088, 205080, 205735, 205955, 206803, 205559, 206136, 206978, 205938, 205358, 205419, 206436, 206003, 206838, 206817, 205913, 205085, 205693, 206513, 205926, 206061, 205379, 206682, 205841, 205988, 205759, 206837, 205410, 205309, 206154, 205127, 206797, 205849, 205620, 205541, 205423, 206849, 206147, 206721, 205899, 206245, 205442, 206515, 205136, 205201, 205949, 205805, 205155, 205421, 206055, 206963, 205838, 206210, 205323, 206941, 205591, 206459, 206587, 205175, 205828, 206286, 205250, 205813, 205858, 206320, 206376, 205801, 205766, 205538, 205853, 205931, 206593, 205181, 206255, 205878, 206811, 205281, 205728, 205073, 206832, 206127, 205866, 205884, 206475, 206609, 206705, 206872, 206730, 206193, 205900, 205328, 206204, 206241, 205253, 206310, 206613, 205755, 205551, 205376, 206985, 206897, 206294, 205489, 206179, 206577, 205063, 205868, 205100, 205800, 206538, 206631, 205512, 206578, 206034, 206092, 206732, 205860, 205378, 205076, 205499, 205793, 205369, 205101, 205947, 205694, 206342, 206197, 206173, 205433, 206025, 206370, 206043, 206709, 206239, 205978, 205257, 205971, 206324, 206284, 206962, 206168, 206009, 206839, 205184, 205596, 206397, 205435, 206924, 205630, 206784, 205890, 206770, 205098, 205814, 205520, 206016, 205228, 205827, 206650, 206093, 205295, 205139, 206725, 205952, 206465, 206323, 206563, 206660, 206020, 206235, 206341, 205361, 205537, 206736, 205174, 205363, 206984, 206125, 206211, 206746, 205339, 205922, 206741, 206912, 205832, 206225, 206160, 206685, 206990, 206088, 206267, 205298, 205510, 205869, 206847, 205322, 205420, 205287, 205223, 206455, 206922, 205563, 206357, 205628, 205560, 206337, 205190, 206534, 206184, 205057, 206582, 206693, 205743, 206738, 206124, 206571, 205706, 206494, 205143, 206388, 206510, 206041, 206027, 206890, 205367, 206222, 206011, 206540, 206710, 205640, 206487, 206439, 205318, 206830, 206507, 205293, 205916, 205592, 205736, 205342, 206669, 206858, 205067, 205289, 206054, 205908, 206608, 205130, 206649, 206557, 206010, 205836, 205084, 205482, 205732, 205272, 205519, 206076, 205481, 205872, 205455, 205961, 206006, 206547, 206296, 206551, 206401, 206913, 205613, 205626, 206586, 206185, 205324, 206559, 206680, 206623, 205662, 205173, 206449, 205665, 205977, 206780, 206767, 205639, 206665, 206558, 206570, 205873, 205226, 206804, 206889, 205484, 205585, 205266, 205218, 206291, 205784, 206207, 205871, 206142, 206900, 206440, 206572, 205306, 205439, 206152, 206233, 205594, 205258, 206772, 206752, 206704, 205350, 206139, 205382, 206876, 205756, 206792, 205657, 206358, 206688, 205437, 205861, 206404, 206190, 205090, 206595, 205716, 205209, 206246, 206182, 206412, 206899, 206661, 206238, 206771, 206174, 206266, 205390, 205365, 206703, 205937, 206548, 205684, 205167, 206375, 206261, 206916, 206180, 206427, 206719, 206256, 206532, 206901, 206761, 205531, 205748, 205804, 205500, 205219, 206529, 205163, 206078, 205518, 206581, 206019, 205891, 205762, 205981, 205454, 206417, 206244, 205480, 205953, 205064, 206247, 205856, 205457, 206597, 206918, 205831, 206292, 206895, 206527, 205564, 206264, 205494, 206159, 205811, 206275, 205255, 206221, 205208, 205951, 206339, 205385, 206497, 205065, 206424, 205757, 206863, 206115, 205742, 205863, 205944, 205279, 205608, 206002, 206082, 205964, 205112, 205565, 205653, 205443, 206611, 206394, 205676, 206120, 205540, 206672, 206935, 206450, 206462, 206482, 206171, 206336, 206865, 206454, 205364, 206989, 205188, 206374, 206707, 205056, 206312, 206149, 205782, 206129, 205110, 206356, 205723, 206085, 206317, 205074, 205945, 205126, 206399, 205528, 206504, 205131, 206301, 206278, 206659, 206643, 206698, 206598, 206828, 205763, 206367, 206500, 206528, 205993, 205069, 205317, 206209, 206396, 205221, 206470, 206065, 206809, 205765, 206306, 205695, 206763, 205658, 206813, 205277, 206330, 205451, 206008, 205453, 205183, 206930, 205384, 206334, 206213, 205393, 206070, 206361, 206878, 205817, 205708, 206437, 205929, 206086, 205079, 205269, 205851, 205159, 206205, 205892, 205242, 205166, 206488, 206071, 205283, 206305, 205837, 206119, 205310, 205864, 206200, 206755, 206580, 206757, 206203, 206393, 205252, 205095, 205504, 205700, 206331, 206945, 206068, 205928, 206369, 205709, 205697, 206893, 206346, 206035, 205465, 206810, 205475, 205158, 206481, 206107, 205340, 205259, 205497, 205372, 205307, 206158, 205638, 205825, 205345, 206947, 205587, 205672, 205611, 206524, 205870, 206786, 205332, 205898, 206490, 206354, 205581, 206421, 205068, 206561, 206815, 206150, 206614, 205600, 205343, 206157, 205119, 206908, 205985, 206880, 206875, 205405, 206773, 206933, 206776, 205524, 206407, 206030, 206795, 206345, 206606, 206332, 205719, 205646, 206448, 206687, 206951, 205195, 205285, 205142, 205412, 206471, 206785, 206321, 206166, 205714, 206683, 205729, 205116, 206765, 206080, 206894, 205061, 206274, 206955, 206058, 206549, 205601, 206920, 206970, 205656, 206373, 205583, 205204, 205711, 205548, 205120, 205957, 206228, 206505, 206300, 206716, 205187, 206796, 206227, 206621, 205794, 206118, 205983, 205787, 205436, 205491, 206060, 205199, 205790, 205569, 205648, 205580, 205260, 205299, 205991, 205749, 205883, 205523, 205280, 205774, 206923, 205920, 205052, 205826, 206198, 206062, 205730, 205533, 205397, 205529, 205802, 205161, 205377, 206518, 206569, 206326, 205456, 206350, 206670, 205775, 206308, 206676, 205263, 205294, 205875, 205992, 205865, 205234, 206972, 206313, 206057, 205717, 206554, 205493, 205989, 205609, 205577, 206822, 205798, 205616, 206378, 205487, 205375, 205539, 206536, 206385, 205162, 205619, 206364, 206304, 206541, 206416, 205089, 206749, 206434, 206635, 206565, 205320, 206155, 205466, 206216, 205355, 206387, 205552, 206841, 206726, 206191, 206594, 206677, 206944, 206496, 205106, 206787, 206014, 205696, 206745, 205165, 206091, 205220, 205117, 205593, 206668, 205770, 206048, 206444, 206964, 205862, 206864, 206756, 206430, 206466, 206799, 206914, 206467, 206087, 205956, 206075, 205882, 205135, 205867, 206169, 205850, 206355, 206446, 206706, 205368, 206137, 205483, 206299, 206982, 205857, 206469, 206965, 205354, 205821, 206568, 206135, 206090, 205366, 206100, 205623, 205614, 205276, 206415, 205909, 206442, 205738, 206694, 206302, 206543, 205685, 205852, 206276, 206004, 206114, 206737, 206195, 205093, 206612, 206647, 206711, 205921, 206898, 206919, 205251, 205566, 205903, 205415, 205582, 205751, 205129, 206592, 206122, 205319, 205470, 206604, 205574, 206691, 205267, 206782, 205844, 206156, 205147, 205933, 205525, 206798, 206123, 206457, 205160, 206186, 206153, 205780, 205886, 206859, 205803, 205396, 206001, 206934, 206831, 205123, 206873, 206254, 205930, 205210, 206347, 205326, 205643, 205224, 205402, 205300, 205388, 205202, 205842, 206903, 205650, 206956, 205571, 205704, 206051, 206781, 206428, 206240, 205970, 206262, 206853, 206259, 205773, 205261, 206556, 206974, 206287, 206884, 205216, 205153, 205776, 206432, 205148, 206800, 205152, 205534, 206408, 206667, 206024, 206867, 206162, 206480, 205795, 206381, 205478, 206626, 206289, 205472, 205321, 206148, 206545, 206697, 206522, 205911, 206096, 206957, 206574, 206251, 206827, 206192, 206314, 206365, 205536, 205544, 205236, 205314, 206848, 206769, 205352, 205754, 205132, 205413, 206652, 205288, 205965, 206579, 206902, 206794, 206316, 206411, 205172, 206888, 205771, 206573, 206398, 206750, 206146, 205386, 205313, 205632, 205122, 206363, 205815, 205896, 206214, 206021, 206402, 205546, 206610, 206472, 205618, 206380, 206645, 205767, 205273, 205848, 206937, 206728, 206366, 206165, 205459, 205923, 205303, 206870, 206520, 206131, 205726, 206656, 205603, 205692, 205680, 206013, 206022, 206047, 205701, 205904, 205270, 205936, 206456, 206861, 206700, 205935, 206243, 206856, 206202, 206053, 206194, 206729, 205513, 206258, 206928, 205060, 206032, 206887, 206654, 205855, 206917, 206063, 206530, 206484, 205431, 205389, 205211, 206886, 206138, 206764, 206925, 206181, 205315, 206386, 206910, 205575, 205918, 205562, 205760, 205823, 205137, 205567, 206651, 206188, 206491, 205557, 206624, 206748, 205651, 205409, 205230, 205145, 205429, 205492, 206988, 206368, 205925, 205740, 205516, 205998, 205924, 205734, 206759, 205895, 205094, 206039, 206550, 206151, 206969, 205347, 206438, 206253, 206268, 205113, 206109, 206133, 206220, 206882, 206881, 205543, 206161, 205747, 206525, 205225, 206584, 205588, 205274, 206603, 205505, 205607, 205833, 205940, 205522, 205462, 206270, 205428, 206860, 206212, 206007, 206991, 206066, 206620, 205416, 206178, 206451, 205761, 206689, 205171, 205406, 205371, 206701, 205237, 206325, 205329, 206435, 205115, 206477, 205468, 205374, 206961, 206599, 206619, 205698, 206629, 206775, 206410, 206349, 205191, 205789, 205972, 205212, 205417, 206555, 206196, 205876, 206384, 206383, 206121, 206607, 205845, 205550, 205906, 206106, 206533, 206690, 206223, 206260, 206458, 206242, 206575, 205637, 205670, 205351, 206869, 206846, 205103, 205239, 206779, 205178, 205967, 205669, 205078, 206879, 205725, 205373, 206655, 205975, 205338, 206281, 206425, 205086, 205644, 206170, 205133, 205722, 206074, 205458, 205467, 206498, 206201, 205400, 205843, 206348, 205621, 206585, 205517, 206094, 206641, 205463, 206250, 205304, 205912, 206911, 205683, 205425, 206907, 206224, 206473, 205839, 205671, 206359, 205526, 205471, 206257, 206874, 205096, 206966, 205959, 206714, 206059, 206833, 206351, 205997, 206671, 206023, 206589, 206072, 205203, 206495, 206600, 205752, 206679, 205189, 206950, 205948, 206980, 206952, 205134, 206546, 205169, 206852, 205452, 206303, 206040, 206352, 206283, 206049, 206983, 206056, 206105, 206995, 206519, 206163, 206285, 206638, 206280, 205874, 206766, 205333, 205976, 205186, 206077, 206187, 206596, 205631, 205901, 205479, 205996, 206862, 205570, 206265, 205232, 206036, 205586, 206089, 205240, 205530, 205360, 205758, 206960, 206921, 206492, 206344, 206553, 206362, 206249, 205092, 205973, 205555, 205474, 206648, 206406, 205349, 205652, 205128, 206390, 206033, 206338, 205488, 206954, 206735, 205915, 206616, 205707, 206315, 206739, 205820, 205414, 206939, 206844, 205514, 205687, 205503, 205724, 206319, 206272, 206891, 205647, 205721, 206774, 206279, 206791, 206333, 205070, 205712, 205151, 205578, 206664, 205227, 205710, 206474, 206042, 206110, 206892, 205235, 206395, 206617, 205958, 205508, 206602, 206476, 206263, 205629, 206409, 206751, 205182, 205381, 206523, 205477, 206929, 206083, 206840, 206986, 206820, 206026, 206940, 205249, 205797, 205245, 205049, 206064, 206099, 206615, 205667, 205297, 205448, 205778, 206684, 205745, 206814, 205881, 206189, 205407, 205099, 206199, 205808, 205527, 205083, 206501, 205241, 205275, 206045, 205660, 206318, 205196, 206103, 205659, 206052, 206743, 205939, 205200, 205370, 206420, 205051, 206215, 205573, 205264, 205599, 205893, 205125, 206806, 205091, 205066, 205572, 205362, 206971, 206229, 206134, 205495, 205302, 205654, 205783, 205673, 206658, 206371, 205847, 206206, 206542, 206012, 206483, 205146, 206720, 205715, 205622, 206734, 206392, 205999, 206521, 205731, 206642, 206464, 206805, 206391, 205598, 205192, 205642, 205124, 205532, 205486, 206931, 205690, 205764, 205059, 206712, 206727, 205335, 205168, 205248, 205097, 206715, 206640, 205271, 206877, 206327, 205325, 205980, 205214, 205398, 205545, 205768, 205942, 206108, 206636, 205233, 206069, 206808, 206885, 206309, 206460, 205995, 206994, 206433, 205485, 206506, 205691, 206489, 205792, 205506, 205605, 205330, 205426, 206904, 206443, 205282, 205424, 205595, 205617, 205399, 206175, 205356, 206073, 206686, 206219, 205934, 205561, 205859, 206104, 205111, 206018, 206948, 205887, 205796, 206502, 206372, 205403, 206145, 205521, 205777, 206552, 205331, 205720, 206269, 205809, 206183, 205229, 205829, 206031, 205739] not in index'

## 1.2 Feature Scaling for KNN using Robust Scaler

In [None]:
rbscaler = RobustScaler()

# fit and transform training data, transform testing data
X_train_scaled = rbscaler.fit_transform(X_train)
X_test_scaled = rbscaler.transform(X_test)

print(pd.DataFrame({
    "count": y_train.value_counts(),
    "proportion": y_train.value_counts(normalize=True)
})
)

joblib.dump(rbscaler, '/home/wahba/Documents/model/multi_class_classification/robust_scaler.joblib')

## 1.3 Dataset Resampling

In [None]:
over = SMOTE(sampling_strategy={
    2: 15000,
    3: 8000,
    4: 8000,
})

under = RandomUnderSampler(sampling_strategy={
    0: 90000,
    1: 30000,
})

pipeline = Pipeline([
    ('over', over),
    ('under', under)
])

X_train_resampled, y_train_resampled = pipeline.fit_resample(X_train, y_train)
X_train_scaled_resampled, y_train_scaled_resampled = pipeline.fit_resample(X_train_scaled, y_train)

print(pd.DataFrame({
    "count": y_train_resampled.value_counts(),
    "proportion": y_train_resampled.value_counts(normalize=True)
})
)

# 2.0 Machine Learning Training

## 2.1. Random Forest

In [None]:

measurement_rf = {}
cpu_usage = []
stop_flag = threading.Event()
rf_model = RandomForestClassifier(n_estimators=150, min_samples_split=2, min_samples_leaf=2, max_features='sqrt', max_depth=30, random_state=42, n_jobs=-1)

def monitor_cpu():
        while not stop_flag.is_set():
            cpu_usage.append(psutil.cpu_percent(interval=0.1))

try :
    cpu_thread = threading.Thread(target=monitor_cpu)
    cpu_thread.start()
    start_time = time.time()

    train_memory_rf = max(memory_usage(lambda: rf_model.fit(X_train_resampled, y_train_resampled)))
    training_time = time.time() - start_time
    
    stop_flag.set()
    cpu_thread.join()
    
    # Add measurements
    measurement_rf['Memory Usage (MB)'] = train_memory_rf
    measurement_rf['Training Time (s)'] = training_time
    measurement_rf['Peak CPU Usage (%)'] = max(cpu_usage)
    measurement_rf['Average CPU Usage (%)'] = sum(cpu_usage) / len(cpu_usage) if cpu_usage else 0
    
    cv_scores_rf = cross_val_score(rf_model, X_train_resampled, y_train_resampled, cv=5, n_jobs=-1)

except Exception as e:
    print(f"An error occurred during training or monitoring: {e}")

joblib.dump(rf_model, '/home/wahba/Documents/model/multi_class_classification/rf_multi_class.joblib')

In [None]:
y_pred_rf = rf_model.predict(X_test)

# Evaluating the model performance on the cross validation set vs accuracy on the test set
cv_scores_mean_rf = np.mean(cv_scores_rf)
print(f'Cross validation average score: {cv_scores_mean_rf:.4f} +/- standard deviation: {np.std(cv_scores_rf):.4f}')

accuracy_rf = accuracy_score(y_test, y_pred_rf)
print(f'Accuracy on the test set: {accuracy_rf:.4f}')
# Evaluating the model via confusion matrix
cm_rf = confusion_matrix(y_test, y_pred_rf)

plt.figure(figsize=(10, 7))
sns.heatmap(cm_rf, annot=True, fmt='d', xticklabels=rf_model.classes_, yticklabels=rf_model.classes_, cmap='Blues', cbar=False)
plt.xlabel('Predicted')
plt.ylabel('Truth')
plt.title('Random Forest Confusion Matrix')
plt.show()

print(classification_report(y_test, y_pred_rf))

## 2.2. XGBoost

In [None]:
measurement_xgb = {}
cpu_usage = []
stop_flag = threading.Event()
xgb_model = xgb.XGBClassifier(
    subsample=1.0, 
    n_estimators=100, 
    min_child_weight=1, 
    max_depth=6, 
    learning_rate=0.2, 
    colsample_bytree=1.0, 
	objective='multi:softmax',
    random_state=42,
    n_jobs=-1)

def monitor_cpu():
        while not stop_flag.is_set():
            cpu_usage.append(psutil.cpu_percent(interval=0.1))

try :
    cpu_thread = threading.Thread(target=monitor_cpu)
    cpu_thread.start()
    start_time = time.time()

    train_memory_xgb = max(memory_usage(lambda: xgb_model.fit(X_train_resampled, y_train_resampled)))
    training_time = time.time() - start_time
    
    stop_flag.set()
    cpu_thread.join()
    
    # Add measurements
    measurement_xgb['Memory Usage (MB)'] = train_memory_xgb
    measurement_xgb['Training Time (s)'] = training_time
    measurement_xgb['Peak CPU Usage (%)'] = max(cpu_usage)
    measurement_xgb['Average CPU Usage (%)'] = sum(cpu_usage) / len(cpu_usage) if cpu_usage else 0
    
    cv_scores_xgb = cross_val_score(xgb_model, X_train_resampled, y_train_resampled, cv=5, n_jobs=-1)

except Exception as e:
    print(f"An error occurred during training or monitoring: {e}")

joblib.dump(xgb_model, '/home/wahba/Documents/model/multi_class_classification/xgb_multi_class.joblib')

In [None]:
y_pred_xgb = xgb_model.predict(X_test)

# Evaluating the model performance on the cross validation set vs accuracy on the test set
cv_scores_mean_xgb = np.mean(cv_scores_xgb)
print(f'Cross validation average score: {cv_scores_mean_xgb:.4f} +/- standard deviation: {np.std(cv_scores_xgb):.4f}')

accuracy_xgb = accuracy_score(y_test, y_pred_xgb)
print(f'Accuracy on the test set: {accuracy_xgb:.4f}')
# Evaluating the model via confusion matrix
cm_xgb = confusion_matrix(y_test, y_pred_xgb)

plt.figure(figsize=(10, 7))
sns.heatmap(cm_xgb, annot=True, fmt='d', xticklabels=xgb_model.classes_, yticklabels=xgb_model.classes_, cmap='Blues', cbar=False)
plt.xlabel('Predicted')
plt.ylabel('Truth')
plt.title('XGBoost Confusion Matrix')
plt.show()

print(classification_report(y_test, y_pred_xgb))

## 2.3. K-Nearest Neighbour (KNN)

In [None]:
measurement_knn = {}
cpu_usage = []
stop_flag = threading.Event()
knn_model = KNeighborsClassifier(weights='distance', n_neighbors=3, n_jobs=-1)

def monitor_cpu():
        while not stop_flag.is_set():
            cpu_usage.append(psutil.cpu_percent(interval=0.1))

try :
    cpu_thread = threading.Thread(target=monitor_cpu)
    cpu_thread.start()
    start_time = time.time()

    train_memory_knn = max(memory_usage(lambda: knn_model.fit(X_train_scaled_resampled, y_train_scaled_resampled)))
    training_time = time.time() - start_time
    
    stop_flag.set()
    cpu_thread.join()
    
    # Add measurements
    measurement_knn['Memory Usage (MB)'] = train_memory_knn
    measurement_knn['Training Time (s)'] = training_time
    measurement_knn['Peak CPU Usage (%)'] = max(cpu_usage)
    measurement_knn['Average CPU Usage (%)'] = sum(cpu_usage) / len(cpu_usage) if cpu_usage else 0
    
    cv_scores_knn = cross_val_score(knn_model, X_train_scaled_resampled, y_train_scaled_resampled, cv=5, n_jobs=-1)

except Exception as e:
    print(f"An error occurred during training or monitoring: {e}")

joblib.dump(knn_model, '/home/wahba/Documents/model/multi_class_classification/knn_multi_class.joblib')

In [None]:
y_pred_knn = knn_model.predict(X_test_scaled)

# Evaluating the model performance on the cross validation set vs accuracy on the test set
cv_scores_mean_knn = np.mean(cv_scores_knn)
print(f'Cross validation average score: {cv_scores_mean_knn:.4f} +/- standard deviation: {np.std(cv_scores_knn):.4f}')
accuracy_knn = accuracy_score(y_test, y_pred_knn)
print(f'Accuracy on the test set: {accuracy_knn:.4f}')
# Evaluating the model via confusion matrix
cm_knn = confusion_matrix(y_test, y_pred_knn)

plt.figure(figsize=(10, 7))
sns.heatmap(cm_knn, annot=True, fmt='d', xticklabels=knn_model.classes_, yticklabels=knn_model.classes_, cmap='Blues', cbar=False)
plt.xlabel('Predicted')
plt.ylabel('Truth')
plt.title('K-Nearest Neighbors Confusion Matrix')
plt.show()

print(classification_report(y_test, y_pred_knn))

# 3.0 Model Evaluation

In [None]:
# Calculating precision, recall, and F1 score for each model
precision_rf = precision_score(y_test, y_pred_rf, average='weighted')
recall_rf = recall_score(y_test, y_pred_rf, average='weighted')
f1_rf = f1_score(y_test, y_pred_rf, average='weighted')

precision_xgb = precision_score(y_test, y_pred_xgb, average='weighted')
recall_xgb = recall_score(y_test, y_pred_xgb, average='weighted')
f1_xgb = f1_score(y_test, y_pred_xgb, average='weighted')

precision_knn = precision_score(y_test, y_pred_knn, average='weighted')
recall_knn = recall_score(y_test, y_pred_knn, average='weighted')
f1_knn = f1_score(y_test, y_pred_knn, average='weighted')

supervised_results = pd.DataFrame({
    'Model': ['Random Forest', 'XGBoost', 'KNN'],
    'Accuracy': [accuracy_rf, accuracy_xgb, accuracy_knn],
    'Cross Validation Mean': [cv_scores_mean_rf, cv_scores_mean_xgb, cv_scores_mean_knn],
    'Precision': [precision_rf, precision_xgb, precision_knn],
    'Recall': [recall_rf, recall_xgb, recall_knn],
    'F1 Score': [f1_rf, f1_xgb, f1_knn],
    'Memory Usage (MB)': [measurement_rf['Memory Usage (MB)'], measurement_xgb['Memory Usage (MB)'], measurement_knn['Memory Usage (MB)']],
    'Training Time (s)': [measurement_rf['Training Time (s)'], measurement_xgb['Training Time (s)'], measurement_knn['Training Time (s)']],
    'Peak CPU Usage (%)': [measurement_rf['Peak CPU Usage (%)'], measurement_xgb['Peak CPU Usage (%)'], measurement_knn['Peak CPU Usage (%)']],
    'Average CPU Usage (%)': [measurement_rf['Average CPU Usage (%)'], measurement_xgb['Average CPU Usage (%)'], measurement_knn['Average CPU Usage (%)']],
})

In [None]:
# Plotting the comparison for accuracy, cross-validation, and metrics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plotting Accuracy and Cross Validation Mean
supervised_results.set_index('Model')[['Accuracy', 'Cross Validation Mean']].plot(kind='bar', ax=axes[0, 0], color=['skyblue', 'lightgreen'], legend=True)
axes[0, 0].set_title('Model Comparison: Accuracy and Cross Validation Mean')
axes[0, 0].set_ylabel('Score')
axes[0, 0].set_xlabel('Model')
axes[0, 0].set_ylim(0.95, 1.0)
axes[0, 0].legend(loc='lower left')

# Plotting Precision, Recall, F1 Score
supervised_results.set_index('Model')[['Precision', 'Recall', 'F1 Score']].plot(kind='bar', ax=axes[0, 1], color=['orange', 'lightcoral', 'yellowgreen'], legend=True)
axes[0, 1].set_title('Model Comparison: Precision, Recall, F1 Score')
axes[0, 1].set_ylabel('Score')
axes[0, 1].set_xlabel('Model')
axes[0, 1].set_ylim(0.95, 1.0)
axes[0, 1].legend(loc='lower left')

# Plotting Memory Usage and Training Time
ax1 = axes[1, 0]

supervised_results.set_index('Model')['Memory Usage (MB)'].plot(
    kind='bar', ax=ax1, color='lightblue', label='Memory Usage (MB)', width=0.6
)
ax1.set_ylabel('Memory Usage (MB)', color='lightblue')
ax1.tick_params(axis='y', labelcolor='lightblue')

ax2 = ax1.twinx() 
supervised_results.set_index('Model')['Training Time (s)'].plot(
    ax=ax2, color='lightpink', marker='o', label='Training Time (s)'
)

ax2.set_ylabel('Training Time (s)', color='lightpink')
ax2.tick_params(axis='y', labelcolor='lightpink')

ax1.set_title('Model Comparison: Memory Usage and Training Time')
ax1.set_xlabel('Model')

lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines + lines2, labels + labels2, loc='upper right')

# Plotting Peak and Average CPU Usage
supervised_results.set_index('Model')[['Peak CPU Usage (%)', 'Average CPU Usage (%)']].plot(kind='bar', ax=axes[1, 1], color=['lightgreen', 'salmon'], legend=True)
axes[1, 1].set_title('Model Comparison: CPU Usage')
axes[1, 1].set_ylabel('Percentage')
axes[1, 1].set_xlabel('Model')
axes[1, 1].legend(loc='lower left')

plt.tight_layout()
plt.show()


In [None]:
print(f'''
    {measurement_rf}
    {measurement_xgb}
    {measurement_knn}
    ''')