Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,21 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:

# Benchmark forward
time_ref, output_ref = benchmark_torch_function(
torch.index_select, (input, 0, offset_indices), **bench_kwargs
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
torch.index_select,
(input, 0, offset_indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

input_group = input.split(batch_size, 0)
time, output_group = benchmark_torch_function(
torch.ops.fbgemm.group_index_select_dim0,
(input_group, indices_group),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)
logging.info(
Expand All @@ -306,13 +314,19 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
time_ref, _ = benchmark_torch_function(
functools.partial(output_ref.backward, retain_graph=True),
(grad,),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
# typing.Tuple[Tensor, ...]]` but got `Tensor`.
cat_output = torch.cat(output_group)
time, _ = benchmark_torch_function(
functools.partial(cat_output.backward, retain_graph=True),
(grad,),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)
logging.info(
Expand Down Expand Up @@ -714,6 +728,8 @@ def batch_group_index_select_bwd(
time_pyt, out_pyt = benchmark_torch_function(
index_select_fwd_ref,
(inputs, indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -726,12 +742,16 @@ def batch_group_index_select_bwd(
input_rows,
input_columns,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

time_gis, out_gis = benchmark_torch_function(
group_index_select_fwd,
(gis_inputs, indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -746,6 +766,8 @@ def batch_group_index_select_bwd(
time_bwd_pyt, _ = benchmark_torch_function(
index_select_bwd_ref,
(out_pyt, grads),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -756,6 +778,8 @@ def batch_group_index_select_bwd(
concat_grads,
optim_batch,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -766,6 +790,8 @@ def batch_group_index_select_bwd(
concat_grads,
optim_group,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,9 @@ def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> N

time_per_iter = benchmark_requests(
requests_uvm,
# pyre-fixme[6]: For 2nd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> Tensor` but got `(indices: Tensor, offsets: Tensor,
# per_sample_weights: Tensor) -> None`.
run_bench,
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
Expand Down Expand Up @@ -1934,6 +1937,9 @@ def nbit_uvm(
indices,
offsets,
),
# pyre-fixme[6]: For 3rd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> None` but got `(indices: Any, offsets: Any,
# indices_weights: Any) -> Tensor`.
lambda indices, offsets, indices_weights: emb_mixed.forward(
indices,
offsets,
Expand Down Expand Up @@ -2421,6 +2427,9 @@ def nbit_cache( # noqa C901
indices,
offsets,
),
# pyre-fixme[6]: For 3rd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> None` but got `(indices: Any, offsets: Any,
# indices_weights: Any) -> Tensor`.
lambda indices, offsets, indices_weights: emb.forward(
indices,
offsets,
Expand Down Expand Up @@ -3061,6 +3070,7 @@ def device_with_spec( # noqa C901
reuse=reuse,
alpha=alpha,
weighted=weighted,
# pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None,
zipf_oversample_ratio=3 if Ls[t] > 5 else 5,
)
Expand Down
5 changes: 4 additions & 1 deletion fbgemm_gpu/codegen/genscript/jinja_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,10 @@ def replace_pta_namespace(pta_str_list: List[str]) -> List[str]:


def replace_placeholder_types(
arg_str_list: List[str], type_combo: Optional[Dict[str, TensorType]]
# pyre-fixme[11]: Annotation `TensorType` is not defined as a type.
arg_str_list: List[str],
# pyre-fixme[11]: Annotation `TensorType` is not defined as a type.
type_combo: Optional[Dict[str, TensorType]],
) -> List[str]:
"""
Replace the placeholder types with the primitive types
Expand Down
21 changes: 20 additions & 1 deletion fbgemm_gpu/codegen/genscript/optimizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ def schema_sym_int_arg_no_default(name: str) -> str:


def make_kernel_arg(
ty: ArgType, name: str, default: Union[int, float, None], pass_by_ref: bool = False
# pyre-fixme[11]: Annotation `ArgType` is not defined as a type.
ty: ArgType,
name: str,
default: Union[int, float, None],
pass_by_ref: bool = False,
) -> str:
return {
ArgType.TENSOR: lambda x: acc_cache_tensor_arg(x, pass_by_ref=pass_by_ref),
Expand Down Expand Up @@ -318,6 +322,7 @@ class OptimizerArgs:
split_variables: List[str]
split_ref_kernel_args: List[str]
placeholder_tensor_names: List[str]
# pyre-fixme[11]: Annotation `TensorType` is not defined as a type.
placeholder_type_combos: Union[List[Dict[str, TensorType]], List[None]]

@staticmethod
Expand Down Expand Up @@ -345,6 +350,7 @@ def create(
else:
ph_combos = [None]

# pyre-fixme[28]: Unexpected keyword argument `placeholder_type_combos`.
return OptimizerArgs(
# GPU kernel args
split_kernel_args=[
Expand Down Expand Up @@ -434,6 +440,7 @@ def create_optim_args(
split_arg_spec = []
for s in arg_spec:
if s.ty in (ArgType.FLOAT, ArgType.INT, ArgType.SYM_INT):
# pyre-fixme[19]: Expected 1 positional argument.
split_arg_spec.append(OptimItem(s.ty, s.name, s.default))
else:
assert s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR)
Expand All @@ -446,8 +453,11 @@ def extend_for_cpu(spec: OptimItem) -> List[OptimItem]:
name = spec.name
default = spec.default
return [
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.TENSOR, f"{name}_host", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default),
]

Expand All @@ -458,9 +468,13 @@ def extend_for_cuda(spec: OptimItem) -> List[OptimItem]:
ty = spec.ty
ph_tys = spec.ph_tys
return [
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ty, f"{name}_dev", default, ph_tys),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ty, f"{name}_uvm", default, ph_tys),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default),
]

Expand All @@ -471,10 +485,15 @@ def extend_for_any(spec: OptimItem) -> List[OptimItem]:
ty = spec.ty
ph_tys = spec.ph_tys
return [
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.TENSOR, f"{name}_host", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ty, f"{name}_dev", default, ph_tys),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ty, f"{name}_uvm", default, ph_tys),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default),
]

Expand Down
6 changes: 5 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@


def none_throws(
optional: Optional[TypeVar("_T")], message: str = "Unexpected `None`"
# pyre-fixme[31]: Expression `typing.Optional[typing.TypeVar("_T")]` is not a
# valid type.
optional: Optional[TypeVar("_T")],
message: str = "Unexpected `None`",
# pyre-fixme[31]: Expression `typing.TypeVar("_T")` is not a valid type.
) -> TypeVar("_T"):
if optional is None:
raise AssertionError(message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,15 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):

embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
cache_miss_counter: torch.Tensor
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
local_uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `weights_offsets` is never initialized.
weights_offsets: torch.Tensor
# pyre-fixme[13]: Attribute `weights_placements` is never initialized.
weights_placements: torch.Tensor

def __init__( # noqa C901
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
lxu_cache_locations_empty: Tensor
timesteps_prefetched: List[int]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
local_uvm_cache_stats: torch.Tensor
uuid: str
# pyre-fixme[13]: Attribute `last_uvm_cache_print_state` is never initialized.
last_uvm_cache_print_state: torch.Tensor
_vbe_B_offsets: Optional[torch.Tensor]
_vbe_max_B: int
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ def forward(
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
feature_requires_grad: Optional[Tensor] = None,
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
) -> Tensor:
indices, offsets, per_sample_weights = self.prepare_inputs(
indices, offsets, per_sample_weights
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def dequantize_embs(
weight_ty: SparseType,
use_cpu: bool,
fp8_config: Optional[FP8QuantizationConfig] = None,
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
) -> torch.Tensor:
print(f"weight_ty: {weight_ty}")
assert (
Expand Down