Skip to content

Commit 50889bd

Browse files
Thomas Polasekfacebook-github-bot
authored andcommitted
Convert directory fbcode/torchrec to use the Ruff Formatter (#2566)
Summary: Pull Request resolved: #2566 Converts the directory specified to use the Ruff formatter in pyfmt ruff_dog If this diff causes merge conflicts when rebasing, please run `hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt` on your diff, and amend any changes before rebasing onto latest. That should help reduce or eliminate any merge conflicts. allow-large-files Reviewed By: amyreese Differential Revision: D66013071 fbshipit-source-id: ee300de212220e068deb208761940a24624a3a5b
1 parent 1f955b5 commit 50889bd

File tree

107 files changed

+369
-472
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+369
-472
lines changed

benchmarks/ebc_benchmarks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def get_fused_ebc_uvm_time(
163163
location: EmbeddingLocation,
164164
epochs: int = 100,
165165
) -> Tuple[float, float]:
166-
167166
fused_ebc = FusedEmbeddingBagCollection(
168167
tables=embedding_bag_configs,
169168
optimizer_type=torch.optim.SGD,
@@ -195,7 +194,6 @@ def get_ebc_comparison(
195194
device: torch.device,
196195
epochs: int = 100,
197196
) -> Tuple[float, float, float, float, float]:
198-
199197
# Simple EBC module wrapping a list of nn.EmbeddingBag
200198
ebc = EmbeddingBagCollection(
201199
tables=embedding_bag_configs,

benchmarks/ebc_benchmarks_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def get_random_dataset(
2626
embedding_bag_configs: List[EmbeddingBagConfig],
2727
pooling_factors: Optional[Dict[str, int]] = None,
2828
) -> IterableDataset[Batch]:
29-
3029
if pooling_factors is None:
3130
pooling_factors = {}
3231

@@ -57,7 +56,6 @@ def train_one_epoch(
5756
dataset: IterableDataset[Batch],
5857
device: torch.device,
5958
) -> float:
60-
6159
start_time = time.perf_counter()
6260

6361
for data in dataset:
@@ -82,7 +80,6 @@ def train_one_epoch_fused_optimizer(
8280
dataset: IterableDataset[Batch],
8381
device: torch.device,
8482
) -> float:
85-
8683
start_time = time.perf_counter()
8784

8885
for data in dataset:
@@ -106,7 +103,6 @@ def train(
106103
device: torch.device,
107104
epochs: int = 100,
108105
) -> Tuple[float, float]:
109-
110106
training_time = []
111107
for _ in range(epochs):
112108
if optimizer:

examples/bert4rec/bert4rec_main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
# OSS import
3737
try:
38-
3938
# pyre-ignore[21]
4039
# @manual=//torchrec/github/examples/bert4rec:bert4rec_metrics
4140
from bert4rec_metrics import recalls_and_ndcgs_for_ks

examples/golden_training/train_dlrm_data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def train(
160160
)
161161

162162
def dense_filter(
163-
named_parameters: Iterator[Tuple[str, nn.Parameter]]
163+
named_parameters: Iterator[Tuple[str, nn.Parameter]],
164164
) -> Iterator[Tuple[str, nn.Parameter]]:
165165
for fqn, param in named_parameters:
166166
if "sparse" not in fqn:

examples/retrieval/two_tower_retrieval.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
# OSS import
2929
try:
30-
3130
# pyre-ignore[21]
3231
# @manual=//torchrec/github/examples/retrieval:knn_index
3332
from knn_index import get_index

tools/lint/black_linter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ def main() -> None:
179179
level=(
180180
logging.NOTSET
181181
if args.verbose
182-
else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO
182+
else logging.DEBUG
183+
if len(args.filenames) < 1000
184+
else logging.INFO
183185
),
184186
stream=sys.stderr,
185187
)

torchrec/datasets/criteo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,9 @@ def get_file_row_ranges_and_remainder(
351351

352352
# If the ranges overlap.
353353
if rank_left_g <= file_right_g and rank_right_g >= file_left_g:
354-
overlap_left_g, overlap_right_g = max(rank_left_g, file_left_g), min(
355-
rank_right_g, file_right_g
354+
overlap_left_g, overlap_right_g = (
355+
max(rank_left_g, file_left_g),
356+
min(rank_right_g, file_right_g),
356357
)
357358

358359
# Convert overlap in global numbers to (local) numbers specific to the

torchrec/datasets/random.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def __init__(
3333
*,
3434
min_ids_per_features: Optional[List[int]] = None,
3535
) -> None:
36-
3736
self.keys = keys
3837
self.keys_length: int = len(keys)
3938
self.batch_size = batch_size
@@ -76,7 +75,6 @@ def __next__(self) -> Batch:
7675
return batch
7776

7877
def _generate_batch(self) -> Batch:
79-
8078
values = []
8179
lengths = []
8280
for key_idx, _ in enumerate(self.keys):

torchrec/datasets/test_utils/criteo_test_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def _create_dataset_npys(
103103
labels: Optional[np.ndarray] = None,
104104
) -> Generator[Tuple[str, ...], None, None]:
105105
with tempfile.TemporaryDirectory() as tmpdir:
106-
107106
if filenames is None:
108107
filenames = [filename]
109108

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -785,9 +785,7 @@ def purge(self) -> None:
785785
def named_split_embedding_weights(
786786
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
787787
) -> Iterator[Tuple[str, torch.Tensor]]:
788-
assert (
789-
remove_duplicate
790-
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
788+
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
791789
for config, param in zip(
792790
self._config.embedding_tables,
793791
self.emb_module.split_embedding_weights(),
@@ -899,9 +897,7 @@ def named_parameters(
899897
def named_split_embedding_weights(
900898
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
901899
) -> Iterator[Tuple[str, torch.Tensor]]:
902-
assert (
903-
remove_duplicate
904-
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
900+
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
905901
for config, tensor in zip(
906902
self._config.embedding_tables,
907903
self.split_embedding_weights(),
@@ -1082,8 +1078,9 @@ def named_parameters(
10821078
combined_key = "/".join(
10831079
[config.name for config in self._config.embedding_tables]
10841080
)
1085-
yield append_prefix(prefix, f"{combined_key}.weight"), cast(
1086-
nn.Parameter, self._emb_module.weights
1081+
yield (
1082+
append_prefix(prefix, f"{combined_key}.weight"),
1083+
cast(nn.Parameter, self._emb_module.weights),
10871084
)
10881085

10891086

@@ -1101,7 +1098,8 @@ def __init__(
11011098
self._pg = pg
11021099

11031100
self._pooling: PoolingMode = pooling_type_to_pooling_mode(
1104-
config.pooling, sharding_type # pyre-ignore[6]
1101+
config.pooling,
1102+
sharding_type, # pyre-ignore[6]
11051103
)
11061104

11071105
self._local_rows: List[int] = []
@@ -1220,9 +1218,7 @@ def purge(self) -> None:
12201218
def named_split_embedding_weights(
12211219
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
12221220
) -> Iterator[Tuple[str, torch.Tensor]]:
1223-
assert (
1224-
remove_duplicate
1225-
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1221+
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
12261222
for config, tensor in zip(
12271223
self._config.embedding_tables,
12281224
self.emb_module.split_embedding_weights(),
@@ -1362,9 +1358,7 @@ def named_parameters(
13621358
def named_split_embedding_weights(
13631359
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
13641360
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
1365-
assert (
1366-
remove_duplicate
1367-
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1361+
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
13681362
for config, tensor in zip(
13691363
self._config.embedding_tables,
13701364
self.split_embedding_weights(),
@@ -1567,6 +1561,7 @@ def named_parameters(
15671561
combined_key = "/".join(
15681562
[config.name for config in self._config.embedding_tables]
15691563
)
1570-
yield append_prefix(prefix, f"{combined_key}.weight"), cast(
1571-
nn.Parameter, self._emb_module.weights
1564+
yield (
1565+
append_prefix(prefix, f"{combined_key}.weight"),
1566+
cast(nn.Parameter, self._emb_module.weights),
15721567
)

0 commit comments

Comments
 (0)