In [16]:
from os.path import expanduser, join
import pathlib

root = pathlib.Path().resolve()

In [17]:
import sys
sys.path.insert(1, join(root, 'lib'))

import config
import functions
import data
import model
import export

In [18]:
datasets = [{
    'label': 'HTTP/S Graph (SLDs)',
    'data': data.read(join(root, 'data', 'graph-data-sld.csv')).sample(n=100)
  }, {
    'label': 'HTTP/S Graph (FQDN)',
    'data': data.read(join(root, 'data', 'graph-data-fqdn.csv'))
  }
]

In [19]:
features = [col for col in list(datasets[0].get('data').columns) if col.lower() not in ['id', 'weight', 'tracker']]

In [20]:
from sklearn.preprocessing import LabelEncoder

for dataset in datasets:
  dataset.get('data')['tracker'] = LabelEncoder().fit_transform(dataset.get('data')['tracker'])

In [21]:
extension = []
for dataset in datasets:
  extension.append({
    'label': dataset.get('label') + ' 50/50',
    'data': data.sample_equal_distribution(dataset.get('data'), 'tracker')
  })

datasets.extend(extension)

In [22]:
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeClassifier

models = {
  'continuous': [
    LinearRegression(n_jobs=-1),
    RandomForestRegressor(n_estimators=200, random_state=0, n_jobs=-1)
  ],
  'category':[
    DecisionTreeClassifier(),
    LogisticRegression(solver='lbfgs', max_iter=1000, n_jobs=-1)
  ]
}

In [23]:
%matplotlib agg

results = model.compute_results(
  datasets[:1], 
  models, 
  features, 
  ['weight', 'tracker']
)

In [24]:
export.classification_results(results, root)

In [25]:
export.aggregated_classification_results(root)

In [26]:
# TODO: for each data set one fig 2 x rows, n = columns with n number of models
# TODO: save pdf for each data set
"""
  fig, ax = plt.subplots()
  importances.plot.bar(yerr=result.importances_std, ax=ax)
  fig.tight_layout()
  return plt
"""

for key in results.keys():
  print (key)
  for model in results.get(key):
    print (results.get(key).get(model).get('feature_importance'))

HTTP/S Graph (SLDs)
{'result': {'importances_mean': array([ 8.45910609e+08,  7.99048480e+02,  4.16435308e+05,  2.40677111e+02,
        1.04092241e+05,  1.00518092e+05,  1.60448253e-06, -1.21111440e+06,
       -6.16874809e+02, -1.78517466e+05, -2.83122063e-08,  0.00000000e+00,
        2.71946192e-08,  0.00000000e+00,  1.52244819e+02,  8.30529014e+08]), 'importances_std': array([5.75334499e+06, 4.48652703e+03, 7.29112789e+03, 5.33074763e+02,
       1.19430239e+05, 1.22252967e+05, 1.58935484e-07, 1.16473899e+04,
       2.02069823e+03, 1.53641549e+03, 2.12831152e-08, 0.00000000e+00,
       2.37397760e-08, 0.00000000e+00, 8.91717114e+02, 1.02246611e+07]), 'importances': array([[ 8.44994351e+08,  8.41276243e+08,  8.35547209e+08,
         8.51281070e+08,  8.49044031e+08,  8.53788734e+08,
         8.51975806e+08,  8.47412781e+08,  8.45716621e+08,
         8.38069242e+08],
       [ 8.70158126e+03, -1.89644937e+03,  1.05821371e+04,
        -1.93325981e+03, -1.37961258e+03, -1.04964202e+03,
     