In [1]:
import warnings
from mpl_toolkits.mplot3d import Axes3D
from ipywidgets import Button, Layout
from datetime import datetime
# warnings.filterwarnings("ignore")
from umap import UMAP
import pandas as pd
import matplotlib.patches as mpatches
from pprint import pprint
import numpy as np
from ipywidgets import HBox, VBox, interactive
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.dates as mdates
from IPython.display import display, Markdown
import ipywidgets as widgets
import math
from sklearn.manifold import TSNE
import torch
from ipywidgets import interact, interact_manual
from IPython.display import set_matplotlib_formats
from importlib import reload
from pathlib import Path
from tqdm import tqdm

import sys
import os
sys.path.insert(0, os.path.abspath('../src/'))

import model_codebase as cb
import cicids2017 as cicids2017
import data_generator as generator

reload(cicids2017)
reload(generator)

pd.set_option('display.max_rows', 500)
pd.set_option('float_format', '{:.4f}'.format)

RES_PATH = "."
MODELPATH = Path(f"../res/{RES_PATH}/ts2vec.torch")
DATASETPATH_CACHE = Path(f"../res/{RES_PATH}/cache")
DATASETPATH = Path("../dataset/CICIDS2017_ntop.pkl")

In [2]:
mMIN = 50
M = 254 - mMIN
MAX_ENC = M + math.pow(M, 2) + math.pow(M, 3)

def encdec(rgb=None, i=None):
    """
    References:
        https://coderwall.com/p/fzni3g/bidirectional-translation-between-1d-and-3d-arrays
    """
    if rgb is not None:
        i = rgb[0] + rgb[1] * M + rgb[2] * math.pow(M, 2)
        return math.floor(i)
    x = math.floor(i % M)
    y = math.floor((i / M) % M)
    z = math.floor(i / ( math.pow(M, 2) ))
    return (x+mMIN, y+mMIN, z+mMIN)
           
def ncolors(n):
    hexl = list('0123456789ABCDEF')
    hexc = np.random.choice(hexl, size=(n, 6))
    return ['#' + ''.join(x) for x in hexc]
    
def rgbtohex(r, g, b):
    hexcol = "#{:02x}{:02x}{:02x}".format(r,g,b)
    return hexcol.upper()

### Loading model and data

In [8]:
TARGET_DSET = "testing_attacks"

model_args = {
    "sigma": -0.25,
    "input_size": 19,
    "rnn_size": 128,
    "rnn_layers": 3,
    "latent_size": 128
}

ts2vec = cb.STC(**model_args).eval()
ts2vec.load_state_dict(torch.load(str(MODELPATH), map_location=torch.device('cpu')))

dset = cicids2017.load_dataset(DATASETPATH_CACHE)[TARGET_DSET]
X = torch.Tensor(dset["context"])
actv = X[:, :cicids2017.ACTIVITY_LEN]

with torch.no_grad():    
    dset["embedding"] = ts2vec.toembedding(actv).numpy()
    
dset["ebs2D"] = TSNE(n_components=2).fit_transform(dset["embedding"])

In [9]:
timestamp2weekday = lambda x: datetime.fromtimestamp(x).weekday()
dset_day = list(map(timestamp2weekday, dset["start_time"]))
dset["weekday"] = np.array(dset_day)

# Broadcast address fix :-)
dset["device_category"][np.where(dset["host"]=="192.168.10.255")] = "broadcast"

In [10]:
lower_type = np.char.lower(dset["attack_type"])
scan_mask = np.char.find(lower_type, "scan")+1
scan_mask = np.array(scan_mask, dtype=bool)
dset["attack_type"][scan_mask] = "NMap"

### What about K-Means?

In [42]:
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=17, random_state=0).fit(dset["ebs2D"])
dset["kmeans"] = kmeans.labels_

### Widget creation

In [47]:
%matplotlib inline
set_matplotlib_formats('svg')

DAY2INT = {
    "Monday": 0,
    "Tuesday": 1,
    "Wednesday": 2,
    "Thursday": 3,
    "Friday": 4}
INT2DAY = {v: k for k, v in DAY2INT.items()}

devices = set([(c, h) for (c, h) in zip(dset["device_category"], dset["host"])])
devices_str = [f"{h} ({c})" for c, h in devices]
devices_str.sort()


# ----- ----- COLORMAPS ----- ----- #
# ----- ----- ------- ----- ----- #
everythingAcolor = np.concatenate([
                        np.unique(dset["attack_type"]),
                        [dev[0] for dev in devices],
                        [dev[1] for dev in devices]])  
everythingAcolor = dict(zip(everythingAcolor, ncolors(len(everythingAcolor))))


# ----- ----- WIDGETS ----- ----- #
# ----- ----- ------- ----- ----- #
device_w_list = widgets.Dropdown(options=devices_str)
                                 # value="192.168.10.51 (server)")

available_days = list(map(lambda x: INT2DAY[x], np.unique(dset["weekday"])))
days_w_list = widgets.Dropdown(options=available_days)

available_attacks = list(np.unique(dset["attack_type"]))+["All", "Only"]
show_attacks_dropdown = widgets.Dropdown(options=available_attacks, value="none")

netmap_checkbox = widgets.Checkbox(value=False, description='Show all network')
all_week = widgets.Checkbox(value=False, description='Show all week traffic')
show_category_checkbox = widgets.Checkbox(value=False, description='Show category')
all_week_checkbox = widgets.Checkbox(value=False, description='Show all week traffic data')

ts1_selector = HBox([device_w_list, days_w_list])
ts2_selector = HBox([netmap_checkbox, show_category_checkbox])
ts3_selector = HBox([all_week_checkbox, show_attacks_dropdown])
wlist = VBox([ts1_selector, ts2_selector, ts3_selector])


# ----- ----- INTERACTOR ----- ----- #
# ----- ----- ---------- ----- ----- #
def plotme3d(x1, x2, z, colors, labels):
    ax = plt.gca(projection='3d')
    for l in np.unique(labels):
        l_idxs = np.where(labels==l)[0]
        ax.scatter(x1[l_idxs], x2[l_idxs], z[l_idxs], color=colors[l_idxs], label=l)
    plt.title("t-SNE validation data, attacks excluded")
    ax.view_init(0, 90)
    
def plotme2d(x1, x2, colors, labels):
    ax = plt.gca()
    for l in np.unique(labels):
        l_idxs = np.where(labels==l)[0]
        ax.scatter(x1[l_idxs], x2[l_idxs], color=colors[l_idxs], label=l)
    plt.title("t-SNE validation data, attacks excluded")
    
def whandler(device, day, netmap, show_category, all_week, attack):
    if all_week:
        idx_mask = np.full(len(dset["weekday"]), True)
    else:
        idx_mask = (dset["weekday"] == DAY2INT[day])
    
    host_only = not netmap
    if host_only:
        # Extracting host/category
        split = device.split(" ")
        host = split[0].strip()
        category = " ".join(split[1:]).replace("(", "").replace(")", "").strip()
        # Masking
        host_mask = (dset["host"]==host)
        idx_mask = idx_mask & host_mask
    masked_dset = { k: v[idx_mask] for k, v in dset.items() }
    
    if attack != "Only":
        no_attacks_mask = (masked_dset["attack_type"] == "none")
        no_attacks_dset = { k: v[no_attacks_mask] for k, v in masked_dset.items() }
        x1 = no_attacks_dset["ebs2D"][:, 0]
        x2 = no_attacks_dset["ebs2D"][:, 1]
        # z = no_attacks_dset["ebs2D"][:, 2]

        focus_label = "device_category" if netmap and show_category else "host"
        colors = np.array(list(map(lambda x: everythingAcolor[x], no_attacks_dset[focus_label].squeeze())))
        labels = no_attacks_dset[focus_label]
        # plotme3d(x1, x2, z, colors, labels)
        plotme2d(x1, x2, colors, labels)
    
    # Plotting attacks..... #
    if attack != "none":
        if attack=="All" or attack=="Only":
            attack_mask = masked_dset["attack_type"] != "none"
        else:
            attack_mask = masked_dset["attack_type"] == attack
        
        attack_type = masked_dset["attack_type"][attack_mask].squeeze()    
        colors = np.array(list(map(lambda x: everythingAcolor[x], attack_type)))
        labels = masked_dset["attack_type"][attack_mask]
        
        attack_dset = { k: v[attack_mask] for k, v in masked_dset.items() }
        x1 = attack_dset["ebs2D"][:, 0]
        x2 = attack_dset["ebs2D"][:, 1]
       
        # plotme3d(x1, x2, z, colors, labels)
        plotme2d(x1, x2, colors, labels)

    # plt.legend(loc=(1.04,0))
    plt.legend()
    plt.gcf().set_size_inches(7, 7)
    plt.savefig("tSNE_clustering.svg")
    plt.show()

In [48]:
output = widgets.interactive(whandler,
                             all_week=all_week_checkbox,
                             device=device_w_list, day=days_w_list, 
                             netmap=netmap_checkbox, 
                             show_category=show_category_checkbox,
                             attack=show_attacks_dropdown).children[-1]
display(wlist)
display(output)

VBox(children=(HBox(children=(Dropdown(description='device', options=('192.168.10.50 (server)',), value='192.1…

Output()