|
| 1 | +from typing import Mapping, Any |
1 | 2 | import collections |
2 | 3 | import copy |
3 | 4 | import functools |
@@ -132,6 +133,88 @@ def call(*args, **kwargs): |
132 | 133 |
|
133 | 134 | self._jitted[key] = call |
134 | 135 |
|
| 136 | + def cpu_state_dict(self, *args, **kwargs): |
| 137 | + """ |
| 138 | + Wrapper for state_dict |
| 139 | + |
| 140 | + this function will make sure to transfer all the parameters to CPU |
| 141 | + making it easier to save the state dict with torch.save |
| 142 | +
|
| 143 | + Returns: |
| 144 | + Mapping[str, Any]: A mapping of parameter names to their values (in torch CPU) |
| 145 | + """ |
| 146 | + state_dict = super().state_dict(*args, **kwargs) |
| 147 | + state_dict = pytree.tree_map(lambda t: t.cpu(), state_dict) |
| 148 | + return state_dict |
| 149 | + |
| 150 | + def load_state_dict(self, |
| 151 | + state_dict: Mapping[str, Any], |
| 152 | + strict: bool = True, |
| 153 | + assign: bool = False): |
| 154 | + """ |
| 155 | + Wrapper for load_state_dict |
| 156 | + |
| 157 | + This function assumes torch CPU state dict and will transfer the parameters to the correct device |
| 158 | + and dtype before loading them into the model. |
| 159 | +
|
| 160 | + Args: |
| 161 | + state_dict (Mapping[str, Any]): A mapping of parameter names to their values (in torch CPU) |
| 162 | + strict (bool, optional): whether to strictly enforce that the keys |
| 163 | + in :attr:`state_dict` match the keys returned by this module's |
| 164 | + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` |
| 165 | + assign (bool, optional): When set to ``False``, the properties of the tensors |
| 166 | + in the current module are preserved whereas setting it to ``True`` preserves |
| 167 | + properties of the Tensors in the state dict. The only |
| 168 | + exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s |
| 169 | + for which the value from the module is preserved. |
| 170 | + Default: ``False`` |
| 171 | +
|
| 172 | + Returns: |
| 173 | + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: |
| 174 | + * **missing_keys** is a list of str containing any keys that are expected |
| 175 | + by this module but missing from the provided ``state_dict``. |
| 176 | + * **unexpected_keys** is a list of str containing the keys that are not |
| 177 | + expected by this module but present in the provided ``state_dict``. |
| 178 | + """ |
| 179 | + # Move tensors to JAX to have easier time extracting sharding information |
| 180 | + current_state_dict = super().state_dict() |
| 181 | + current_state_dict = jax_view(current_state_dict) |
| 182 | + |
| 183 | + # create out shardings that eithe reuses the current state dict sharding or replicates the weights |
| 184 | + def extract_sharding_or_replicate(name): |
| 185 | + if name in current_state_dict: |
| 186 | + return current_state_dict[name].sharding |
| 187 | + return jax.sharding.PartitionSpec() |
| 188 | + |
| 189 | + output_shards = { |
| 190 | + name: extract_sharding_or_replicate(name) for name in state_dict |
| 191 | + } |
| 192 | + |
| 193 | + def convert_to_xla_tensor_if_needed(t): |
| 194 | + is_torch_tensor = isinstance(t, torch.Tensor) |
| 195 | + is_xla_tensor = isinstance(t, torchax.tensor.Tensor) |
| 196 | + if is_xla_tensor: |
| 197 | + t = jax_view(t) |
| 198 | + elif is_torch_tensor: |
| 199 | + # convert to jax tensor |
| 200 | + t = tensor.t2j(t) |
| 201 | + return t |
| 202 | + |
| 203 | + # convert the state dict to JAX and shard them |
| 204 | + state_dict = pytree.tree_map( |
| 205 | + tensor.t2j, |
| 206 | + state_dict, |
| 207 | + ) |
| 208 | + # Convert ordered dict to regular dict, pjit type-safety checks |
| 209 | + state_dict = dict(state_dict) |
| 210 | + jitted = jax_jit( |
| 211 | + lambda t: t, kwargs_for_jax_jit={"out_shardings": output_shards}) |
| 212 | + state_dict = jitted(state_dict) |
| 213 | + # review it as torch tensors, so we can use torch.assign if we need to |
| 214 | + state_dict = torch_view(state_dict) |
| 215 | + |
| 216 | + return super().load_state_dict(state_dict, strict, assign) |
| 217 | + |
135 | 218 |
|
136 | 219 | class CompileMixin: |
137 | 220 |
|
|
0 commit comments