Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skrules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .skope_rules import SkopeRules
from .rule import Rule
from .rule import Rule, replace_feature_name

__all__ = ['SkopeRules', 'Rule']
11 changes: 11 additions & 0 deletions skrules/rule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
import re

def replace_feature_name(rule, replace_dict):
def replace(match):
return replace_dict[match.group(0)]

rule = re.sub('|'.join(r'\b%s\b' % re.escape(s) for s in replace_dict),
replace, rule)
return rule

class Rule:
""" An object modelizing a logical rule and add factorization methods.
It is used to simplify rules and deduplicate them.
Expand Down Expand Up @@ -56,3 +66,4 @@ def __repr__(self):
[feature, symbol, str(self.agg_dict[(feature, symbol)])])
for feature, symbol in sorted(self.agg_dict.keys())
])

31 changes: 24 additions & 7 deletions skrules/skope_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from sklearn.externals import six
from sklearn.tree import _tree

from .rule import Rule
from .rule import Rule, replace_feature_name

INTEGER_TYPES = (numbers.Integral, np.integer)

BASE_FEATURE_NAME = "__C__"

class SkopeRules(BaseEstimator):
""" An easy-interpretable classifier optimizing simple logical rules.
Expand Down Expand Up @@ -249,11 +249,17 @@ def fit(self, X, y, sample_weight=None):
self.estimators_samples_ = []
self.estimators_features_ = []

# default columns names of the form ['c0', 'c1', ...]:
feature_names_ = (self.feature_names if self.feature_names is not None
else ['c' + x for x in
np.arange(X.shape[1]).astype(str)])
# default columns names :
feature_names_ = [BASE_FEATURE_NAME + x for x in
np.arange(X.shape[1]).astype(str)]
if self.feature_names is not None:
self.feature_dict_ = {BASE_FEATURE_NAME + str(i): feat
for i, feat in enumerate(self.feature_names)}
else:
self.feature_dict_ = {BASE_FEATURE_NAME + str(i): feat
for i, feat in enumerate(feature_names_)}
self.feature_names_ = feature_names_

clfs = []
regs = []

Expand Down Expand Up @@ -356,6 +362,10 @@ def fit(self, X, y, sample_weight=None):
for rule in
[Rule(r, args=args) for r, args in rules_]]





# keep only rules verifying precision_min and recall_min:
for rule, score in rules_:
if score[0] >= self.precision_min and score[1] >= self.recall_min:
Expand All @@ -377,7 +387,14 @@ def fit(self, X, y, sample_weight=None):
# Deduplicate the rule using semantic tree
if self.max_depth_duplication is not None:
self.rules_ = self.deduplicate(self.rules_)

self.rules_ = sorted(self.rules_, key=lambda x: - self.f1_score(x))
self.rules_without_feature_names_ = self.rules_

# Replace generic feature names by real feature names
self.rules_ = [(replace_feature_name(rule, self.feature_dict_), perf)
for rule, perf in self.rules_]

return self

def predict(self, X):
Expand Down Expand Up @@ -432,7 +449,7 @@ def decision_function(self, X):
% (X.shape[1], self.n_features_))

df = pandas.DataFrame(X, columns=self.feature_names_)
selected_rules = self.rules_
selected_rules = self.rules_without_feature_names_

scores = np.zeros(X.shape[0])
for (r, w) in selected_rules:
Expand Down
12 changes: 11 additions & 1 deletion skrules/tests/test_rule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sklearn.utils.testing import assert_equal, assert_not_equal

from skrules import Rule
from skrules import Rule, replace_feature_name


def test_rule():
Expand Down Expand Up @@ -53,3 +53,13 @@ def test_equals_rule():

rule3 = "a < 3.0 and a == a"
assert_equal(rule3, str(Rule(rule3)))


def test_replace_feature_name():
rule = "__C__0 <= 3 and __C__1 > 4"
real_rule = "$b <= 3 and c(4) > 4"
replace_dict = {
"__C__0": "$b",
"__C__1": "c(4)"
}
assert_equal(replace_feature_name(rule, replace_dict=replace_dict), real_rule)