In [1]:
import warnings
from ipywidgets import Button, Layout
from datetime import datetime
# warnings.filterwarnings("ignore")
from umap import UMAP
import pandas as pd
import matplotlib.patches as mpatches
from pprint import pprint
import numpy as np
from ipywidgets import HBox, VBox, interactive
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.dates as mdates
from IPython.display import display, Markdown
import ipywidgets as widgets
import math
from sklearn.manifold import TSNE
import torch
from ipywidgets import interact, interact_manual
from IPython.display import set_matplotlib_formats
from importlib import reload
from pathlib import Path
from tqdm import tqdm

import sys
import os
sys.path.insert(0, os.path.abspath('../src/'))

import model_codebase as cb
import cicids2017 as cicids2017
import data_generator as generator

reload(cicids2017)
reload(generator)

pd.set_option('display.max_rows', 500)
pd.set_option('float_format', '{:.4f}'.format)

MODELPATH = Path("../res/monday_train/ts2vec.torch")
DATASETPATH_CACHE = Path("../res/monday_train/cache")
DATASETPATH = Path("../dataset/CICIDS2017_ntop.pkl")

  import numba.targets


In [2]:
mMIN = 50
M = 254 - mMIN
MAX_ENC = M + math.pow(M, 2) + math.pow(M, 3)

def encdec(rgb=None, i=None):
    """
    References:
        https://coderwall.com/p/fzni3g/bidirectional-translation-between-1d-and-3d-arrays
    """
    if rgb is not None:
        i = rgb[0] + rgb[1] * M + rgb[2] * math.pow(M, 2)
        return math.floor(i)
    x = math.floor(i % M)
    y = math.floor((i / M) % M)
    z = math.floor(i / ( math.pow(M, 2) ))
    return (x+mMIN, y+mMIN, z+mMIN)
           
# def ncolors(n, givehex=False):
#    cols = np.linspace(0, MAX_ENC, n)
#    cols = map(lambda x: encdec(i=x), cols)
#    
#    if givehex:
#        cols = map(lambda x: rgbtohex(*x), cols)
#    return list(cols)

def ncolors(n):
    hexl = list('0123456789ABCDEF')
    hexc = np.random.choice(hexl, size=(n, 6))
    return ['#' + ''.join(x) for x in hexc]
    
def rgbtohex(r, g, b):
    hexcol = "#{:02x}{:02x}{:02x}".format(r,g,b)
    return hexcol.upper()

### Loading model and data

In [3]:
ts2vec = cb.STC().eval()
ts2vec.load_state_dict(torch.load(str(MODELPATH), map_location=torch.device('cpu')))

dset, _, _ = cicids2017.load_dataset(DATASETPATH_CACHE)
X = torch.Tensor(dset["context"])
actv = X[:, :cb.ACTIVITY_LEN]

with torch.no_grad():    
    dset["embedding"] = ts2vec.toembedding(actv).numpy()

In [4]:
# dset["UMAP"] = UMAP().fit_transform(dset["embedding"])
dset["UMAP"] = TSNE(n_components=2).fit_transform(dset["embedding"])

timestamp2weekday = lambda x: datetime.fromtimestamp(x).weekday()
dset_day = list(map(timestamp2weekday, dset["start_time"]))
dset["weekday"] = np.array(dset_day)

# Broadcast address fix :-)
dset["device_category"][np.where(dset["host"]=="192.168.10.255")] = "broadcast"

### Widget creation

In [5]:
%matplotlib inline
set_matplotlib_formats('svg')

DAY2INT = {
    "Monday": 0,
    "Tuesday": 1,
    "Wednesday": 2,
    "Thursday": 3,
    "Friday": 4}
INT2DAY = {v: k for k, v in DAY2INT.items()}

devices = set([(c, h) for (c, h) in zip(dset["device_category"], dset["host"])])
devices_str = [f"{h} ({c})" for c, h in devices]
devices_str.sort()


# ----- ----- COLORMAPS ----- ----- #
# ----- ----- ------- ----- ----- #
hosts = [dev[1] for dev in devices]
host_cmap = dict(zip(hosts, ncolors(len(hosts))))

cats = [dev[0] for dev in devices]
cats_cmap = dict(zip(cats, ncolors(len(cats))))

attacks = np.unique(dset["attack"])
attack_cmap = dict(zip(attacks, ncolors(len(attacks))))


# ----- ----- WIDGETS ----- ----- #
# ----- ----- ------- ----- ----- #
device_w_list = widgets.Dropdown(options=devices_str,
                                 value="192.168.10.50 (server)")

available_days = list(map(lambda x: INT2DAY[x], np.unique(dset["weekday"])))
days_w_list = widgets.Dropdown(options=available_days)

netmap_checkbox = widgets.Checkbox(value=False, description='Show all network')
show_ip_checkbox = widgets.Checkbox(value=False, description='Show specific IP')
ts1_selector = HBox([device_w_list, netmap_checkbox])
ts2_selector = HBox([days_w_list, show_ip_checkbox])
wlist = VBox([ts1_selector, ts2_selector])


# ----- ----- INTERACTOR ----- ----- #
# ----- ----- ---------- ----- ----- #
def whandler(device, day, netmap, show_ip):
    idx_mask = (dset["weekday"] == DAY2INT[day])
    
    host_only = not netmap
    if host_only:
        # Extracting host/category
        split = device.split(" ")
        host = split[0].strip()
        category = " ".join(split[1:]).replace("(", "").replace(")", "").strip()
        # Masking
        host_mask = (dset["host"]==host)
        idx_mask = idx_mask & host_mask      
    x1 = dset["UMAP"][:, 0]
    x2 = dset["UMAP"][:, 1]
    
    if not netmap:
        focus_label, cmap = ("attack", attack_cmap)
    elif netmap:
        if show_ip:
            focus_label, cmap = ("host", host_cmap) 
        else:
            focus_label, cmap = ("device_category", cats_cmap) 
    colors = np.array(list(map(lambda x: cmap[x], dset[focus_label])))
    labels = dset[focus_label]
    
    # Plotting ..... #
    ax = plt.gca()
    for l in np.unique(labels):
        l_idxs = np.where(labels==l)[0]
        ax.scatter(x1[l_idxs], x2[l_idxs], color=colors[l_idxs], label=l)
    plt.legend(loc=(1.04,0))
    plt.gcf().set_size_inches(7, 7)
    plt.show()


In [6]:
output = widgets.interactive(whandler,
                             device=device_w_list, day=days_w_list, 
                             netmap=netmap_checkbox, show_ip=show_ip_checkbox).children[-1]
display(wlist)
display(output)

VBox(children=(HBox(children=(Dropdown(description='device', index=11, options=('192.168.10.1 (unknown device …

Output()