In [None]:
#https://github.com/jadianes/spark-py-notebooks/blob/master/nb10-sql-dataframes/nb10-sql-dataframes.ipynb

In [1]:
import os
#os.getcwd()

In [3]:
import urllib.request
data_file = "kddcup.data_10_percent.gz"
if not os.path.isfile(data_file):
    f = urllib.request.urlretrieve ("http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data_10_percent.gz", "kddcup.data_10_percent.gz")

In [4]:
import findspark
findspark.init()
import pyspark

sc = pyspark.SparkContext(appName="test")
raw_data = sc.textFile('file://' + os.getcwd() + "/"+ data_file).cache()

In [5]:
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

In [6]:
from pyspark.sql import Row

csv_data = raw_data.map(lambda l: l.split(","))
row_data = csv_data.map(lambda p: Row(
    duration=int(p[0]), 
    protocol_type=p[1],
    service=p[2],
    flag=p[3],
    src_bytes=int(p[4]),
    dst_bytes=int(p[5])
    )
)

In [7]:
interactions_df = sqlContext.createDataFrame(row_data)
interactions_df.registerTempTable("interactions") #table name

In [8]:
# Select tcp network interactions with more than 1 second duration and no transfer from destination
tcp_interactions = sqlContext.sql("""
    SELECT duration, dst_bytes FROM interactions WHERE protocol_type = 'tcp' AND duration > 1000 AND dst_bytes = 0
""")
tcp_interactions.show()

+--------+---------+
|duration|dst_bytes|
+--------+---------+
|    5057|        0|
|    5059|        0|
|    5051|        0|
|    5056|        0|
|    5051|        0|
|    5039|        0|
|    5062|        0|
|    5041|        0|
|    5056|        0|
|    5064|        0|
|    5043|        0|
|    5061|        0|
|    5049|        0|
|    5061|        0|
|    5048|        0|
|    5047|        0|
|    5044|        0|
|    5063|        0|
|    5068|        0|
|    5062|        0|
+--------+---------+
only showing top 20 rows



In [9]:
sqlContext.sql("""
    SELECT * FROM interactions WHERE protocol_type = 'tcp'
""").show()

+---------+--------+----+-------------+-------+---------+
|dst_bytes|duration|flag|protocol_type|service|src_bytes|
+---------+--------+----+-------------+-------+---------+
|     5450|       0|  SF|          tcp|   http|      181|
|      486|       0|  SF|          tcp|   http|      239|
|     1337|       0|  SF|          tcp|   http|      235|
|     1337|       0|  SF|          tcp|   http|      219|
|     2032|       0|  SF|          tcp|   http|      217|
|     2032|       0|  SF|          tcp|   http|      217|
|     1940|       0|  SF|          tcp|   http|      212|
|     4087|       0|  SF|          tcp|   http|      159|
|      151|       0|  SF|          tcp|   http|      210|
|      786|       0|  SF|          tcp|   http|      212|
|      624|       0|  SF|          tcp|   http|      210|
|     1985|       0|  SF|          tcp|   http|      177|
|      773|       0|  SF|          tcp|   http|      222|
|     1169|       0|  SF|          tcp|   http|      256|
|      259|   

In [10]:
sqlContext.sql("""
    SELECT DISTINCT protocol_type FROM interactions
""").show()

+-------------+
|protocol_type|
+-------------+
|          tcp|
|          udp|
|         icmp|
+-------------+



In [11]:
sqlContext.sql("""
    SELECT DISTINCT service FROM interactions
""").show()

+--------+
| service|
+--------+
|  telnet|
|     ftp|
|    auth|
|iso_tsap|
|  systat|
|    name|
| sql_net|
|   ntp_u|
|     X11|
|   pop_3|
|    ldap|
| discard|
|  tftp_u|
|  Z39_50|
| daytime|
|domain_u|
|   login|
|    smtp|
|     mtp|
|  domain|
+--------+
only showing top 20 rows



In [14]:
# Output duration together with dst_bytes
tcp_interactions_out = tcp_interactions.rdd.map(lambda p: "Duration: {}, Dest. bytes: {}".format(p.duration, p.dst_bytes))
for ti_out in tcp_interactions_out.collect():
  print(ti_out)

Duration: 5057, Dest. bytes: 0
Duration: 5059, Dest. bytes: 0
Duration: 5051, Dest. bytes: 0
Duration: 5056, Dest. bytes: 0
Duration: 5051, Dest. bytes: 0
Duration: 5039, Dest. bytes: 0
Duration: 5062, Dest. bytes: 0
Duration: 5041, Dest. bytes: 0
Duration: 5056, Dest. bytes: 0
Duration: 5064, Dest. bytes: 0
Duration: 5043, Dest. bytes: 0
Duration: 5061, Dest. bytes: 0
Duration: 5049, Dest. bytes: 0
Duration: 5061, Dest. bytes: 0
Duration: 5048, Dest. bytes: 0
Duration: 5047, Dest. bytes: 0
Duration: 5044, Dest. bytes: 0
Duration: 5063, Dest. bytes: 0
Duration: 5068, Dest. bytes: 0
Duration: 5062, Dest. bytes: 0
Duration: 5046, Dest. bytes: 0
Duration: 5052, Dest. bytes: 0
Duration: 5044, Dest. bytes: 0
Duration: 5054, Dest. bytes: 0
Duration: 5039, Dest. bytes: 0
Duration: 5058, Dest. bytes: 0
Duration: 5051, Dest. bytes: 0
Duration: 5032, Dest. bytes: 0
Duration: 5063, Dest. bytes: 0
Duration: 5040, Dest. bytes: 0
Duration: 5051, Dest. bytes: 0
Duration: 5066, Dest. bytes: 0
Duration

In [15]:
interactions_df.printSchema()

root
 |-- dst_bytes: long (nullable = true)
 |-- duration: long (nullable = true)
 |-- flag: string (nullable = true)
 |-- protocol_type: string (nullable = true)
 |-- service: string (nullable = true)
 |-- src_bytes: long (nullable = true)



In [16]:
from time import time

t0 = time()
interactions_df.select("protocol_type", "duration", "dst_bytes").groupBy("protocol_type").count().show()
tt = time() - t0

print("Query performed in {} seconds".format(round(tt,3)))

+-------------+------+
|protocol_type| count|
+-------------+------+
|          tcp|190065|
|          udp| 20354|
|         icmp|283602|
+-------------+------+

Query performed in 6.811 seconds


In [22]:
t0 = time()
interactions_df.select("protocol_type", "duration", "dst_bytes").filter(interactions_df.duration>1000).filter(interactions_df.dst_bytes==0).groupBy("protocol_type").count().show()
tt = time() - t0

print("Query performed in {} seconds".format(round(tt,3)))

+-------------+-----+
|protocol_type|count|
+-------------+-----+
|          tcp|  139|
+-------------+-----+

Query performed in 7.264 seconds


In [23]:
def get_label_type(label):
    if label!="normal.":
        return "attack"
    else:
        return "normal"
    
row_labeled_data = csv_data.map(lambda p: Row(
    duration=int(p[0]), 
    protocol_type=p[1],
    service=p[2],
    flag=p[3],
    src_bytes=int(p[4]),
    dst_bytes=int(p[5]),
    label=get_label_type(p[41])
    )
)
interactions_labeled_df = sqlContext.createDataFrame(row_labeled_data)

In [24]:
t0 = time()
interactions_labeled_df.select("label").groupBy("label").count().show()
tt = time() - t0

print("Query performed in {} seconds".format(round(tt,3)))

+------+------+
| label| count|
+------+------+
|normal| 97278|
|attack|396743|
+------+------+

Query performed in 6.525 seconds


In [26]:
t0 = time()
interactions_labeled_df.select("label", "protocol_type").groupBy("label", "protocol_type").count().show()
tt = time() - t0

print("Query performed in {} seconds".format(round(tt,3)))

+------+-------------+------+
| label|protocol_type| count|
+------+-------------+------+
|normal|          udp| 19177|
|normal|         icmp|  1288|
|normal|          tcp| 76813|
|attack|         icmp|282314|
|attack|          tcp|113252|
|attack|          udp|  1177|
+------+-------------+------+

Query performed in 6.929 seconds


In [39]:
import pyspark.sql.functions as F
total = interactions_labeled_df.count()
interactions_labeled_df.select("label", "protocol_type").groupBy("label", "protocol_type").count().\
    withColumn('total',F.lit(total)).\
    withColumn('percentage',F.round(F.expr('count/total'), 3)).show()#.filter('fraction>0.1').show()

+------+-------------+------+------+----------+
| label|protocol_type| count| total|percentage|
+------+-------------+------+------+----------+
|normal|          udp| 19177|494021|     0.039|
|normal|         icmp|  1288|494021|     0.003|
|normal|          tcp| 76813|494021|     0.155|
|attack|         icmp|282314|494021|     0.571|
|attack|          tcp|113252|494021|     0.229|
|attack|          udp|  1177|494021|     0.002|
+------+-------------+------+------+----------+



In [40]:
t0 = time()
interactions_labeled_df.select("label", "protocol_type", "dst_bytes").groupBy("label", "protocol_type", interactions_labeled_df.dst_bytes==0).count().show()
tt = time() - t0

print("Query performed in {} seconds".format(round(tt,3)))

+------+-------------+---------------+------+
| label|protocol_type|(dst_bytes = 0)| count|
+------+-------------+---------------+------+
|normal|          udp|          false| 15583|
|attack|          udp|          false|    11|
|attack|          tcp|           true|110583|
|normal|          tcp|          false| 67500|
|attack|         icmp|           true|282314|
|attack|          tcp|          false|  2669|
|normal|          tcp|           true|  9313|
|normal|          udp|           true|  3594|
|normal|         icmp|           true|  1288|
|attack|          udp|           true|  1166|
+------+-------------+---------------+------+

Query performed in 8.844 seconds
