In [1]:
#https://github.com/jadianes/spark-py-notebooks/blob/master/nb9-mllib-trees/nb9-mllib-trees.ipynb

In [1]:
import os
#os.getcwd()

In [3]:
import urllib.request
data_file = "kddcup.data.gz"
if not os.path.isfile(data_file):
    f = urllib.request.urlretrieve ("http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data.gz", "kddcup.data.gz")

In [5]:
import findspark
findspark.init()
import pyspark

sc = pyspark.SparkContext(appName="test")
#data_file = "/kddcup.data.gz"
raw_data = sc.textFile('file://' + os.getcwd() + "/"+ data_file)

print("Train data size is {}".format(raw_data.count()))

Train data size is 4898431


In [6]:
test_data_file = "corrected.gz"
if not os.path.isfile(data_file):
    ft = urllib.request.urlretrieve("http://kdd.ics.uci.edu/databases/kddcup99/corrected.gz", "corrected.gz")

test_raw_data = sc.textFile('file://' + os.getcwd() + "/" + test_data_file)

print("Test data size is {}".format(test_raw_data.count()))

Test data size is 311029


In [7]:
# Don't have enough resources
raw_data = raw_data.sample(False, 0.3, 1234)
test_raw_data = test_raw_data.sample(False, 0.3, 1234)

In [8]:
from pyspark.mllib.regression import LabeledPoint
from numpy import array

csv_data = raw_data.map(lambda x: x.split(","))
test_csv_data = test_raw_data.map(lambda x: x.split(","))

protocols = csv_data.map(lambda x: x[1]).distinct().collect()
services = csv_data.map(lambda x: x[2]).distinct().collect()
flags = csv_data.map(lambda x: x[3]).distinct().collect()

In [9]:
def create_labeled_point(line_split):
    # leave_out = [41]
    clean_line_split = line_split[0:41]
    
    # convert protocol to numeric categorical variable
    try: 
        clean_line_split[1] = protocols.index(clean_line_split[1])
    except:
        clean_line_split[1] = len(protocols)
        
    # convert service to numeric categorical variable
    try:
        clean_line_split[2] = services.index(clean_line_split[2])
    except:
        clean_line_split[2] = len(services)
    
    # convert flag to numeric categorical variable
    try:
        clean_line_split[3] = flags.index(clean_line_split[3])
    except:
        clean_line_split[3] = len(flags)
    
    # convert label to binary label
    attack = 1.0
    if line_split[41]=='normal.':
        attack = 0.0
        
    return LabeledPoint(attack, array([float(x) for x in clean_line_split]))

training_data = csv_data.map(create_labeled_point)
test_data = test_csv_data.map(create_labeled_point)

In [10]:
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from time import time

# Build the model
t0 = time()
tree_model = DecisionTree.trainClassifier(training_data, numClasses=2, 
                                          categoricalFeaturesInfo={1: len(protocols), 2: len(services), 3: len(flags)},
                                          impurity='gini', maxDepth=4, maxBins=100)
tt = time() - t0

print("Classifier trained in {} seconds".format(round(tt,3)))

Classifier trained in 79.713 seconds


In [11]:
predictions = tree_model.predict(test_data.map(lambda p: p.features))
labels_and_preds = test_data.map(lambda p: p.label).zip(predictions)

In [12]:
t0 = time()
test_accuracy = labels_and_preds.filter(lambda v: v[0] == v[1]).count() / float(test_data.count())
tt = time() - t0

print("Prediction made in {} seconds. Test accuracy is {}".format(round(tt,3), round(test_accuracy,4)))

Prediction made in 7.093 seconds. Test accuracy is 0.9191


In [14]:
print("Learned classification tree model:")
print("==================================")
print(tree_model.toDebugString())

Learned classification tree model:
DecisionTreeModel classifier of depth 4 with 25 nodes
  If (feature 22 <= 45.0)
   If (feature 38 <= 0.88)
    If (feature 36 <= 0.4)
     If (feature 34 <= 0.91)
      Predict: 0.0
     Else (feature 34 > 0.91)
      Predict: 1.0
    Else (feature 36 > 0.4)
     If (feature 2 in {0.0,24.0,25.0,14.0,20.0,1.0,6.0,21.0,13.0,2.0,22.0,12.0,7.0,3.0,16.0,26.0,23.0,8.0,19.0,4.0})
      Predict: 0.0
     Else (feature 2 not in {0.0,24.0,25.0,14.0,20.0,1.0,6.0,21.0,13.0,2.0,22.0,12.0,7.0,3.0,16.0,26.0,23.0,8.0,19.0,4.0})
      Predict: 1.0
   Else (feature 38 > 0.88)
    If (feature 3 in {0.0,2.0})
     Predict: 0.0
    Else (feature 3 not in {0.0,2.0})
     If (feature 30 <= 0.67)
      Predict: 1.0
     Else (feature 30 > 0.67)
      Predict: 0.0
  Else (feature 22 > 45.0)
   If (feature 5 <= 0.0)
    If (feature 2 in {28.0})
     Predict: 0.0
    Else (feature 2 not in {28.0})
     If (feature 11 <= 0.0)
      Predict: 1.0
     Else (feature 11 > 0.0)
     