Skip to content

Commit

Permalink
Conversion to arrays or backing stores should trim store to correct l…
Browse files Browse the repository at this point in the history
…ength.
  • Loading branch information
Chris Nuernberger committed Feb 16, 2019
1 parent 02766de commit ca6ee2c
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/tech/compute.clj
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ intended usage of the buffer."
new-max-length (- original-size offset)]
(when-not (<= length new-max-length)
(throw (ex-info "Sub buffer out of range."
{:required new-max-length
:current length})))
{:desired-length length
:ecount-minus-offset new-max-length})))
(drv/sub-buffer device-buffer offset length)))
([buffer offset]
(sub-buffer buffer offset (- (dtype/ecount buffer) offset))))
Expand Down
8 changes: 6 additions & 2 deletions src/tech/compute/cpu/typed_buffer.clj
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
(.limit buf# (+ offset# len#))
buf#))
:array-backing-store (fn [item#]
(dtype/->array item#))
(let [item# (to-nio-buf ~datatype item#)]
(when (.hasArray item#)
(.array item#))))
:nio-offset (fn [item#]
(let [buf# (to-nio-buf ~datatype item#)]
(.position buf#)))
Expand Down Expand Up @@ -78,7 +80,9 @@
(and (and lhs-ary rhs-ary)
(identical? lhs-ary rhs-ary)
(and (= (get-offset lhs)
(get-offset rhs))))))
(get-offset rhs)))
(and (= (get-length lhs)
(get-length rhs))))))
(partially-alias? [lhs rhs]
(let [lhs-ary (array-backing-store (primitive/->buffer-backing-store lhs))
rhs-ary (array-backing-store (primitive/->buffer-backing-store rhs))]
Expand Down
9 changes: 5 additions & 4 deletions src/tech/compute/tensor.clj
Original file line number Diff line number Diff line change
Expand Up @@ -513,11 +513,12 @@ https://cloojure.github.io/doc/core.matrix/clojure.core.matrix.html#var-select"
(ensure-tensor arg)
arg))))
select-result (apply dims/select (tensor->dimensions tensor) args)
{:keys [dimensions elem-offset]} select-result
{:keys [dimensions elem-offset buffer-length]} select-result
tens-buffer (tens-proto/tensor->buffer tensor)
new-buffer (compute/sub-buffer tens-buffer elem-offset
(- (dtype/ecount tens-buffer)
(long elem-offset)))]
buffer-length (long (or buffer-length
(- (dtype/ecount tens-buffer)
(long elem-offset))))
new-buffer (compute/sub-buffer tens-buffer elem-offset buffer-length)]
(construct-tensor dimensions new-buffer)))


Expand Down
8 changes: 5 additions & 3 deletions src/tech/compute/tensor/dimensions.clj
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,13 @@ https://cloojure.github.io/doc/core.matrix/clojure.core.matrix.html#var-select"
shape (map dims-select/apply-select-arg-to-dimension shape args)
{shape :dimension-seq
strides :strides
offset :offset} (dims-select/dimensions->simpified-dimensions
shape strides)]
offset :offset
buffer-length :length} (dims-select/dimensions->simpified-dimensions
shape strides)]
{:dimensions {:shape shape
:strides strides}
:elem-offset offset})))
:elem-offset offset
:buffer-length buffer-length})))


(defn dimensions->column-stride
Expand Down
14 changes: 10 additions & 4 deletions src/tech/compute/tensor/dimensions/select.clj
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ Returns:
(throw (ex-info "Bad dimension type"
{:dimension dimension}))))
[[] [] 0]
(map vector dimension-seq stride-seq))]
{:dimension-seq dimension-seq
:strides strides
:offset offset}))
(map vector dimension-seq stride-seq))
retval

{:dimension-seq dimension-seq
:strides strides
:offset offset
:length (when (shape/direct-shape? dimension-seq)
(apply + 1 (map * (map (comp dec shape/shape-entry->count)
dimension-seq) strides)))}]
retval))
12 changes: 11 additions & 1 deletion test/tech/compute/cpu/tensor_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
(verify-tensor/binary-constant-op (driver) *datatype*))


(deftest binary-op
(def-all-dtype-exception-unsigned binary-op
(verify-tensor/binary-op (driver) *datatype* ))


Expand Down Expand Up @@ -163,6 +163,16 @@
(is (= [1 4] (ct/shape test-tensor))))))


(deftest select-tensor-work-correctly-with-dtype-lib
(testing "Test that dtype/->array, dtype/->array-copy work correctly with select-produced tensors"
(let [test-data (ct/->tensor (partition 3 (range 9)))]
(is (= (ct/ecount (ct/select test-data 0 :all)) 3))
(is (= (ct/ecount (dtype/->array-copy (ct/select test-data 0 :all))) 3))
(is (= (ct/ecount (dtype/->array-copy (ct/select test-data 2 :all))) 3))
(is (= (ct/ecount (dtype/->array-copy (ct/select test-data 1 :all))) 3))
(is (= nil (dtype/->array (ct/select test-data 0 :all)))))))


(def-double-float-test cholesky-decomp
(verify-tensor/cholesky-decomp (driver) *datatype*))

Expand Down

0 comments on commit ca6ee2c

Please sign in to comment.