In [1]:
dbutils.help()

In [2]:
import urllib
import urllib.request
urllib.request.urlretrieve("http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data_10_percent.gz", "/tmp/kddcup_data.gz")
dbutils.fs.mv("file:/tmp/kddcup_data.gz", "dbfs:/kdd/kddcup_data.gz")
display(dbutils.fs.ls("dbfs:/kdd"))

path,name,size
dbfs:/kdd/kddcup_data.gz,kddcup_data.gz,2144903


In [3]:
data_file = "dbfs:/kdd/kddcup_data.gz"
raw_rdd = sc.textFile(data_file).cache()
raw_rdd.take(5)

In [4]:
type(raw_rdd)

In [5]:
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
sqlContext 

In [6]:
csv_rdd = raw_rdd.map(lambda row: row.split(","))
print(csv_rdd.take(2))
print(type(csv_rdd))

In [7]:
len(csv_rdd.take(1)[0]) # Check the total number of features (columns)

In [8]:
"""
We will be extracting the following columns based on their positions in each data point (row) and build a new RDD as follows.
"""

from pyspark.sql import Row

parsed_rdd = csv_rdd.map(lambda r: Row(
    duration=int(r[0]),
    protocol_type=r[1],
    service=r[2],
    flag=r[3],
    src_bytes=int(r[4]),
    dst_bytes=int(r[5]),
    wrong_fragment=int(r[7]),
    urgent=int(r[8]),
    hot=int(r[9]),
    num_failed_logins=int(r[10]),
    num_compromised=int(r[12]),
    su_attempted=r[14],
    num_root=int(r[15]),
    num_file_creations=int(r[16]),
    label=r[-1]
    )
)
parsed_rdd.take(5)

In [9]:
# Construct the DataFrame
df = sqlContext.createDataFrame(parsed_rdd)
display(df.head(10))

dst_bytes,duration,flag,hot,label,num_compromised,num_failed_logins,num_file_creations,num_root,protocol_type,service,src_bytes,su_attempted,urgent,wrong_fragment
5450,0,SF,0,normal.,0,0,0,0,tcp,http,181,0,0,0
486,0,SF,0,normal.,0,0,0,0,tcp,http,239,0,0,0
1337,0,SF,0,normal.,0,0,0,0,tcp,http,235,0,0,0
1337,0,SF,0,normal.,0,0,0,0,tcp,http,219,0,0,0
2032,0,SF,0,normal.,0,0,0,0,tcp,http,217,0,0,0
2032,0,SF,0,normal.,0,0,0,0,tcp,http,217,0,0,0
1940,0,SF,0,normal.,0,0,0,0,tcp,http,212,0,0,0
4087,0,SF,0,normal.,0,0,0,0,tcp,http,159,0,0,0
151,0,SF,0,normal.,0,0,0,0,tcp,http,210,0,0,0
786,0,SF,1,normal.,0,0,0,0,tcp,http,212,0,0,0


In [10]:
df.printSchema()

### Build a temporary table

We can leverage the registerTempTable() function to build a temporary table to run SQL commands on our DataFrame at scale! A point to remember is that the lifetime of this temp table is tied to the session. It creates an in-memory table that is scoped to the cluster in which it was created. The data is stored using Hive's highly optimized, in-memory columnar format.

You can also check out saveAsTable(), which creates a permanent, physical table stored in S3 using the Parquet format. This table is accessible to all clusters. The table metadata, including the location of the file(s), is stored within the Hive metastore.

In [12]:
help(df.registerTempTable)

In [13]:
df.registerTempTable("connections")

#### Connections based on the protocol type

Let's look at how we can get the total number of connections based on the type of connectivity protocol. First, we will get this information using normal DataFrame DSL syntax to perform aggregations.

In [15]:
display(df.groupBy('protocol_type').count().orderBy('count', ascending=False))

protocol_type,count
icmp,283602
tcp,190065
udp,20354


In [16]:
protocols = sqlContext.sql("""
                           SELECT protocol_type, count(*) as freq
                           FROM connections
                           GROUP BY protocol_type
                           ORDER BY 2 DESC
                           """)
display(protocols)

protocol_type,freq
icmp,283602
tcp,190065
udp,20354


#### Connections based on good or bad (attack types) signatures

We will now run a simple aggregation to check the total number of connections based on good (normal) or bad (intrusion attacks) types.

In [18]:
labels = sqlContext.sql("""
                           SELECT label, count(*) as freq
                           FROM connections
                           GROUP BY label
                           ORDER BY 2 DESC
                           """)
display(labels)

label,freq
smurf.,280790
neptune.,107201
normal.,97278
back.,2203
satan.,1589
ipsweep.,1247
portsweep.,1040
warezclient.,1020
teardrop.,979
pod.,264


In [19]:
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

In [20]:
labels_df = pd.DataFrame(labels.toPandas())
labels_df.set_index("label", drop=True,inplace=True)
labels_fig = labels_df.plot(kind='barh')

plt.rcParams["figure.figsize"] = (7, 5)
plt.rcParams.update({'font.size': 10})
plt.tight_layout()
display(labels_fig.figure)

### Connections based on protocols and attacks

In [22]:
attack_protocol = sqlContext.sql("""
                           SELECT
                             protocol_type,
                             CASE label
                               WHEN 'normal.' THEN 'no attack'
                               ELSE 'attack'
                             END AS state,
                             COUNT(*) as freq
                           FROM connections
                           GROUP BY protocol_type, state
                           ORDER BY 3 DESC
                           """)
display(attack_protocol)

protocol_type,state,freq
icmp,attack,282314
tcp,attack,113252
tcp,no attack,76813
udp,no attack,19177
icmp,no attack,1288
udp,attack,1177


#### Connection stats based on protocols and attacks

Let's take a look at some statistical measures pertaining to these protocols and attacks for our connection requests.

In [24]:
attack_stats = sqlContext.sql("""
                           SELECT
                             protocol_type, 
                             CASE label
                                WHEN 'normal' THEN 'no attack'
                                ELSE 'attack'
                             END AS state,
                              COUNT(*) as total_freq,
                             ROUND(AVG(src_bytes), 2) as mean_src_bytes,
                             ROUND(AVG(dst_bytes), 2) as mean_dst_bytes,
                             ROUND(AVG(duration), 2) as mean_duration,
                             SUM(num_failed_logins) as total_failed_logins,
                             SUM(num_compromised) as total_compromised,
                             SUM(num_file_creations) as total_file_creations,
                             SUM(su_attempted) as total_root_attempts,
                             SUM(num_root) as total_root_acceses
                           FROM connections
                           GROUP BY protocol_type, state
                           ORDER BY 3 DESC
                           """)
display(attack_stats)
                            
                             
                             

protocol_type,state,total_freq,mean_src_bytes,mean_dst_bytes,mean_duration,total_failed_logins,total_compromised,total_file_creations,total_root_attempts,total_root_acceses
icmp,attack,283602,928.32,0.0,0.0,0,0,0,0.0,0
tcp,attack,190065,6469.0,2248.44,18.3,75,5045,535,18.0,5608
udp,attack,20354,93.94,84.71,993.65,0,0,0,0.0,0


### Filtering connection stats based on the TCP protocol by service and attack type

Let's take a closer look at TCP attacks, given that we have more relevant data and statistics for the same. We will now aggregate different types of TCP attacks based on service and attack type and observe different metrics.

In [26]:
tcp_attack_stats = sqlContext.sql("""
                                   SELECT
                                     service,
                                     label as attack_type,
                                     COUNT(*) as total_freq,
                                     ROUND(AVG(duration), 2) as mean_duration,
                                     SUM(num_failed_logins) as total_failed_logins,
                                     SUM(num_file_creations) as total_file_creations,
                                     SUM(su_attempted) as total_root_attempts,
                                     SUM(num_root) as total_root_acceses
                                   FROM connections
                                   WHERE protocol_type = 'tcp'
                                   AND label != 'normal.'
                                   GROUP BY service, attack_type
                                   ORDER BY total_freq DESC
                                   """)
display(tcp_attack_stats)

service,attack_type,total_freq,mean_duration,total_failed_logins,total_file_creations,total_root_attempts,total_root_acceses
private,neptune.,101317,0.0,0,0,0.0,0
http,back.,2203,0.13,0,0,0.0,0
other,satan.,1221,0.0,0,0,0.0,0
private,portsweep.,725,1915.81,0,0,0.0,0
ftp_data,warezclient.,708,403.71,0,0,0.0,0
ftp,warezclient.,307,1063.79,0,0,0.0,0
other,portsweep.,260,1058.22,0,0,0.0,0
telnet,neptune.,197,0.0,0,0,0.0,0
http,neptune.,192,0.0,0,0,0.0,0
finger,neptune.,177,0.0,0,0,0.0,0


### Filtering connection stats based on the TCP protocol by service and attack type

We will now filter some of these attack types by imposing some constraints in our query based on duration, file creations, and root accesses.

In [28]:
tcp_attack_stats = sqlContext.sql("""
                                   SELECT
                                     service,
                                     label as attack_type,
                                     COUNT(*) as total_freq,
                                     ROUND(AVG(duration), 2) as mean_duration,
                                     SUM(num_failed_logins) as total_failed_logins,
                                     SUM(num_file_creations) as total_file_creations,
                                     SUM(su_attempted) as total_root_attempts,
                                     SUM(num_root) as total_root_acceses
                                   FROM connections
                                   WHERE (protocol_type = 'tcp'
                                          AND label != 'normal.')
                                   GROUP BY service, attack_type
                                   HAVING (mean_duration >= 50
                                           AND total_file_creations >= 5
                                           AND total_root_acceses >= 1)
                                   ORDER BY total_freq DESC
                                   """)
display(tcp_attack_stats)

service,attack_type,total_freq,mean_duration,total_failed_logins,total_file_creations,total_root_attempts,total_root_acceses
telnet,buffer_overflow.,21,130.67,0,15,0.0,5
telnet,loadmodule.,5,63.8,0,9,0.0,3
telnet,multihop.,2,458.0,0,8,0.0,93


### Subqueries to filter TCP attack types based on service

Let's try to get all the TCP attacks based on service and attack type such that the overall mean duration of these attacks is greater than zero (> 0). For this, we can do an inner query with all aggregation statistics and extract the relevant queries and apply a mean duration filter in the outer query, as shown below.

In [30]:
tcp_attack_stats = sqlContext.sql("""
                                   SELECT
                                     t.service,
                                     t.attack_type,
                                     t.total_freq
                                   FROM
                                   (SELECT
                                     service,
                                     label as attack_type,
                                     COUNT(*) as total_freq,
                                     ROUND(AVG(duration), 2) as mean_duration,
                                     SUM(num_failed_logins) as total_failed_logins,
                                     SUM(num_file_creations) as total_file_creations,
                                     SUM(su_attempted) as total_root_attempts,
                                     SUM(num_root) as total_root_acceses
                                   FROM connections
                                   WHERE protocol_type = 'tcp'
                                   AND label != 'normal.'
                                   GROUP BY service, attack_type
                                   ORDER BY total_freq DESC) as t
                                     WHERE t.mean_duration > 0
                                   """)
display(tcp_attack_stats)

service,attack_type,total_freq
http,back.,2203
private,portsweep.,725
ftp_data,warezclient.,708
ftp,warezclient.,307
other,portsweep.,260
private,satan.,170
telnet,guess_passwd.,53
telnet,buffer_overflow.,21
ftp_data,warezmaster.,18
imap4,imap.,12


#### Build a pivot table from aggregated data

We will build upon the previous DataFrame object where we aggregated attacks based on type and service. For this, we can leverage the power of Spark DataFrames and the DataFrame DSL.

In [32]:
display((tcp_attack_stats.groupby('service')
                         .pivot('attack_type')
                         .agg({'total_freq':'max'})
                         .na.fill(0))
)

service,back.,buffer_overflow.,ftp_write.,guess_passwd.,imap.,ipsweep.,loadmodule.,multihop.,perl.,phf.,portsweep.,rootkit.,satan.,spy.,warezclient.,warezmaster.
telnet,0,21,0,53,0,1,5,2,3,0,0,5,1,2,0,0
ftp,0,1,2,0,0,1,1,2,0,0,0,1,1,0,307,2
pop_3,0,0,0,0,0,0,0,0,0,0,3,0,1,0,0,0
discard,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0
login,0,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0
smtp,0,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0
domain,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
http,2203,0,0,0,0,3,0,0,0,4,3,0,0,0,0,0
courier,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0
other,0,0,0,0,0,0,0,0,0,0,260,0,0,0,5,0
