In [1]:
import sklearn
from sklearn_porter import Porter
import pickle
from sklearn.externals import joblib
import joblib
import numpy as np
import os
import re
import json

import ccx as cx

In [2]:
# specifically needs version 0.20.2
# can not open pickle file for later versions
# activate the virtual environment venv
sklearn.__version__

'0.20.2'

In [3]:
MODEL_PATH = '../model/'
OUT_DIR = '../website/js/extraTrees/'
model_2_3_path = MODEL_PATH + 'model_2_3.pkl'
# 3 to 2 is used
model_3_2_path = MODEL_PATH + 'model_3_2.pkl'

In [4]:
model_2_3 = joblib.load(model_2_3_path)
model_3_2 = joblib.load(model_3_2_path)

In [5]:
print(len(model_2_3))
print(len(model_3_2))

198
242


In [6]:
model_2_3[0].predict(
    np.random.randint(0,6,211).reshape(1,-1))

array([1])

In [19]:
porter = Porter(model_3_2[0], language='js')
output = porter.export()
with open(OUT_DIR + 'extraTrees.js', 'w') as f:
    f.write(output)

In [20]:
porter.integrity_score(np.random.randint(0,6,211 * 1000).reshape(1000,-1))

0.921

In [7]:
model = model_3_2[0]

In [8]:
# NOT NEEEDED
def get_question_numbers(model):
    """
    Args:
        extra trees model
        
    Returns:
        list of question numbers
    """
    
    nodeid=np.zeros(model.n_estimators).astype(int)
    responses={}
    while not all(nodeid<0):
        [nodeid,responses]=cx.getNodeid(model,nodeid=nodeid,responses=responses,ask=cx.ask)
    return list(responses.keys())

In [21]:
def extra_tree_to_js(model, out_file):
    """Create a js file of the extra trees.
    
    Args:
        model: extraTrees model
        out_file: file name to write to
        
    Returns:
        Porter
    """
    
    porter = Porter(model, language='js')
    output = porter.export()
    
    with open(out_file, 'w') as f:
        f.write(output)
        
    return porter
        
# write_sklearn_porter(model, OUT_DIR + 'extraTrees.js')

In [22]:
def extra_trees_to_js(models, out_dir):
    """Create multiple js files from list of extra trees
    
    Args:
        models: list of extra trees models
        out_dir: directory to write js files
        
    Returns:
        list of porters
    """
    
    porters = []
    for i, model in enumerate(models):
        name = str(i) + '.js'
        porter = extra_tree_to_js(model, os.path.join (out_dir, name))
        porters.append(porter)
    return porters
        
porters = extra_trees_to_js(model_3_2, OUT_DIR)

In [27]:
def calc_integrity(porters, num_samples):
    """Calculate the integrity of the porter.
    
    Meaning it calculates the accuracy of the ported 
    classifier.
    The value is between 0 and 1.
    
    Args:
        porters: list of porters
        num_samples: number of samples for each porter to test
        
    Returns:
        average probability
    """
    
    scores = []
    for porter in porters:
        random_samples = np.random.randint(0,6,211 * num_samples).reshape(num_samples,-1)
        score = porter.integrity_score(random_samples)
        scores.append(score)
    scores = np.array(scores)
    return scores.mean()

calc_integrity(porters, 10)

0.859504132231405

In [11]:
def parse_extra_tree_js(file_name):
    """Parse the javascript file for the splits
    
    Args:
        file_name: name of file
        
    Returns:
        dictionary containing question numbers and threshold
    """
    
    split_points = {}
    q_numbers = []
    thresholds = []
    
    with open(file_name, 'r') as f:
        content = f.read()
        
    instances = re.findall(r'features\[\d+\] <= \d+\.\d+', content)
    
    for instance in instances:
        q_number = int(re.search(r'\d+|$', instance).group())
        threshold = float(re.search(r'\d+\.\d+', instance).group())
        threshold = round(threshold, 3)
        q_numbers.append(q_number)
        thresholds.append(threshold)
        
    split_points['q_numbers'] = q_numbers
    split_points['thresholds'] = thresholds 
    return split_points

split_dict = parse_extra_tree_js(OUT_DIR + 'extraTrees.js')
split_dict

{'q_numbers': [143, 180, 186, 201, 184, 137, 60, 68, 192, 138, 186, 96, 160],
 'thresholds': [2.886,
  1.427,
  2.942,
  4.005,
  2.714,
  3.89,
  2.103,
  3.464,
  2.977,
  1.85,
  1.52,
  1.816,
  2.988]}

In [16]:
def parse_multiple_extra_tree_js(file_dir):
    """Parse multiple javascript files for splits.
    CAUTION: right now, we only take the files with exactly 14 splits.
    
    Args:
        file_dir: directory containing the javascript files
        
    Returns:
        
    """
    
    file_name_to_split_points = {}
    js_files = []
    #q_numbers = []
    
    for file in os.listdir(file_dir):
        if file.endswith(".js"):
            js_files.append(file)
    
    for js_file in js_files:
        split_points = parse_extra_tree_js(os.path.join(file_dir, js_file))
        
        if len(split_points['q_numbers']) == 14:
            file_name_to_split_points['f' + js_file.split('.')[0]] = split_points
        #q_numbers.append(split_points['q_numbers'])
    
    return file_name_to_split_points

split_dict = parse_multiple_extra_tree_js(OUT_DIR)

In [18]:
with open('../website/js/tree_splits.json', 'w') as json_file:
#     json.dump(split_dict, json_file, indent=4, sort_keys=True)
    json.dump(split_dict, json_file, sort_keys=True)

In [10]:
# responses={}
# while not all(nodeid<0):
#     [nodeid,responses]=cx.getNodeid(model,nodeid=nodeid,responses=responses,ask=cx.ask)
# # return list(responses.keys())