-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.clj
63 lines (46 loc) · 1.68 KB
/
utils.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
(ns tech.compute.verify.utils
(:require [tech.resource :as resource]
[clojure.test :refer :all]
[tech.v2.datatype :as dtype]
[tech.v2.datatype.casting :as casting]
[tech.compute :as compute])
(:import [java.math BigDecimal MathContext]))
(defn test-wrapper
[test-fn]
(resource/stack-resource-context
;;Turn on if you want much slower tests.
(test-fn)))
(defmacro with-default-device-and-stream
[driver & body]
`(resource/stack-resource-context
(let [~'device (compute/default-device ~driver)
~'stream (compute/default-stream ~'device)]
~@body)))
(def ^:dynamic *datatype* :float64)
(defmacro datatype-list-tests
[datatype-list test-name & body]
`(do
~@(for [datatype datatype-list]
(do
`(deftest ~(symbol (str test-name "-" (name datatype)))
(with-bindings {#'*datatype* ~datatype}
~@body))))))
(defmacro def-double-float-test
[test-name & body]
`(datatype-list-tests [:float64 :float32] ~test-name ~@body))
(defmacro def-int-long-test
[test-name & body]
`(datatype-list-tests [:int32 :uint32 :int64 :uint64]
~test-name
~@body))
(defmacro def-all-dtype-test
[test-name & body]
`(datatype-list-tests ~casting/numeric-types ~test-name ~@body))
(defmacro def-all-dtype-exception-unsigned
"Some platforms can detect unsigned errors."
[test-name & body]
`(do
(datatype-list-tests ~casting/host-numeric-types ~test-name ~@body)
(datatype-list-tests ~casting/unsigned-int-types ~test-name
(is (thrown? Throwable
~@body)))))