This repository was archived by the owner on Dec 11, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathutils.lisp
205 lines (170 loc) · 8.38 KB
/
utils.lisp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
(in-package :cl-waffe2.frontends/onnx)
(eval-when (:compile-toplevel :load-toplevel :execute)
(defun symb (&rest inputs)
(intern (with-output-to-string (out) (dolist (sym inputs) (princ sym out))))))
;; ~~ Converter ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defparameter *converter-features* (make-hash-table :test #'equal))
(defmacro defop ((opset-name min-opset-version) ((gph inputs attrs &rest more) &body body))
"Defines a converter"
(let ((tmp (gensym)))
`(let ((,tmp (alexandria:named-lambda ,(intern (format nil "~a_~a" opset-name min-opset-version)) (,gph ,inputs ,attrs ,@more) (declare (ignorable ,gph)) ,@body)))
(if (gethash ,opset-name *converter-features*)
(let ((values (gethash ,opset-name *converter-features*)))
(let ((position (position ,min-opset-version values :test #'= :key #'cdr)))
(if position
(setf (nth position (gethash ,opset-name *converter-features*)) (cons ,tmp ,min-opset-version))
(push (cons ,tmp ,min-opset-version) (gethash ,opset-name *converter-features*)))))
(setf (gethash ,opset-name *converter-features*) (list (cons ,tmp ,min-opset-version)))))))
(defun get-converter (op-type opset-version)
(declare (type string op-type)
(type fixnum opset-version))
(restart-case (progn
(let ((candidates (gethash op-type *converter-features*)))
(when (null candidates)
(error "get-converter: Convertion pattern for ~a(version=~a) is not defined yet." op-type opset-version))
(let ((candidates (sort candidates #'> :key #'cdr)))
(loop for (impl . version) in candidates
if (<= version opset-version)
do (return-from get-converter impl)))
(error "get-converter: there's no implementation for ~a satisfying opset=~a.~%Candidates:~a"
op-type
opset-version
candidates)))
(reload-and-retry-defop ()
:report "Reload the (expected to be updated on REPL) converter and restart from the point in the error."
(get-converter op-type opset-version))))
;; ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defstruct (Graph-Proto-Helper
(:conc-name gp-)
(:constructor make-graph-proto-helper (graph-proto opset-version)))
(graph-proto grpah-proto :type Graph-Proto)
(opset-version opset-version :type fixnum)
;; ID -> Waffe2 Tensor
(name2value (make-hash-table :test #'equal) :type hash-table))
(defun tensor-shape-proto.dimension->aten (proto)
(declare (type tensor-shape-proto.dimension proto))
(let ((value (oneof-value (tensor-shape-proto.dimension-value proto))))
(etypecase value
(number value)
(string (intern value)))))
(defun tensor-proto->aten (tensor-proto)
(declare (type tensor-proto tensor-proto))
(restart-case
(let ((array (raw->array tensor-proto)))
;; [TODO] Add change-facet function when adding a support for CUDA.
(change-facet array :direction 'AbstractTensor))
(update-configuration-and-retry ()
:report "Retry the operation"
(tensor-proto->aten tensor-proto))))
(defun value-info-proto->aten (value-info-proto)
(declare (type Value-Info-Proto value-info-proto))
(with-slots ((name cl-onnx::name) (type cl-onnx::type)) value-info-proto
(let ((type (oneof-value (type-proto-value type))))
(typecase type
(type-proto.tensor
(with-slots ((elem-type cl-onnx::elem-type) (shape cl-onnx::shape)) type
(let* ((dims (tensor-shape-proto-dim shape))
(dims (map 'list #'tensor-shape-proto.dimension->aten dims)))
(make-input dims (intern name "KEYWORD") :dtype (int->data-type elem-type)))))
(T
(error "[from-onnx] value-info-proto->aten: Not implemented: ~a" type))))))
(defun value->users (graph-proto name)
(declare (type string name)
(type graph-proto graph-proto))
(loop for node in (graph-proto-node graph-proto)
for position upfrom 0
if (find name (node-proto-input node) :test #'equal)
collect node))
(defun user->values (graph-proto name)
(declare (type string name)
(type graph-proto graph-proto))
(loop for node in (graph-proto-node graph-proto)
for position upfrom 0
if (find name (node-proto-output node) :test #'equal)
collect node))
(defun gp-name->value (graph-proto-helper name)
(declare (type string name))
(gethash name (gp-name2value graph-proto-helper)))
(defun (setf gp-name->value) (value graph-proto-helper name)
(setf (gethash name (gp-name2value graph-proto-helper)) value))
(defun call-converter (converter node-proto graph-proto-helper input attrs)
(restart-case
(handler-bind
((error #'(lambda (c) (error "An error was occured during the translation:~%Position:~a(version=~a)~%Error:~% ~a" (node-proto-name node-proto) (gp-opset-version graph-proto-helper) c))))
(multiple-value-list (funcall converter graph-proto-helper input attrs)))
(reload-and-retry-defop ()
:report "Reload the (expected to be updated on REPL) converter and restart from the point in the error."
(call-converter
(get-converter (node-proto-op-type node-proto) (gp-opset-version graph-proto-helper))
node-proto
graph-proto-helper
input
attrs))))
(defun node-proto->aten (graph-proto-helper node-proto)
(declare (type graph-proto-helper graph-proto-helper)
(type node-proto node-proto))
(dolist (input (node-proto-input node-proto))
(when (null (gethash input (gp-name2value graph-proto-helper)))
(let ((values (user->values (gp-graph-proto graph-proto-helper) input)))
(dolist (val values)
(node-proto->aten graph-proto-helper val)))))
(let* ((converter (get-converter (node-proto-op-type node-proto) (gp-opset-version graph-proto-helper)))
(input
(map
'list
#'(lambda (x)
(let ((value (gp-name->value graph-proto-helper x)))
;;(when (null value) (error "node-proto->aten: ~a is not declared?" x))
value))
(node-proto-input node-proto)))
(attrs (make-hash-table :test #'equal)))
(dolist (attr (node-proto-attribute node-proto))
(setf (gethash (attribute-proto-name attr) attrs)
(cl-onnx::read-attr attr)))
(let ((out-in-tensors (call-converter converter node-proto graph-proto-helper input attrs))
(output
(node-proto-output node-proto)))
(assert (= (length output) (length out-in-tensors))
()
"Assertion Failed: Converter for (~a opset=~a) expected to return ~a tensors butgot ~a."
(node-proto-op-type node-proto)
(gp-opset-version graph-proto-helper)
(length output)
(length out-in-tensors))
(loop for tensor in out-in-tensors
for name in output
do (setf (gp-name->value graph-proto-helper name) tensor))
(apply #'values out-in-tensors))))
(defun from-onnx (graph-proto-helper)
"A helper function to translate graph-proto"
(declare (type Graph-Proto-Helper graph-proto-helper))
(let* ((graph-proto (gp-graph-proto graph-proto-helper))
(input-names (map 'list #'value-info-proto-name (graph-proto-input graph-proto)))
(output-names (map 'list #'value-info-proto-name (graph-proto-output graph-proto)))
(inputs (map 'list #'value-info-proto->aten (graph-proto-input graph-proto))))
(loop for name in input-names
for input in inputs
do (setf (gp-name->value graph-proto-helper name) input))
(dolist (initializer (graph-proto-initializer graph-proto))
;; Initializer = tensor-proto
(setf (gp-name->value graph-proto-helper (tensor-proto-name initializer))
(tensor-proto->aten initializer)))
(assert (null (graph-proto-sparse-initializer graph-proto)) () "[TODO] not ready")
(let ((innermosts
(apply #'append (map 'list #'(lambda (x) (user->values graph-proto x)) output-names))))
(dolist (node-proto innermosts)
(node-proto->aten graph-proto-helper node-proto)))
(let ((results (map 'list #'(lambda (x) (gp-name->value graph-proto-helper x)) output-names)))
(when (some #'null results)
(error "Failed to trace the graph: ~a -> ~a" output-names results))
(apply
#'cl-waffe2/base-impl:lazy-values
(append
(loop for value being each hash-value of (gp-name2value graph-proto-helper)
if (typep value 'AbstractTensor)
collect value)
(if (<= (length results) 1)
results
(apply #'cl-waffe2/base-impl:lazy-values results)))))))
(defmacro range (from to &optional (by 1))
`(loop for i from ,from to ,to by ,by collect i))