Skip to content
Browse files

predict neighbors refactored

  • Loading branch information...
1 parent 9c1dcc7 commit 2cdfa812c18ae3e6551b2b058b5485afdefe9620 @paolo-losi committed
View
817 python/extratrees/cbindings.c
397 additions, 420 deletions not shown because the diff is too large. Please use a local Git client to view these changes.
View
18 python/extratrees/cbindings.pyx
@@ -13,7 +13,6 @@ from cextratrees cimport (ET_problem, ET_problem_destroy, ET_load_libsvm_file,
ET_forest_neighbors, ET_params,
ET_forest_predict_class_bayes,
class_probability_vec, class_probability,
- neighbour_weight, neighbour_weight_vec,
double_vec, ET_forest_feature_importance)
@@ -172,13 +171,12 @@ cdef class Forest:
@cython.wraparound(False)
def neighbors(self, np.ndarray[np.float32_t, ndim=2] X, curtail=1):
cdef float *vector
+ cdef double *weights
cdef int sample_idx, feature_idx
cdef uint32_t _curtail = curtail
- cdef neighbour_weight_vec *nwv
- cdef neighbour_weight *nw
cdef np.ndarray[np.float64_t, ndim=2] adiacency
- adiacency = numpy.zeros(shape=(X.shape[0], self._forest.n_samples),
+ adiacency = numpy.empty(shape=(X.shape[0], self._forest.n_samples),
dtype=numpy.float64)
vector = <float *> malloc(sizeof(float) * X.shape[1])
@@ -189,16 +187,14 @@ cdef class Forest:
for feature_idx in xrange(X.shape[1]):
vector[feature_idx] = X[sample_idx, feature_idx]
- nwv = ET_forest_neighbors(self._forest, vector, _curtail)
- if not nwv:
+ weights = ET_forest_neighbors(self._forest, vector, _curtail)
+ if not weights:
raise MemoryError()
- for i in xrange(nwv.n):
- nw = &nwv.a[i]
- adiacency[sample_idx, nw.key] = nw.weight
+ for feature_idx in xrange(X.shape[1]):
+ adiacency[sample_idx, feature_idx] = weights[feature_idx]
- free(nwv.a)
- free(nwv)
+ free(weights)
free(vector)
return adiacency
View
2 python/extratrees/cextratrees.c
@@ -1,4 +1,4 @@
-/* Generated by Cython 0.15.1 on Thu Mar 29 13:49:06 2012 */
+/* Generated by Cython 0.15.1 on Sun Apr 1 00:10:29 2012 */
#define PY_SSIZE_T_CLEAN
#include "Python.h"
View
11 python/extratrees/cextratrees.pxd
@@ -51,16 +51,7 @@ cdef extern from "extratrees.h":
float *vector,
uint32_t curtail_min_size,
bool smooth)
-
- ctypedef struct neighbour_weight:
- uint32_t key
- double weight
-
- ctypedef struct neighbour_weight_vec:
- size_t n, m
- neighbour_weight *a
-
- cdef neighbour_weight_vec *ET_forest_neighbors(ET_forest *forest,
+ cdef double *ET_forest_neighbors(ET_forest *forest,
float *vector,
uint32_t curtail_min_size)
View
10 src/extratrees.h
@@ -95,14 +95,6 @@ typedef struct {
// --- predict types ---
typedef struct {
- uint32_t key;
- double weight;
-} neighbour_weight;
-
-typedef kvec_t(neighbour_weight) neighbour_weight_vec;
-
-
-typedef struct {
double label;
double probability;
} class_probability;
@@ -127,7 +119,7 @@ double ET_forest_predict_class_majority(ET_forest *forest, float *v,
uint32_t curtail_min_size);
double ET_forest_predict_class_bayes(ET_forest *forest, float *v,
uint32_t curtail_min_size, bool smooth);
-neighbour_weight_vec *ET_forest_neighbors(ET_forest *forest, float *vector,
+double *ET_forest_neighbors(ET_forest *forest, float *vector,
uint32_t curtail_min_size);
class_probability_vec *ET_forest_predict_probability(ET_forest *forest,
float *vector,
View
28 src/predict.c
@@ -189,15 +189,14 @@ uint_vec **ET_forest_neighbors_detail(ET_forest *forest, float *vector,
}
-neighbour_weight_vec *ET_forest_neighbors(ET_forest *forest, float *vector,
- uint32_t curtail_min_size) {
+double *ET_forest_neighbors(ET_forest *forest, float *vector,
+ uint32_t curtail_min_size) {
uint_vec **neigh_detail;
- neighbour_weight_vec *nwvec;
size_t n_trees = kv_size(forest->trees);
+ double *nwa = NULL;
- nwvec = malloc(sizeof(neighbour_weight_vec));
- check_mem(nwvec);
- kv_init(*nwvec);
+ nwa = calloc(sizeof(double), forest->n_samples);
+ check_mem(nwa);
neigh_detail = ET_forest_neighbors_detail(forest, vector, curtail_min_size);
check_mem(neigh_detail);
@@ -207,16 +206,8 @@ neighbour_weight_vec *ET_forest_neighbors(ET_forest *forest, float *vector,
double incr = 1.0 / (double) (kv_size(*tree_neighs) * n_trees);
for(size_t j = 0; j < kv_size(*tree_neighs); j++) {
- neighbour_weight *nw = NULL;
uint32_t sample_idx = kv_A(*tree_neighs, j);
-
- kal_getp(*nwvec, sample_idx, nw);
- if (nw == NULL) {
- kv_push(neighbour_weight, *nwvec,
- ((neighbour_weight) { sample_idx, incr }));
- } else {
- nw->weight += incr;
- }
+ nwa[sample_idx] += incr;
}
kv_destroy(*tree_neighs);
free(tree_neighs);
@@ -224,7 +215,7 @@ neighbour_weight_vec *ET_forest_neighbors(ET_forest *forest, float *vector,
free(neigh_detail);
exit:
- return nwvec;
+ return nwa;
}
@@ -319,7 +310,6 @@ class_probability_vec *ET_forest_predict_probability(ET_forest *forest,
bool smooth) {
bool error = true;
class_probability_vec *prob_vec = NULL;
- neighbour_weight_vec *nwvec = NULL;
double n_trees = kv_size(forest->trees);
prob_vec = malloc(sizeof(class_probability_vec));
@@ -380,10 +370,6 @@ class_probability_vec *ET_forest_predict_probability(ET_forest *forest,
kv_destroy(*prob_vec);
free(prob_vec);
}
- if (nwvec != NULL) {
- kv_destroy(*nwvec);
- free(nwvec);
- }
return prob_vec;
}
View
55,742 tests/test_importance.out
27,871 additions, 27,871 deletions not shown because the diff is too large. Please use a local Git client to view these changes.
View
9 tests/test_predict.c
@@ -18,6 +18,7 @@ void test_predict() {
float vector2[3] = {2, 1, 1};
float vector3[3] = {2.1, 1, 1};
double prediction;
+ double *neighbor_weights;
class_probability_vec *cpv;
problem_init(&prob, vectors, labels);
@@ -64,8 +65,16 @@ void test_predict() {
prediction = ET_forest_predict_class_bayes(forest, vector3, 1, false);
fprintf(stderr, "class prediction vector3 (bayes): %g\n", prediction);
+ neighbor_weights = ET_forest_neighbors(forest, vector3, 1);
+ fprintf(stderr, "neighbor weights for vector3:\n");
+ for(size_t i = 0; i < forest->n_samples; i++) {
+ fprintf(stderr, " - sample_idx: %zd. weight: %g\n",
+ i, neighbor_weights[i]);
+ }
+
ET_forest_destroy(forest);
free(forest);
+ free(neighbor_weights);
}
View
16,811 tests/test_predict.out
8,512 additions, 8,299 deletions not shown because the diff is too large. Please use a local Git client to view these changes.
View
2,432 tests/test_serialization.out
1,311 additions, 1,121 deletions not shown because the diff is too large. Please use a local Git client to view these changes.
View
1,564 tests/test_train.out
782 additions, 782 deletions not shown because the diff is too large. Please use a local Git client to view these changes.

0 comments on commit 2cdfa81

Please sign in to comment.
Something went wrong with that request. Please try again.