Permalink
Browse files

Add workaround for fused type attributes

  • Loading branch information...
yenchenlin committed Jul 21, 2016
1 parent 1a8ddbe commit 7c87f29f5869b4b757d6105e51404307df23d3d9
Showing with 12 additions and 0 deletions.
  1. +12 −0 sklearn/neighbors/binary_tree.pxi
@@ -152,6 +152,7 @@ from ..utils import check_array
from typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t
from typedefs import DTYPE, ITYPE
+from cython cimport floating
from dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist,
euclidean_dist_to_rdist, euclidean_rdist_to_dist)
@@ -219,6 +220,11 @@ cdef DTYPE_t[:, ::1] get_memview_DTYPE_2D(
return <DTYPE_t[:X.shape[0], :X.shape[1]:1]> (<DTYPE_t*> X.data)
+cdef float[:, ::1] get_memview_float_2D(
+ np.ndarray[float, ndim=2, mode='c'] X):
+ return <float[:X.shape[0], :X.shape[1]:1]> (<float*> X.data)
+
+
cdef DTYPE_t[:, :, ::1] get_memview_DTYPE_3D(
np.ndarray[DTYPE_t, ndim=3, mode='c'] X):
return <DTYPE_t[:X.shape[0], :X.shape[1], :X.shape[2]:1]>\
@@ -1006,6 +1012,7 @@ cdef class BinaryTree:
cdef np.ndarray node_bounds_arr
cdef readonly DTYPE_t[:, ::1] data
+ cdef readonly float[:, ::1] data_32
cdef public ITYPE_t[::1] idx_array
cdef public NodeData_t[::1] node_data
cdef public DTYPE_t[:, :, ::1] node_bounds
@@ -1051,6 +1058,11 @@ cdef class BinaryTree:
def __init__(self, data,
leaf_size=40, metric='minkowski', **kwargs):
+ self.data_arr = check_array(data, dtype=[np.float64, np.float32],
+ order='C')
+ if self.data_arr.dtype == np.float32:
+ self.data_32 = get_memview_float_2D(self.data_arr)
+
self.data_arr = np.asarray(data, dtype=DTYPE, order='C')
self.data = get_memview_DTYPE_2D(self.data_arr)

0 comments on commit 7c87f29

Please sign in to comment.