-
Notifications
You must be signed in to change notification settings - Fork 3
/
api_sugar.clj
90 lines (77 loc) · 2.57 KB
/
api_sugar.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
(ns tvm-clj.api-sugar
"Operators and bindings to make the clojure code look and work like the python
tvm bindings. This file is purely syntax sugar."
(:require [tvm-clj.api :as api]
[tvm-clj.tvm-jna :as bindings]
[tech.datatype :as dtype])
(:refer-clojure :exclude [+ - * / rem = min max cast]))
(set! *warn-on-reflection* true)
(defmacro defbinop
[op-symbol scalar-fn api-fn]
`(defn ~op-symbol
([lhs# rhs#]
(if (or (bindings/is-node-handle? lhs#)
(bindings/is-node-handle? rhs#))
(~api-fn lhs# rhs#)
(~scalar-fn lhs# rhs#)))
([lhs# rhs# arg# & args#]
(apply ~op-symbol
(~op-symbol lhs# rhs#)
arg# args#))))
(defmacro defunop
[op-symbol scalar-fn api-fn]
`(defn ~op-symbol
[lhs#]
(if (bindings/is-node-handle? lhs#)
(~api-fn lhs#)
(~scalar-fn lhs#))))
(defbinop + clojure.core/+ api/add)
(defbinop - clojure.core/- api/sub)
(defbinop * clojure.core/* api/mul)
(defbinop / clojure.core// api/div)
(defbinop rem clojure.core/rem api/mod)
(defbinop = clojure.core/= api/eq)
(defbinop min clojure.core/min api/min)
(defbinop max clojure.core/max api/max)
(defbinop pow Math/pow api/power)
(defunop exp Math/exp api/exp)
(defunop tanh Math/tanh api/tanh)
(defunop sigmoid #(/ 1.0
(+ 1.0 (Math/exp (- %)))) api/sigmoid)
(defunop log Math/log api/log)
(defunop sqrt Math/sqrt api/sqrt)
(defunop floor Math/floor api/floor)
(defunop ceil Math/ceil api/ceil)
(defunop abs #(Math/abs (double %)) api/abs)
(defunop round #(Math/round (double %)) api/round)
(defunop trunc #(if (> % 0)
(Math/floor %)
(Math/ceil %)) api/trunc)
(defunop popcount #(Long/bitCount (long %)) api/popcount)
(def tvar api/variable)
(def const api/const)
(def placeholder api/placeholder)
(def cast api/cast)
(defn compute
"Returns the output tensor or a vector instead of the operation.
You can recover the operation from any output's op member."
[dims fn name & args]
(let [target-op (apply api/compute dims fn name args)
output-tensors (api/output-tensors target-op)]
(if (= 1 (count output-tensors))
(first output-tensors)
output-tensors)))
(defmacro lambda
[arglist & body]
`(api/tvm-fn ~arglist ~@body))
(defmacro tif
[bool-stmt true-stmt false-stmt]
`(let [bool-arg# ~bool-stmt]
(if (bindings/is-node-handle? bool-arg#)
(api/select bool-arg# ~true-stmt ~false-stmt)
(if bool-arg#
~true-stmt
~false-stmt))))
(defmacro tlet
[expr-pairs body]
`(api/tvm-let ~expr-pairs ~body))