-
Notifications
You must be signed in to change notification settings - Fork 1
/
jax-methods-returning-self.txt
162 lines (131 loc) · 5.43 KB
/
jax-methods-returning-self.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
Methods with `self` or `cls` annotations: 0
Methods returning `self` or `cls(...)`: 28
def numpy_array(x):
return x
def strip_weak_type(self):
return self
def join(self, other):
assert isinstance(other, self.__class__), other
return self
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def full_lower(self):
return self # default implementation
def full_lower(self):
if self.pval.is_known:
return full_lower(self.pval.const)
return self
def __copy__(self):
return self
def __deepcopy__(self, unused_memo):
return self
def strip_weak_type(self) -> 'AbstractValue':
return self
def strip_named_shape(self) -> 'AbstractValue':
return self
def join(self, other):
if config.jax_enable_checks:
assert other is abstract_unit, other
return self
def block_until_ready(self):
self._check_if_deleted()
for buf in self.device_buffers:
buf.block_host_until_ready()
return self
def full_lower(self):
return self
def block_until_ready(self):
"""Blocks the caller until the buffer's value has been computed on device.
This method is mostly useful for timing microbenchmarks that wish to
time how long a computation takes, without transferring the result back
to the host.
Returns the buffer object (`self`).
"""
self._check_if_deleted()
self.device_buffer.block_host_until_ready() # pytype: disable=attribute-error
return self
def full_lower(self):
return self
@classmethod
def from_axis_resources(cls,
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
resource_env: ResourceEnv,
global_axis_sizes: Dict[AxisName, int]):
physical_axis_resources, loop_axis_resources = _unzip_axis_resources(
axis_resources, resource_env)
axis_resource_count = _get_axis_resource_count(axis_resources, resource_env)
axis_subst_dict = dict(axis_resources)
axis_vmap_size: Dict[AxisName, Optional[int]] = {}
for naxis, raxes in sorted(axis_resources.items(), key=lambda x: str(x[0])):
num_resources = axis_resource_count[naxis]
assert global_axis_sizes[naxis] % num_resources.nglobal == 0
local_tile_size = global_axis_sizes[naxis] // num_resources.nglobal
# We have to vmap when there are no resources (to handle the axis name!) or
# when every resource gets chunks of values.
if not raxes or local_tile_size > 1:
axis_vmap_size[naxis] = local_tile_size
axis_subst_dict[naxis] += (fresh_resource_name(naxis),)
else:
axis_vmap_size[naxis] = None
return cls(resource_env,
physical_axis_resources, loop_axis_resources,
axis_subst_dict, axis_vmap_size)
@classmethod
def from_user_input(cls, entry, arg_name):
if entry is None:
return cls(entry, ())
if not isinstance(entry, PartitionSpec):
raise TypeError(f"{arg_name} are expected to be "
f"PartitionSpec instances or None, but got {entry}")
axis_specs = []
for axis_spec in entry:
if axis_spec is None:
axis_spec = ()
elif isinstance(axis_spec, (list, tuple)):
axis_spec = tuple(axis_spec)
else:
axis_spec = (axis_spec,)
axis_specs.append(axis_spec)
return cls(entry, axis_specs)
def __enter__(self):
return self
def __iter__(self):
"""Called before starting the first iteration."""
self.first_iteration = True # In case we reuse the range
return self
def full_lower(self):
return self
@staticmethod
def update(optimizer, inputs, labels):
grad = jax.grad(FlaxMNIST.loss)(optimizer.target, inputs, labels)
optimizer = optimizer.apply_gradient(grad)
return optimizer
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(children, **aux_data)
@classmethod
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32):
if nnz is None:
nnz = (mat != 0).sum()
return cls(csr_fromdense(mat, nnz=nnz, index_dtype=index_dtype), shape=mat.shape)
@classmethod
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32):
if nnz is None:
nnz = (mat != 0).sum()
return cls(csr_fromdense(mat.T, nnz=nnz, index_dtype=index_dtype), shape=mat.shape)
@classmethod
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32):
if nnz is None:
nnz = (mat != 0).sum()
return cls(coo_fromdense(mat, nnz=nnz, index_dtype=index_dtype), shape=mat.shape)
@classmethod
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32, n_dense=0, n_batch=0):
return cls(bcoo_fromdense(mat, nse=nnz, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch), shape=mat.shape)
@classmethod
def tree_unflatten(cls, aux, consts):
jaxpr, in_tree, out_tree = aux
return cls(jaxpr, in_tree, out_tree, consts)