/
jaxify.py
997 lines (725 loc) · 24.3 KB
/
jaxify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
from collections.abc import Sequence
from functools import reduce
from functools import singledispatch as dispatch
from functools import update_wrapper
from warnings import warn
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import theano
from theano.compile.ops import (
DeepCopyOp,
Rebroadcast,
Shape,
Shape_i,
SpecifyShape,
ViewOp,
)
from theano.gof import FunctionGraph
from theano.ifelse import IfElse
from theano.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp
from theano.scan.op import Scan
from theano.scan.utils import scan_args as ScanArgs
from theano.tensor.basic import (
Alloc,
AllocEmpty,
ARange,
Dot,
Eye,
Join,
MaxAndArgmax,
Reshape,
ScalarFromTensor,
TensorFromScalar,
)
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from theano.tensor.extra_ops import (
Bartlett,
CumOp,
DiffOp,
FillDiagonal,
FillDiagonalOffset,
RavelMultiIndex,
RepeatOp,
Unique,
UnravelIndex,
)
from theano.tensor.nlinalg import (
SVD,
AllocDiag,
Det,
Eig,
Eigh,
ExtractDiag,
MatrixInverse,
QRFull,
QRIncomplete,
)
from theano.tensor.nnet.sigm import ScalarSoftplus
from theano.tensor.opt import MakeVector
from theano.tensor.slinalg import Cholesky, Solve
from theano.tensor.subtensor import ( # This is essentially `np.take`; Boolean mask indexing and setting
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
get_idx_list,
)
from theano.tensor.type_other import MakeSlice
if theano.config.floatX == "float64":
jax.config.update("jax_enable_x64", True)
else:
jax.config.update("jax_enable_x64", False)
# XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted.
# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md
# Older versions < 0.2.0 do not have this flag so we don't need to set it.
try:
jax.config.disable_omnistaging()
except AttributeError:
pass
subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
"""Compose JAX implementations of node operations.
Parameters
----------
out_node: Node
The output node.
fgraph_inputs: List[Variable]
The inputs--in a `FunctionGraph` sense--to `out_node`.
memo: Mapping (Optional)
A map from visited nodes to their JAX functions.
Outputs
-------
A `function` object that represents the composed JAX operations and takes
the same form of inputs as `fgraph_inputs`.
"""
if memo is None:
memo = {}
if out_node in memo:
return memo[out_node]
jax_return_func = jax_funcify(out_node.op)
input_funcs = []
for i in out_node.inputs:
if i in fgraph_inputs:
idx = fgraph_inputs.index(i)
i_dtype = getattr(i, "dtype", None)
def jax_inputs_func(*inputs, i_dtype=i_dtype, idx=idx):
return jnp.array(inputs[idx], dtype=jnp.dtype(i_dtype))
input_f = jax_inputs_func
elif i.owner is None:
i_dtype = getattr(i, "dtype", None)
i_data = i.data
def jax_data_func(*inputs, i_dtype=i_dtype, i_data=i_data):
if i_dtype is None:
return i_data
else:
return jnp.array(i_data, dtype=jnp.dtype(i_dtype))
input_f = jax_data_func
else:
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
input_funcs.append(input_f)
if not isinstance(jax_return_func, Sequence):
jax_return_func = [jax_return_func]
jax_funcs = []
for return_func in jax_return_func:
def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs]
return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func))
if len(out_node.outputs) == 1:
jax_funcs = jax_funcs[0]
memo[out_node] = jax_funcs
return jax_funcs
@dispatch
def jax_funcify(op):
"""Create a JAX "perform" function for a Theano `Variable` and its `Op`."""
raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op))
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op):
def makeslice(*x):
return slice(*x)
return makeslice
@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op):
func_name = op.nfunc_spec[0]
if "." in func_name:
jnp_func = reduce(getattr, [jax] + func_name.split("."))
else:
jnp_func = getattr(jnp, func_name)
if hasattr(op, "nfunc_variadic"):
# These are special cases that handle invalid arities due to the broken
# Theano `Op` type contract (e.g. binary `Op`s that also function as
# their own variadic counterparts--even when those counterparts already
# exist as independent `Op`s).
jax_variadic_func = getattr(jnp, op.nfunc_variadic)
def elemwise(*args):
if len(args) > op.nfunc_spec[1]:
return jax_variadic_func(
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
)
else:
return jnp_func(*args)
return elemwise
else:
return jnp_func
@jax_funcify.register(Clip)
def jax_funcify_Clip(op):
def clip(x, min, max):
return jnp.where(x < min, min, jnp.where(x > max, max, x))
return clip
@jax_funcify.register(Identity)
def jax_funcify_Identity(op):
def identity(x):
return x
return identity
@jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op):
def scalarsoftplus(x):
return jnp.where(x < -30.0, 0.0, jnp.where(x > 30.0, x, jnp.log1p(jnp.exp(x))))
return scalarsoftplus
@jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op):
def allocempty(*shape):
return jnp.empty(shape, dtype=op.dtype)
return allocempty
@jax_funcify.register(Alloc)
def jax_funcify_Alloc(op):
def alloc(x, *shape):
res = jnp.broadcast_to(x, shape)
return res
return alloc
@jax_funcify.register(Dot)
def jax_funcify_Dot(op):
def dot(x, y):
return jnp.dot(x, y)
return dot
@jax_funcify.register(ARange)
def jax_funcify_ARange(op):
# XXX: This currently requires concrete arguments.
def arange(start, stop, step):
return jnp.arange(start, stop, step, dtype=op.dtype)
return arange
def jnp_safe_copy(x):
try:
res = jnp.copy(x)
except NotImplementedError:
warn("`jnp.copy` is not implemented yet. " "Using the object's `copy` method.")
if hasattr(x, "copy"):
res = jnp.array(x.copy())
else:
warn("Object has no `copy` method: {}".format(x))
res = x
return res
@jax_funcify.register(DeepCopyOp)
def jax_funcify_DeepCopyOp(op):
def deepcopyop(x):
return jnp_safe_copy(x)
return deepcopyop
@jax_funcify.register(Shape)
def jax_funcify_Shape(op):
def shape(x):
return jnp.shape(x)
return shape
@jax_funcify.register(Shape_i)
def jax_funcify_Shape_i(op):
i = op.i
def shape_i(x):
return jnp.shape(x)[i]
return shape_i
@jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op):
def specifyshape(x, shape):
assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), (
"got shape",
x.shape,
"expected",
shape,
)
return x
return specifyshape
@jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op):
op_axis = op.axis
def rebroadcast(x):
for axis, value in op_axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
"Dimension %s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
)
return x
return rebroadcast
@jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op):
def viewop(x):
return x
return viewop
@jax_funcify.register(Cast)
def jax_funcify_Cast(op):
def cast(x):
return jnp.array(x).astype(op.o_type.dtype)
return cast
@jax_funcify.register(TensorFromScalar)
def jax_funcify_TensorFromScalar(op):
def tensor_from_scalar(x):
return jnp.array(x)
return tensor_from_scalar
@jax_funcify.register(ScalarFromTensor)
def jax_funcify_ScalarFromTensor(op):
def scalar_from_tensor(x):
return jnp.array(x).flatten()[0]
return scalar_from_tensor
@jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op):
scalar_op = op.scalar_op
return jax_funcify(scalar_op)
@jax_funcify.register(Composite)
def jax_funcify_Composite(op):
jax_impl = jax_funcify(op.fgraph)
return jax_impl
@jax_funcify.register(Scan)
def jax_funcify_Scan(op):
inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_tt_inner_func = jax_funcify(inner_fg)
def scan(*outer_inputs):
scan_args = ScanArgs(
outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info
)
# `outer_inputs` is a list with the following composite form:
# [n_steps]
# + outer_in_seqs
# + outer_in_mit_mot
# + outer_in_mit_sot
# + outer_in_sit_sot
# + outer_in_shared
# + outer_in_nit_sot
# + outer_in_non_seqs
n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs
n_non_seqs = len(scan_args.outer_in_non_seqs)
# TODO: sit_sots
mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0]
pos_taps = [abs(t) for t in tap if t > 0]
max_neg = max(neg_taps) if neg_taps else 0
max_pos = max(pos_taps) if pos_taps else 0
init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice)
init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs]
def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared
# terms
(
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = carry
# `x` contains the in_seqs
inner_in_seqs = x
# `inner_scan_inputs` is a list with the following composite form:
# inner_in_seqs
# + sum(inner_in_mit_mot, [])
# + sum(inner_in_mit_sot, [])
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_scan_inputs = [
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_non_seqs,
]
raise NotImplementedError()
return inner_scan_inputs
def inner_scan_outs_to_jax_outs(
op,
old_carry,
inner_scan_outs,
):
# `inner_scan_outs` is a list with the following
# composite form:
# outer_out_mit_mot
# + outer_out_mit_sot
# + outer_out_sit_sot
# + outer_out_nit_sot
# + outer_out_shared
# + cond
(
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_nit_sot,
outer_out_shared,
cond,
) = inner_scan_outs
outer_out_non_seqs = old_carry[:-n_non_seqs]
# This should contain all inner-output taps, non_seqs, and shared
# terms
carry = [
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_shared,
outer_out_non_seqs,
]
# This should contain all inner-outputs that produce
# outer-outputs
y = []
raise NotImplementedError()
return (carry, y)
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = jax_tt_inner_func(*inner_args)
new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs)
return new_carry, y
return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
return scan
@jax_funcify.register(IfElse)
def jax_funcify_IfElse(op):
n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res if n_outs > 1 else res[0]
return ifelse
def convert_indices(indices, entry):
if indices and isinstance(entry, theano.gof.Type):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry
@jax_funcify.register(Subtensor)
def jax_funcify_Subtensor(op):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
if idx_list:
cdata = get_idx_list((x,) + ilists, idx_list)
else:
cdata = ilists
# breakpoint()
if len(cdata) == 1:
cdata = cdata[0]
return x.__getitem__(cdata)
# return x.take(ilists, axis=0)
return subtensor
_ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def jax_funcify_IncSubtensor(op):
idx_list = op.idx_list
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
else:
jax_fn = jax.ops.index_add
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
_ilist = list(ilist)
cdata = tuple(convert_indices(_ilist, idx) for idx in idx_list)
if len(cdata) == 1:
cdata = cdata[0]
return jax_fn(x, cdata, y)
return incsubtensor
_ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops]
@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op):
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
else:
jax_fn = jax.ops.index_add
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return advancedincsubtensor
@jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(fgraph):
out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
return jax_funcs
@jax_funcify.register(CAReduce)
def jax_funcify_CAReduce(op):
axis = op.axis
op_nfunc_spec = getattr(op, "nfunc_spec", None)
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
scalar_op_name = getattr(op.scalar_op, "name", None)
scalar_op_identity = getattr(op.scalar_op, "identity", None)
acc_dtype = getattr(op, "acc_dtype", None)
def careduce(x):
nonlocal axis, op_nfunc_spec, scalar_nfunc_spec, scalar_op_name, scalar_op_identity, acc_dtype
if axis is None:
axis = list(range(x.ndim))
if acc_dtype is None:
acc_dtype = x.dtype.type
if op_nfunc_spec:
jax_op = getattr(jnp, op_nfunc_spec[0])
return jax_op(x, axis=axis).astype(acc_dtype)
# The Theano `Op` didn't tell us which NumPy equivalent to use (or
# there isn't one), so we use this fallback approach
if scalar_nfunc_spec:
scalar_fn_name = scalar_nfunc_spec[0]
elif scalar_op_name:
scalar_fn_name = scalar_op_name
to_reduce = reversed(sorted(axis))
if to_reduce:
# In this case, we need to use the `jax.lax` function (if there
# is one), and not the `jnp` version.
jax_op = getattr(jax.lax, scalar_fn_name)
init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)
return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype)
else:
return x
return careduce
@jax_funcify.register(MakeVector)
def jax_funcify_MakeVector(op):
def makevector(*x):
return jnp.array(x, dtype=op.dtype)
return makevector
@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op):
def reshape(x, shape):
return jnp.reshape(x, shape)
return reshape
@jax_funcify.register(DimShuffle)
def jax_funcify_DimShuffle(op):
def dimshuffle(x):
res = jnp.transpose(x, op.shuffle + op.drop)
shape = list(res.shape[: len(op.shuffle)])
for augm in op.augment:
shape.insert(augm, 1)
res = jnp.reshape(res, shape)
if not op.inplace:
res = jnp_safe_copy(res)
return res
return dimshuffle
@jax_funcify.register(Join)
def jax_funcify_Join(op):
def join(axis, *tensors):
view = op.view
if (view != -1) and all(
[
tensor.shape[axis] == 0
for tensor in tensors[0:view] + tensors[view + 1 :]
]
):
return tensors[view]
else:
ndim = tensors[0].ndim
if axis < -ndim:
raise IndexError("Join axis %d out of bounds [0, %d)" % (axis, ndim))
return jnp.concatenate(tensors, axis=axis)
return join
@jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op):
axis = op.axis
def maxandargmax(x, axis=axis):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
max_res = jnp.max(x, axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
[i for i in range(x.ndim) if i not in axes], dtype="int64"
)
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = kept_shape + (
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(new_shape)
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_res, max_idx_res
return maxandargmax
@jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op):
offset = op.offset
axis1 = op.axis1
axis2 = op.axis2
def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
return jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
return extract_diag
@jax_funcify.register(Cholesky)
def jax_funcify_Cholesky(op):
lower = op.lower
def cholesky(a, lower=lower):
return jsp.linalg.cholesky(a, lower=lower).astype(a.dtype)
return cholesky
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op):
def alloc_diag(x):
return jnp.diag(x)
return alloc_diag
@jax_funcify.register(Solve)
def jax_funcify_Solve(op):
if op.A_structure == "lower_triangular":
lower = True
else:
lower = False
def solve(a, b, lower=lower):
return jsp.linalg.solve(a, b, lower=lower)
return solve
@jax_funcify.register(Det)
def jax_funcify_Det(op):
def det(x):
return jnp.linalg.det(x)
return det
@jax_funcify.register(Eig)
def jax_funcify_Eig(op):
def eig(x):
return jnp.linalg.eig(x)
return eig
@jax_funcify.register(Eigh)
def jax_funcify_Eigh(op):
uplo = op.UPLO
def eigh(x, uplo=uplo):
return jnp.linalg.eigh(x, UPLO=uplo)
return eigh
@jax_funcify.register(MatrixInverse)
def jax_funcify_MatrixInverse(op):
def matrix_inverse(x):
return jnp.linalg.inv(x)
return matrix_inverse
@jax_funcify.register(QRFull)
def jax_funcify_QRFull(op):
mode = op.mode
def qr_full(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_full
@jax_funcify.register(QRIncomplete)
def jax_funcify_QRIncomplete(op):
mode = op.mode
def qr_incomplete(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_incomplete
@jax_funcify.register(SVD)
def jax_funcify_SVD(op):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
def svd(x, full_matrices=full_matrices, compute_uv=compute_uv):
return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return svd
@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op):
axis = op.axis
mode = op.mode
def cumop(x, axis=axis, mode=mode):
if mode == "add":
return jnp.cumsum(x, axis=axis)
else:
return jnp.cumprod(x, axis=axis)
return cumop
@jax_funcify.register(DiffOp)
def jax_funcify_DiffOp(op):
n = op.n
axis = op.axis
def diffop(x, n=n, axis=axis):
return jnp.diff(x, n=n, axis=axis)
return diffop
@jax_funcify.register(RepeatOp)
def jax_funcify_RepeatOp(op):
axis = op.axis
def repeatop(x, repeats, axis=axis):
return jnp.repeat(x, repeats, axis=axis)
return repeatop
@jax_funcify.register(Bartlett)
def jax_funcify_Bartlett(op):
def bartlett(x):
return jnp.bartlett(x)
return bartlett
@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op):
# def filldiagonal(a, val):
# if a.ndim == 2:
# step = a.shape[1] + 1
# end = a.shape[1] * a.shape[1]
# a.flat[:end:step] = val
# else:
# jnp.fill_diagonal(a, val)
#
# return a
#
# return filldiagonal
raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(FillDiagonalOffset)
def jax_funcify_FillDiagonalOffset(op):
# def filldiagonaloffset(a, val, offset):
# height, width = a.shape
#
# if offset >= 0:
# start = offset
# num_of_step = min(min(width, height), width - offset)
# else:
# start = -offset * a.shape[1]
# num_of_step = min(min(width, height), height + offset)
#
# step = a.shape[1] + 1
# end = start + step * num_of_step
# a.flat[start:end:step] = val
#
# return a
#
# return filldiagonaloffset
raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(Unique)
def jax_funcify_Unique(op):
axis = op.axis
if axis is not None:
raise NotImplementedError(
"jax.numpy.unique is not implemented for the axis argument"
)
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
def unique(
x,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
):
ret = jnp.lax_numpy._unique1d(x, return_index, return_inverse, return_counts)
if len(ret) == 1:
return ret[0]
else:
return ret
return unique
@jax_funcify.register(UnravelIndex)
def jax_funcify_UnravelIndex(op):
order = op.order
warn("JAX ignores the `order` parameter in `unravel_index`.")
def unravelindex(indices, dims, order=order):
return jnp.unravel_index(indices, dims)
return unravelindex
@jax_funcify.register(RavelMultiIndex)
def jax_funcify_RavelMultiIndex(op):
mode = op.mode
order = op.order
def ravelmultiindex(*inp, mode=mode, order=order):
multi_index, dims = inp[:-1], inp[-1]
return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order)
return ravelmultiindex
@jax_funcify.register(Eye)
def jax_funcify_Eye(op):
dtype = op.dtype
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)
return eye