In [1]:
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

print("ready!")

In [2]:
!wget http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data_10_percent.gz -P /tmp
dbutils.fs.mv("file:/tmp/kddcup.data_10_percent.gz", "dbfs:/kdd/kddcup_data.gz")

display(dbutils.fs.ls("dbfs:/kdd"))

In [3]:
data_file = "dbfs:/kdd/kddcup_data.gz"
raw_rdd = sc.textFile(data_file).cache()

raw_rdd.take(5)

In [4]:
raw_rdd.count()

In [5]:
csv_rdd = raw_rdd.map(lambda row: row.split(","))
len(csv_rdd.take(1)[0])

In [6]:
from pyspark.sql import Row

parsed_rdd = csv_rdd.map(lambda r: Row(
    duration=int(r[0]),
    protocol_type=r[1],
    service=r[2],
    flag=r[3],
    src_bytes=int(r[4]),
    dst_bytes=int(r[5]),
    wrong_fragment=int(r[7]),
    urgent=int(r[8]),
    hot=int(r[9]),
    num_failed_logins=int(r[10]),
    num_compromised=int(r[12]),
    su_attempted=r[14],
    num_root=int(r[15]),
    num_file_creations=int(r[16]),
    label=r[-1]
    )
)
parsed_rdd.take(5)

In [7]:
spark = sc.getOrCreate()
type(spark)

In [8]:
df = parsed_rdd.toDF()
display(df.head(10))

dst_bytes,duration,flag,hot,label,num_compromised,num_failed_logins,num_file_creations,num_root,protocol_type,service,src_bytes,su_attempted,urgent,wrong_fragment
5450,0,SF,0,normal.,0,0,0,0,tcp,http,181,0,0,0
486,0,SF,0,normal.,0,0,0,0,tcp,http,239,0,0,0
1337,0,SF,0,normal.,0,0,0,0,tcp,http,235,0,0,0
1337,0,SF,0,normal.,0,0,0,0,tcp,http,219,0,0,0
2032,0,SF,0,normal.,0,0,0,0,tcp,http,217,0,0,0
2032,0,SF,0,normal.,0,0,0,0,tcp,http,217,0,0,0
1940,0,SF,0,normal.,0,0,0,0,tcp,http,212,0,0,0
4087,0,SF,0,normal.,0,0,0,0,tcp,http,159,0,0,0
151,0,SF,0,normal.,0,0,0,0,tcp,http,210,0,0,0
786,0,SF,1,normal.,0,0,0,0,tcp,http,212,0,0,0


In [9]:
df.printSchema()

In [11]:
df.registerTempTable("connections")

In [12]:
x = df.groupBy("protocol_type") \
      .count() \
      .orderBy("count", ascending=True)
display(x)

protocol_type,count
udp,20354
tcp,190065
icmp,283602


In [13]:
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

protocols = sqlContext.sql("""
                           SELECT protocol_type, count(*) as freq
                           FROM connections
                           GROUP BY protocol_type
                           ORDER BY 2 DESC
                           """)
display(protocols)

protocol_type,freq
icmp,283602
tcp,190065
udp,20354


In [14]:
labels = sqlContext.sql("""
                           SELECT label, count(*) as freq
                           FROM connections
                           GROUP BY label
                           ORDER BY 2 DESC
                           """)
display(labels)

label,freq
smurf.,280790
neptune.,107201
normal.,97278
back.,2203
satan.,1589
ipsweep.,1247
portsweep.,1040
warezclient.,1020
teardrop.,979
pod.,264


In [15]:
type(labels)

In [16]:
labels_df = pd.DataFrame(labels.toPandas())
labels_df.set_index("label", drop=True,inplace=True)
labels_fig = labels_df.plot(kind='barh')

plt.rcParams["figure.figsize"] = (7, 5)
plt.rcParams.update({'font.size': 10})
plt.tight_layout()
display(labels_fig.figure)