Skip to content

Commit

Permalink
Explicit feature parameters for generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
xehivs committed Jan 23, 2020
1 parent 57edae7 commit e155d86
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
15 changes: 15 additions & 0 deletions strlearn/streams/StreamGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def __init__(
n_drifts=0,
concept_sigmoid_spacing=None,
n_classes=2,
n_features=20,
n_informative=2,
n_redundant=2,
n_repeated=0,
n_clusters_per_class=2,
recurring=False,
weights=None,
incremental=False,
Expand All @@ -95,6 +100,11 @@ def __init__(
self.incremental = incremental
self.y_flip = y_flip
self.classes_ = np.array(range(self.n_classes))
self.n_features = n_features
self.n_redundant = n_redundant
self.n_informative = n_informative
self.n_repeated = n_repeated
self.n_clusters_per_class = n_clusters_per_class

def is_dry(self):
"""Checking if we have reached the end of the stream."""
Expand Down Expand Up @@ -144,6 +154,11 @@ def _make_classification(self):
**self.make_classification_kwargs,
n_samples=self.n_chunks * self.chunk_size,
n_classes=self.n_classes,
n_features=self.n_features,
n_informative=self.n_informative,
n_redundant=self.n_redundant,
n_repeated=self.n_repeated,
n_clusters_per_class=self.n_clusters_per_class,
random_state=self.random_state + i,
weights=weights.tolist(),
)[0].T
Expand Down
5 changes: 4 additions & 1 deletion strlearn/tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
sys.path.insert(0, "../..")
from sklearn.naive_bayes import GaussianNB


"""
def test_download_arff():
url = "http://156.17.43.89/Toyset.arff"
r = requests.get(url)
Expand All @@ -20,6 +20,7 @@ def test_download_arff():
r = requests.get(url)
with open("Elec.arff", "wb") as f:
f.write(r.content)
"""


def test_generator_same():
Expand Down Expand Up @@ -118,6 +119,7 @@ def test_generator_str():
assert str(stream) == "gr_css999_rs1410_nd0_ln50_50_d50_50000"


"""
def test_arff_parser():
stream = sl.streams.ARFFParser("Toyset.arff")
assert str(stream) == "Toyset.arff"
Expand All @@ -132,3 +134,4 @@ def test_arff_parser():
evaluator = sl.evaluators.TestThenTrain(metrics=(accuracy_score))
evaluator.process(stream, clf)
stream.reset()
"""

0 comments on commit e155d86

Please sign in to comment.