Skip to content

Commit

Permalink
remove pyre-fixme annotations from distributed util (#650)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #650

# Context
As part of reducing pyre-fixme annotations

# This diff
Remove pyre-fixme annotations from distributed util

Reviewed By: JKSenthil

Differential Revision: D51930004

fbshipit-source-id: 03fd6da134267bcae6ca746f091ecd931b91f2c9
  • Loading branch information
galrotem authored and facebook-github-bot committed Dec 18, 2023
1 parent 22abb8c commit 799918a
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import tempfile
from functools import wraps
from typing import Any, Callable, List, Optional, TypeVar, Union
from typing import Any, Callable, cast, List, Optional, TypeVar, Union

import torch
import torch.distributed as dist
Expand All @@ -18,6 +18,9 @@
from torch.distributed.elastic.utils.distributed import get_free_port
from typing_extensions import Literal

T = TypeVar("T")
DistObjList = Union[List[T], List[None]]


class PGWrapper:
"""
Expand Down Expand Up @@ -54,25 +57,22 @@ def barrier(self) -> None:
else:
dist.barrier(group=self.pg)

# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def broadcast_object_list(self, obj_list: List[Any], src: int = 0) -> None:
def broadcast_object_list(self, obj_list: DistObjList, src: int = 0) -> None:
if self.pg is None:
return
dist.broadcast_object_list(obj_list, src=src, group=self.pg)

# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def all_gather_object(self, obj_list: List[Any], obj: Any) -> None:
def all_gather_object(self, obj_list: DistObjList, obj: T) -> None:
if self.pg is None:
obj_list = cast(List[T], obj_list) # to make pyre happy
obj_list[0] = obj
return
dist.all_gather_object(obj_list, obj, group=self.pg)

def scatter_object_list(
self,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
output_list: List[Any],
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
input_list: Optional[List[Any]],
output_list: List[None],
input_list: Optional[DistObjList],
src: int = 0,
) -> None:
rank = self.get_rank()
Expand Down

0 comments on commit 799918a

Please sign in to comment.