In [1]:
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql import Row

In [2]:
spark = SparkSession.builder\
    .config("spark.warehouse.dir","file:///c:/tmp")\
    .appName("SparkSQL").getOrCreate()

In [3]:
#http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html
sc = spark.sparkContext
raw_data =sc.textFile('kddcup.data_10_percent.gz').cache()

In [4]:
raw_data.toDebugString()#可以看到rdd之间是如何转换的

b'(1) kddcup.data_10_percent.gz MapPartitionsRDD[1] at textFile at <unknown>:0 [Memory Serialized 1x Replicated]\n |  kddcup.data_10_percent.gz HadoopRDD[0] at textFile at <unknown>:0 [Memory Serialized 1x Replicated]'

In [5]:
raw_data.take(2)

['0,tcp,http,SF,181,5450,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,8,8,0.00,0.00,0.00,0.00,1.00,0.00,0.00,9,9,1.00,0.00,0.11,0.00,0.00,0.00,0.00,0.00,normal.',
 '0,tcp,http,SF,239,486,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,8,8,0.00,0.00,0.00,0.00,1.00,0.00,0.00,19,19,1.00,0.00,0.05,0.00,0.00,0.00,0.00,0.00,normal.']

In [8]:
csv_data = raw_data.map(lambda l:l.split(","))
#csv_data.take(2)

In [9]:
row_data = csv_data.map(lambda  p:Row(
    duration =int(p[0]),
    protocol_type =p[1],
    service =p[2],
    flag =p[3],
    src_bytes =int(p[4]),
    dst_bytes =int(p[5]))                       
)

In [10]:
interactions_df =spark.createDataFrame(row_data)
interactions_df.registerTempTable("interactions")

In [11]:
interactions_df.printSchema()

root
 |-- dst_bytes: long (nullable = true)
 |-- duration: long (nullable = true)
 |-- flag: string (nullable = true)
 |-- protocol_type: string (nullable = true)
 |-- service: string (nullable = true)
 |-- src_bytes: long (nullable = true)



In [13]:
#查询tcp 超过1秒的，同时没有数据传输的
tcp_interactions = spark.sql("select duration,dst_bytes from interactions where protocol_type ='tcp' and duration >1000 and dst_bytes =0 ")
tcp_interactions.show()

+--------+---------+
|duration|dst_bytes|
+--------+---------+
|    5057|        0|
|    5059|        0|
|    5051|        0|
|    5056|        0|
|    5051|        0|
|    5039|        0|
|    5062|        0|
|    5041|        0|
|    5056|        0|
|    5064|        0|
|    5043|        0|
|    5061|        0|
|    5049|        0|
|    5061|        0|
|    5048|        0|
|    5047|        0|
|    5044|        0|
|    5063|        0|
|    5068|        0|
|    5062|        0|
+--------+---------+
only showing top 20 rows



In [17]:
#按照协议类型进行分组统计
from time import time
t0 =time()
interactions_df.select("protocol_type","duration","dst_bytes").groupBy("protocol_type").count().show()
tt =time()-t0
print("querytime  in {} seconds".format(round(tt,3)))

+-------------+------+
|protocol_type| count|
+-------------+------+
|          tcp|190065|
|          udp| 20354|
|         icmp|283602|
+-------------+------+

querytime  in 10.942 seconds


In [18]:
#按照协议类型进行分组统计
from time import time
t0 =time()
interactions_df.select("protocol_type","duration","dst_bytes")\
    .filter((interactions_df.dst_bytes ==0)&(interactions_df.duration >1000))\
    .groupBy("protocol_type").count().show()
tt =time()-t0
print("querytime  in {} seconds".format(round(tt,3)))

+-------------+-----+
|protocol_type|count|
+-------------+-----+
|          tcp|  139|
+-------------+-----+

querytime  in 16.074 seconds


In [26]:
def get_label_type(label):
    if label!="normal":
        return "attack"
    else:
        return "normal"
    
raw_labeled_data =csv_data.map(lambda p:Row(
    duration =int(p[0]),
    protocol_type =p[1],
    service =p[2],
    flag =p[3],
    src_bytes =int(p[4]),
    dst_bytes =int(p[5]),
    label =get_label_type(p[41]))
)
interactions_labeled_df =spark.createDataFrame(raw_labeled_data)

In [27]:
interactions_labeled_df.select("label").groupBy("label").count().show()

+------+------+
| label| count|
+------+------+
|attack|494021|
+------+------+

