Skip to content

Commit 676f5b8

Browse files
committed
format python
1 parent 5ffbd7a commit 676f5b8

File tree

5 files changed

+20
-15
lines changed

5 files changed

+20
-15
lines changed

test/test_mp_all_gather.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def _mp_fn(index):
4343

4444
# Testing with a single replica group, channel_id and use_global_device_ids
4545
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
46-
result = xm.all_gather(ordinal_tensor, dim=0, channel_id=0, use_global_device_ids=True)
46+
result = xm.all_gather(
47+
ordinal_tensor, dim=0, channel_id=0, use_global_device_ids=True)
4748
xm.mark_step()
4849

4950
cpu_result = result.cpu()

test/test_mp_reduce_scatter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,14 @@ def _mp_fn(index):
2929
assert res.cpu().allclose(expected)
3030
xm.rendezvous('test_reduce_scatter')
3131

32-
res = xm.reduce_scatter(xm.REDUCE_SUM, xrand, scale, scatter_dim,
33-
world_size, channel_id=0, use_global_device_ids=True)
32+
res = xm.reduce_scatter(
33+
xm.REDUCE_SUM,
34+
xrand,
35+
scale,
36+
scatter_dim,
37+
world_size,
38+
channel_id=0,
39+
use_global_device_ids=True)
3440
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand, scale)
3541
xm.mark_step()
3642

torch_xla/_internal/tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_tpu_type() -> str:
200200
env = get_tpu_env()
201201
except requests.HTTPError as e:
202202
raise EnvironmentError('Failed to get TPU metadata') from e
203-
203+
204204
match = re.search(r"^([^-]*)-", env[xenv.ACCELERATOR_TYPE])
205205
if match:
206206
return match.group(1)

torch_xla/core/xla_model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,8 +534,8 @@ def all_gather(value: torch.Tensor,
534534
groups: Optional[List[List[int]]] = None,
535535
output: Optional[torch.Tensor] = None,
536536
pin_layout: bool = True,
537-
channel_id = None,
538-
use_global_device_ids = None) -> torch.Tensor:
537+
channel_id=None,
538+
use_global_device_ids=None) -> torch.Tensor:
539539
"""Performs an all-gather operation along a given dimension.
540540
541541
Args:
@@ -588,7 +588,8 @@ def all_gather(value: torch.Tensor,
588588
return output
589589

590590
result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
591-
[], pin_layout, channel_id, use_global_device_ids)
591+
[], pin_layout, channel_id,
592+
use_global_device_ids)
592593
return result
593594

594595
# Now the input should be a list of Tensors.
@@ -875,8 +876,8 @@ def reduce_scatter(reduce_type: str,
875876
output: Optional[Union[torch.Tensor,
876877
List[torch.Tensor]]] = None,
877878
pin_layout: bool = True,
878-
channel_id = None,
879-
use_global_device_ids = None) -> torch.Tensor:
879+
channel_id=None,
880+
use_global_device_ids=None) -> torch.Tensor:
880881
"""Performs a XLA `ReduceScatter()` operation on the input tensor.
881882
882883
See: https://www.tensorflow.org/xla/operation_semantics#reducescatter

torch_xla/distributed/xla_multiprocessing.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,13 @@ def run(self, fn):
123123
"""
124124
with self._lock:
125125
return fn()
126-
126+
127+
127128
###############################################################################
128129
#
129130
# The following is modified from JAX: https://github.com/jax-ml/jax/blob/main/jax/_src/mesh_utils.py
130131
#
131132
###############################################################################
132-
133133

134134
_TPU_V5P = "v5p"
135135
_TPU_V6E = "v6e"
@@ -174,10 +174,7 @@ def _v6e_create_replica_groups() -> List | None:
174174
return None
175175

176176

177-
device_kind_handler_dict: dict[
178-
str,
179-
Callable[..., List | None],
180-
] = {
177+
device_kind_handler_dict: dict[str, Callable[..., List | None],] = {
181178
_TPU_V5P: _v5p_create_replica_groups,
182179
_TPU_V6E: _v6e_create_replica_groups
183180
}

0 commit comments

Comments
 (0)