-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【OSCP】使用SPU实现随机森林算法 #752
【OSCP】使用SPU实现随机森林算法 #752
Conversation
All contributors have signed the CLA ✍️ ✅ |
I have read the CLA Document and I hereby sign the CLA |
sml/forest/BUILD.bazel
Outdated
@@ -0,0 +1,22 @@ | |||
# Copyright 2023 Ant Group Co., Ltd. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
麻烦把forest.py移动到ensemble目录内,对应的tests和emulation和build文件等也移动一下哈~
sml/forest/forest.py
Outdated
self, | ||
n_estimators, | ||
max_features, | ||
n_features, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_features 建议在 fit 里从数据集获取,max_features的校验或者具体值的计算也可以放到fit里
sml/forest/forest.py
Outdated
self.splitter = splitter | ||
self.max_depth = max_depth | ||
self.bootstrap = bootstrap | ||
self.max_samples = max_samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同n_features和max_features,具体的校验和值的计算延迟到fit中
sml/forest/forest.py
Outdated
bootstrap, | ||
max_samples, | ||
n_labels, | ||
seed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed参数不需要
sml/forest/forest.py
Outdated
X_sample, y_sample = self._bootstrap_sample(X, y) | ||
features = self._select_features() | ||
# selected_indices = self._shuffle_indices(n_features) | ||
print(y_sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要print
sml/forest/forest.py
Outdated
for i, tree in enumerate(self.trees): | ||
features = self.features_indices[i] | ||
print(features) | ||
tree_predictions = tree_predictions.at[:, i].set( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尽量少用.set,这种jax的update方法,会重新copy所有数据,可以先计算好,然后jnp.array一次性得到array
sml/forest/forest.py
Outdated
return y_pred.ravel() | ||
|
||
|
||
def jax_mode_row(data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用这么复杂,,label取值为0,1,2,... (decision tree的要求),比如二分类,直接统计所有的tree里==0和==1的个数,返回其中大的即可。。(请尽量避免循环,,善用向量化的)
sml/forest/forest.py
Outdated
|
||
class RandomForestClassifier: | ||
"""A random forest classifier.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__init__里的超参数需要解释;
以及对数据格式的要求也需要说明(可以参考决策树模型里的一些注释说明信息)
sml/forest/forest.py
Outdated
from sml.tree.tree import DecisionTreeClassifier as sml_dtc | ||
|
||
# from functools import partial | ||
# from jax import jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无关注释可以清理一下
sml/forest/forest.py
Outdated
Description : 基本完成函数编写的工作,目前测试结果基本正确,后面需要完成emul和test | ||
bootstrap有问题,bootstrap后predict不输出1,bootstrap无1(因为不支持jax.random的api) | ||
|
||
!最终:bootstrap这个参数,不可用:在明文下bootstrap取样正确,但在forest_test.py时,无法取到标签1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分说明内容也可以清理一下
Pull Request
What problem does this PR solve?
Issue Number: Fixed 254 #
Possible side effects?
Performance:
Backward compatibility: