-
Notifications
You must be signed in to change notification settings - Fork 1
/
tensor.clj
33 lines (29 loc) · 1.03 KB
/
tensor.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
(ns tech.compute.verify.tensor
(:require [tech.compute.context :as compute-ctx]
[tech.v2.tensor.impl :as dtt-impl]
[tech.v2.datatype.functional :as dfn]
[tech.v2.datatype :as dtype]
[tech.v2.tensor :as dtt]
[tech.compute.tensor :as ct]
[tech.resource :as resource]
[clojure.test :refer :all]))
(defmacro verify-context
[driver datatype & body]
`(resource/stack-resource-context
(compute-ctx/with-context
{:driver ~driver}
(dtt-impl/with-datatype
~datatype
~@body))))
(defn clone
[driver datatype]
(verify-context
driver datatype
(let [tensor (ct/->tensor (partition 3 (range 9)))
dev-tens (ct/clone-to-device tensor)
host-tens (ct/clone-to-host dev-tens)]
(is (dfn/equals tensor host-tens))
(let [sub-tens (dtt/select tensor [0 1] [0 1])
dev-tens (ct/clone-to-device sub-tens)
host-tens (ct/clone-to-host dev-tens)]
(is (dfn/equals sub-tens host-tens))))))