## Anomaly Detection in Network Trac with K-means Clustering
- http://www.kdd.org/kdd-cup/view/kdd-cup-1999/Data

In [10]:
from pyspark.conf import SparkConf
from pyspark import StorageLevel

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
# from pyspark.sql.types import *
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler
from pyspark.ml.clustering import KMeans, KMeansModel
from pyspark.ml import Pipeline

import random


In [2]:
spark = SparkSession.builder \
    .master("local") \
    .appName("Clustering") \
    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("INFO")

In [3]:
df = spark.read.csv("../data/kddcup.data_10_percent.txt", inferSchema=True, sep=",").toDF(
    "duration", "protocol_type", "service", "flag", "src_bytes",
    "dst_bytes", "land", "wrong_fragment", "urgent",
    "hot", "num_failed_logins", "logged_in", "num_compromised",
    "root_shell", "su_attempted", "num_root", "num_file_creations",
    "num_shells", "num_access_files", "num_outbound_cmds",
    "is_host_login", "is_guest_login", "count", "srv_count",
    "serror_rate", "srv_serror_rate", "rerror_rate", "srv_rerror_rate",
    "same_srv_rate", "diff_srv_rate", "srv_diff_host_rate",
    "dst_host_count", "dst_host_srv_count",
    "dst_host_same_srv_rate", "dst_host_diff_srv_rate",
    "dst_host_same_src_port_rate", "dst_host_srv_diff_host_rate",
    "dst_host_serror_rate", "dst_host_srv_serror_rate",
    "dst_host_rerror_rate", "dst_host_srv_rerror_rate", "label")

In [4]:
print("Number of training data: {}".format(df.count()))
df.printSchema()

Number of training data: 494021
root
 |-- duration: integer (nullable = true)
 |-- protocol_type: string (nullable = true)
 |-- service: string (nullable = true)
 |-- flag: string (nullable = true)
 |-- src_bytes: integer (nullable = true)
 |-- dst_bytes: integer (nullable = true)
 |-- land: integer (nullable = true)
 |-- wrong_fragment: integer (nullable = true)
 |-- urgent: integer (nullable = true)
 |-- hot: integer (nullable = true)
 |-- num_failed_logins: integer (nullable = true)
 |-- logged_in: integer (nullable = true)
 |-- num_compromised: integer (nullable = true)
 |-- root_shell: integer (nullable = true)
 |-- su_attempted: integer (nullable = true)
 |-- num_root: string (nullable = true)
 |-- num_file_creations: string (nullable = true)
 |-- num_shells: string (nullable = true)
 |-- num_access_files: integer (nullable = true)
 |-- num_outbound_cmds: integer (nullable = true)
 |-- is_host_login: integer (nullable = true)
 |-- is_guest_login: integer (nullable = true)
 |-- co

In [7]:
df.show(3)

+--------+-------------+-------+----+---------+---------+----+--------------+------+---+-----------------+---------+---------------+----------+------------+--------+------------------+----------+----------------+-----------------+-------------+--------------+-----+---------+-----------+---------------+-----------+---------------+-------------+-------------+------------------+--------------+------------------+----------------------+----------------------+---------------------------+---------------------------+--------------------+------------------------+--------------------+------------------------+-------+
|duration|protocol_type|service|flag|src_bytes|dst_bytes|land|wrong_fragment|urgent|hot|num_failed_logins|logged_in|num_compromised|root_shell|su_attempted|num_root|num_file_creations|num_shells|num_access_files|num_outbound_cmds|is_host_login|is_guest_login|count|srv_count|serror_rate|srv_serror_rate|rerror_rate|srv_rerror_rate|same_srv_rate|diff_srv_rate|srv_diff_host_rate|dst_hos

In [5]:
df.select("label").groupBy("label").count().orderBy(desc("count")).show(25)

+----------------+------+
|           label| count|
+----------------+------+
|          smurf.|280790|
|        neptune.|107201|
|         normal.| 97277|
|           back.|  2203|
|          satan.|  1589|
|        ipsweep.|  1247|
|      portsweep.|  1040|
|    warezclient.|  1020|
|       teardrop.|   979|
|            pod.|   264|
|           nmap.|   231|
|   guess_passwd.|    53|
|buffer_overflow.|    30|
|           land.|    21|
|    warezmaster.|    20|
|           imap.|    12|
|        rootkit.|    10|
|     loadmodule.|     9|
|      ftp_write.|     8|
|       multihop.|     7|
|            phf.|     4|
|           perl.|     3|
|            spy.|     2|
|            0.00|     1|
+----------------+------+



In [8]:
# Only numeric
df = df.withColumn("num_root", df.num_root.cast("int"))
df = df.withColumn("num_file_creations", df.num_root.cast("int"))
df = df.withColumn("num_shells", df.num_root.cast("int"))

# Drop null
train = df.drop("protocol_type", "service", "flag").dropna().cache()
columns = train.columns
columns.remove('label')

In [11]:
# VectorAssembler
assembler = VectorAssembler(
    inputCols=columns,
    outputCol='features')

# StandardScaler
scaler = StandardScaler(
    inputCol='features',
    outputCol='scaled_features',
    withStd=True,
    withMean=True)

# KMeans
kmeans = KMeans(
    featuresCol='scaled_features',
    predictionCol='cluster',
    maxIter=30,
    seed=random.randrange(1,10))
kmeans.setK(3)

KMeans_480f98b472486f05f440

In [12]:
pipeline = Pipeline().setStages([assembler, scaler, kmeans])
pipelineModel = pipeline.fit(train)
kmeansModel = pipelineModel.stages[-1]

In [13]:
print(*kmeansModel.clusterCenters(), sep='\n')

[-0.06779167 -0.00211676 -0.02628241 -0.00667342 -0.04772019 -0.00257147
 -0.04413591 -0.00978218 -0.41718843 -0.00567868 -0.01055195 -0.00467567
 -0.00564001 -0.00564001 -0.00564001 -0.02763182  0.          0.
 -0.03726266  0.81625548  0.86634943 -0.46408985 -0.46352057 -0.24796049
 -0.2486313   0.53691515 -0.25511113 -0.20356242  0.34755699  0.62439493
  0.59842288 -0.2823769   0.82281462 -0.15862339 -0.46439018 -0.46320249
 -0.25193814 -0.24946402]
[ 0.19983821  0.00309251 -0.02605623 -0.004087   -0.02857926 -0.00257147
 -0.0439704  -0.00700191 -0.41473558 -0.0056115  -0.01055195 -0.00467567
 -0.00545136 -0.00545136 -0.00545136 -0.02763182  0.          0.
 -0.03726266 -0.71091765 -1.1473152   1.51034301  1.50794494  0.59904634
  0.59617813 -1.74350703  0.77326679 -0.20157484  0.33893612 -1.67921651
 -1.73370146  0.78457492 -1.13837302 -0.15722012  1.50946569  1.50869869
  0.59803101  0.60250379]
[-0.04217291  0.00245364  0.10768047  0.02431165  0.17307353  0.01056196
  0.18108398  0

In [14]:
withCluster = pipelineModel.transform(train)
clusterLabel = withCluster.select("cluster", "label") \
    .groupBy("cluster", "label").count() \
    .orderBy("cluster", "count")
clusterLabel.show(100)

+-------+----------------+------+
|cluster|           label| count|
+-------+----------------+------+
|      0|           nmap.|     3|
|      0|          satan.|    30|
|      0|         normal.|   609|
|      0|          smurf.|280774|
|      1|   guess_passwd.|     1|
|      1|           imap.|     2|
|      1|           land.|     2|
|      1|    warezclient.|     4|
|      1|        ipsweep.|    19|
|      1|       teardrop.|   101|
|      1|           nmap.|   103|
|      1|      portsweep.|  1037|
|      1|          satan.|  1539|
|      1|         normal.|  5873|
|      1|        neptune.|107197|
|      2|            spy.|     2|
|      2|      portsweep.|     3|
|      2|           perl.|     3|
|      2|            phf.|     4|
|      2|        neptune.|     4|
|      2|       multihop.|     7|
|      2|      ftp_write.|     8|
|      2|     loadmodule.|     9|
|      2|        rootkit.|    10|
|      2|           imap.|    10|
|      2|          smurf.|    16|
|      2|     

In [15]:
clusterLabel.groupBy('cluster').sum('count').orderBy('cluster').show()

+-------+----------+
|cluster|sum(count)|
+-------+----------+
|      0|    281416|
|      1|    115878|
|      2|     96726|
+-------+----------+



In [None]:
kmeansModel.summary.clusterSizes

In [16]:
scaled_features = withCluster.select('cluster', 'scaled_features')
scaled_features.show(3)

+-------+--------------------+
|cluster|     scaled_features|
+-------+--------------------+
|      2|[-0.0677917208490...|
|      2|[-0.0677917208490...|
|      2|[-0.0677917208490...|
+-------+--------------------+
only showing top 3 rows



In [18]:
spark.stop()