Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit fc56dbc

Browse files
committed
incoporating all the latest comments
1 parent 58d0bf8 commit fc56dbc

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

rfcs/20180626-tensor-forest.md

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ classifier = estimator.TensorForestClassifier(feature_columns=[feature_1, featur
107107
model_dir=None,
108108
n_classes=2,
109109
label_vocabulary=None,
110+
head=None,
110111
n_trees=100,
111112
max_nodes=1000,
112113
num_splits_to_consider=10,
113114
split_after_samples=250,
114-
base_random_seed=0,
115115
config=None)
116116
117117
@@ -121,6 +121,12 @@ def input_fn_train():
121121
122122
classifier.train(input_fn=input_fn_train)
123123
124+
def input_fn_predict():
125+
...
126+
return dataset
127+
128+
classifier.predict(input_fn=input_fn_predict)
129+
124130
def input_fn_eval():
125131
...
126132
return dataset
@@ -133,11 +139,14 @@ Here are some explained details for the classifier parameters:
133139
- **feature_columns**: An iterable containing all the feature columns used by the model. All items in the set should be instances of classes derived from FeatureColumn.
134140
- **n_classes**: Defaults to 2. The number of classes in a classification problem.
135141
- **model_dir**: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into an estimator to continue training a previously saved model.
136-
- **label_vocabulary**: A list of strings representing all possible label values. If provided, labels must be of string type and their values must be present in label_vocabulary list. If label_vocabulary is omitted, it is assumed that the labels are already encoded as integer or float values within [0, 1] for n_classes=2, or encoded as integer values in {0, 1,..., n_classes-1} for n_classes>2 . If vocabulary is not provided and labels are of string, an error will be generated.
142+
- **label_vocabulary**: A list of strings representing all possible label values. If provided, labels must be of string type and their values must be present in label_vocabulary list. If label_vocabulary is omitted, it is assumed that the labels are already encoded as integer values within {0, 1} for n_classes=2, or encoded as integer values in {0, 1,..., n_classes-1} for n_classes>2 . If vocabulary is not provided and labels are of string, an error will be generated.
143+
- **head**: .A `head_lib._Head` instance, the loss would be calculated for metrics purpose and not being used for training. If not provided, one will be automatically created based on params
137144
- **n_trees**: The number of trees to create. Defaults to 100. There usually isn't any accuracy gain from using higher values (assuming deep enough trees are built).
138145
- **max_nodes**: Defaults to 10k. No tree is allowed to grow beyond max_nodes nodes, and training stops when all trees in the forest are this large.
139146
- **num_splits_to_consider**: Defaults to sqrt(num_features). In the extremely randomized tree training algorithm, only this many potential splits are evaluated for each tree node.
140147
- **split_after_samples**: Defaults to 250. In our online version of extremely randomized tree training, we pick a split for a node after it has accumulated this many training samples.
148+
- **config**: RunConfig object to configure the runtime settings.
149+
141150

142151

143152
### TensorForestRegressor
@@ -149,11 +158,11 @@ feature_2 = numeric_column('feature_2')
149158
regressor = estimator.TensorForestRegressor(feature_columns=[feature_1, feature_2],
150159
model_dir=None,
151160
label_dimension=1,
161+
head=None,
152162
n_trees=100,
153163
max_nodes=1000,
154164
num_splits_to_consider=10,
155165
split_after_samples=250,
156-
base_random_seed=0,
157166
config=None)
158167
159168
@@ -163,6 +172,12 @@ def input_fn_train():
163172
164173
regressor.train(input_fn=input_fn_train)
165174
175+
def input_fn_predict():
176+
...
177+
return dataset
178+
179+
regressor.predict(input_fn=input_fn_predict)
180+
166181
def input_fn_eval():
167182
...
168183
return dataset
@@ -172,15 +187,15 @@ metrics = regressor.evaluate(input_fn=input_fn_eval)
172187

173188
Here are some explained details for the regressor parameters:
174189

175-
* **feature_columns:** An iterable containing all the feature columns used by the model. All items in the set should be instances of classes derived from `FeatureColumn`.
176-
* **model_dir:** Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.
177-
* **label_dimension:** Defaults to 1. Number of regression targets per example.
178-
* **n_trees:** The number of trees to create. Defaults to 100. There usually isn't any accuracy gain from using higher values.
179-
* **max_nodes:** Defaults to 10,000. No tree is allowed to grow beyond max_nodes nodes, and training stops when all trees in the forest are this large.
180-
* **num_splits_to_consider:** Defaults to `sqrt(num_features)`. In the extremely randomized tree training algorithm, only this many potential splits are evaluated for each tree node.
181-
* **split_after_samples:** Defaults to 250. In our online version of extremely randomized tree training, we pick a split for a node after it has accumulated this many training samples.
182-
* **base_random_seed:** By default (base_random_seed = 0), the random number generator for each tree is seeded by a 64-bit random value when each tree is first created. Using a non-zero value causes tree training to be deterministic, in that the i-th tree's random number generator is seeded with the value base_random_seed + i.
183-
* **config:** `RunConfig` object to configure the runtime settings.
190+
- **feature_columns:** An iterable containing all the feature columns used by the model. All items in the set should be instances of classes derived from `FeatureColumn`.
191+
- **model_dir:** Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.
192+
- **label_dimension:** Defaults to 1. Number of regression targets per example.
193+
- **head**: .A `head_lib._Head` instance, the loss would be calculated for metrics purpose and not being used for training. If not provided, one will be automatically created based on params
194+
- **n_trees:** The number of trees to create. Defaults to 100. There usually isn't any accuracy gain from using higher values.
195+
- **max_nodes:** Defaults to 10,000. No tree is allowed to grow beyond max_nodes nodes, and training stops when all trees in the forest are this large.
196+
- **num_splits_to_consider:** Defaults to `sqrt(num_features)`. In the extremely randomized tree training algorithm, only this many potential splits are evaluated for each tree node.
197+
- **split_after_samples:** Defaults to 250. In our online version of extremely randomized tree training, we pick a split for a node after it has accumulated this many training samples.
198+
- **config:** `RunConfig` object to configure the runtime settings.
184199

185200
### First version supported features
186201

@@ -222,6 +237,11 @@ During inference, for every batch of data, we pass through the tree structure an
222237

223238
Since the trees are independent, for the distributed version, we would distribute the number of trees required to train evenly among all the available workers. For every tree, they would have two tf.resources available for training.
224239

240+
## Differences from the latest contrib version
241+
242+
- Simplified code with only limited subset of features (obviously, excluding all the experimental ones)
243+
- New estimator interface, support for new feature columns and losses
244+
225245
## Future Work
226246

227247
Add sample importance, right now we don’t support sample importance, which it’s a widely used [feature](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html#sklearn.ensemble.ExtraTreesClassifier.fit).

0 commit comments

Comments
 (0)