In [51]:
import pandas as ps
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report

In [46]:
data = ps.read_csv("flows.csv")
print("Количество записей каждого протокола:")
for proto in data["proto"].unique():
    print("{}{}".format(proto.ljust(20), len(data[data["proto"] == proto])))

Количество записей каждого протокола:
DNS                 6507
HTTP                2626
SSL                 768
SSL_No_Cert         90
Apple               25
NTP                 17
Quic                111
BitTorrent          93
Skype               283
Unknown             22
Unencryped_Jabber   1


## Обработка данных

In [47]:
drop_protos = ["Unknown", "Unencryped_Jabber", "NTP", "Apple"]
replace_protos = [("SSL_No_Cert", "SSL")]
data = data[~data["proto"].isin(drop_protos)]
for old_proto, new_proto in replace_protos:
    data = data.replace(old_proto, new_proto)

In [48]:
print("Новое количество записей каждого протокола:")
for proto in data["proto"].unique():
    print("{}{}".format(proto.ljust(12), len(data[data["proto"] == proto])))

Новое количество записей каждого протокола:
DNS         6507
HTTP        2626
SSL         858
Quic        111
BitTorrent  93
Skype       283


## Разделение данных на обучающую и тестовую выборки

In [40]:
proto_clusters = [data[data["proto"] == proto] for proto in data["proto"].unique()]
train_clusters = []
test_clusters = []
for cluster in proto_clusters:
    np.random.seed(42)
    cluster = cluster.iloc[np.random.permutation(len(cluster))]
    split_index = len(cluster)//3
    train_clusters.append(cluster.iloc[:split_index])
    test_clusters.append(cluster.iloc[split_index:])

In [41]:
train_data = ps.concat(train_clusters)
test_data = ps.concat(test_clusters)
print("Обучающая выборка: {} записей\nПроверочная выборка: {} записей".format(len(train_data), len(test_data)))

Обучающая выборка: 3492 записей
Проверочная выборка: 6986 записей


In [43]:
print("Количество записей каждого протокола:")
for proto in data["proto"].unique():
    print("{}{}/{}".format(proto.ljust(12), len(train_data[train_data["proto"] == proto]),
                                       len(test_data[test_data["proto"] == proto])))

Количество записей каждого протокола:
DNS         2169/4338
HTTP        875/1751
SSL         286/572
Quic        37/74
BitTorrent  31/62
Skype       94/189


## Обучение и проверка модели

In [49]:
scaler = StandardScaler()
X_train = scaler.fit_transform(train_data.drop(["proto", "subproto"], axis=1))
X_test = scaler.transform(test_data.drop(["proto", "subproto"], axis=1))

In [56]:
labeler = LabelEncoder()
y_train = labeler.fit_transform(train_data["proto"])
y_test = labeler.transform(test_data["proto"])

In [57]:
model = RandomForestClassifier(27, "entropy", 9, random_state=42)
model.fit(X_train, y_train)
y_predicted = model.predict(X_test)
true_labels = labeler.inverse_transform(y_test)
predicted_labels = labeler.inverse_transform(y_predicted)
print(classification_report(true_labels, predicted_labels))

             precision    recall  f1-score   support

 BitTorrent       1.00      0.97      0.98        62
        DNS       1.00      1.00      1.00      4338
       HTTP       1.00      1.00      1.00      1751
       Quic       1.00      1.00      1.00        74
        SSL       0.99      0.99      0.99       572
      Skype       0.98      0.94      0.96       189

avg / total       1.00      1.00      1.00      6986



In [59]:
def cross_class_report(y, p):
    classes = np.unique(y)
    res = ps.DataFrame({"y": y, "p": p}, index=None)
    table = ps.DataFrame(index=classes, columns=classes)
    for true_cls in classes:
        tmp = res[res["y"] == true_cls]
        for pred_cls in classes:
            table[pred_cls][true_cls] = len(tmp[tmp["p"] == pred_cls])
    return table

In [60]:
print(cross_class_report(true_labels, predicted_labels))

           BitTorrent   DNS  HTTP Quic  SSL Skype
BitTorrent         60     0     0    0    2     0
DNS                 0  4334     0    0    0     4
HTTP                0     0  1750    0    1     0
Quic                0     0     0   74    0     0
SSL                 0     0     4    0  568     0
Skype               0     8     2    0    1   178
