diff --git a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py index c7d1a7c724..363060cf90 100644 --- a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py @@ -8,6 +8,7 @@ import functools from math import sqrt +from typing import List, Tuple import click import fbgemm_gpu @@ -32,7 +33,7 @@ def generate_unary_feature( num_embeddings: int, # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List[]` to avoid runtime subscripting errors. -) -> tuple[list, list, list]: +) -> Tuple[List, List, List]: lengths = [] offsets = [] indices = [] @@ -52,7 +53,7 @@ def generate_unary_feature( class MyModule(torch.nn.Module): - def __init__(self, num_tasks: int, hash_sizes: list[int]) -> None: + def __init__(self, num_tasks: int, hash_sizes: List[int]) -> None: super().__init__() self.num_tasks = num_tasks self.hash_sizes = hash_sizes @@ -72,7 +73,7 @@ def __init__(self, num_tasks: int, hash_sizes: list[int]) -> None: self.emb_modules.append(emb) def forward( - self, offsets: list[torch.Tensor], indices: list[torch.Tensor] + self, offsets: List[torch.Tensor], indices: List[torch.Tensor] ) -> torch.Tensor: tt_list = [] for n in range(self.num_tasks): diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index fa28df37e1..ee55ef4f20 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -10,6 +10,7 @@ import logging import threading import time +from typing import List, Tuple import torch @@ -31,7 +32,7 @@ def benchmark_torch_function( # noqa: C901 name: str = "", num_threads: int = 1, copy_f_for_multi_thread_test: bool = False, -) -> tuple[float, torch.Tensor]: +) -> Tuple[float, torch.Tensor]: logging.debug(f"Start to benchmark {name}...") if device != "cpu" and device != "" and device != "cuda": torch.cuda.set_device(device) @@ -67,7 +68,7 @@ def benchmark_torch_function( # noqa: C901 dtype=torch.float, device=device, ) - duration_ms_list: list[float] = [] + duration_ms_list: List[float] = [] f_list = [f] # make deepcopy of f if necessary diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index 86f31a2835..e089e08883 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -8,7 +8,7 @@ import logging import time -from typing import Callable +from typing import Callable, Tuple import click import torch @@ -25,9 +25,9 @@ def benchmark_hbc_function( - func: Callable[[Tensor], tuple[Tensor, Tensor]], + func: Callable[[Tensor], Tuple[Tensor, Tensor]], input: Tensor, -) -> tuple[float, Tensor]: +) -> Tuple[float, Tensor]: if input.is_cuda: torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) @@ -118,7 +118,7 @@ def cli( [num_bins * (num_segments + 1)], dtype=torch.float64 ).fill_(0.0) - def fbgemm_hbc_cpu(input: Tensor) -> tuple[Tensor, Tensor]: + def fbgemm_hbc_cpu(input: Tensor) -> Tuple[Tensor, Tensor]: return torch.ops.fbgemm.histogram_binning_calibration( input, bin_num_examples, @@ -130,7 +130,7 @@ def fbgemm_hbc_cpu(input: Tensor) -> tuple[Tensor, Tensor]: 0.9995, ) - def fbgemm_hbc_by_feature_cpu(input: Tensor) -> tuple[Tensor, Tensor]: + def fbgemm_hbc_by_feature_cpu(input: Tensor) -> Tuple[Tensor, Tensor]: return torch.ops.fbgemm.histogram_binning_calibration_by_feature( input, segment_values, @@ -146,7 +146,7 @@ def fbgemm_hbc_by_feature_cpu(input: Tensor) -> tuple[Tensor, Tensor]: 0.9995, ) - def fbgemm_generic_hbc_by_feature_cpu(input: Tensor) -> tuple[Tensor, Tensor]: + def fbgemm_generic_hbc_by_feature_cpu(input: Tensor) -> Tuple[Tensor, Tensor]: return torch.ops.fbgemm.generic_histogram_binning_calibration_by_feature( input, segment_values, @@ -186,7 +186,7 @@ def fbgemm_generic_hbc_by_feature_cpu(input: Tensor) -> tuple[Tensor, Tensor]: bin_num_examples_gpu: Tensor = bin_num_examples.cuda() bin_num_positives_gpu: Tensor = bin_num_positives.cuda() - def fbgemm_hbc_gpu(input: Tensor) -> tuple[Tensor, Tensor]: + def fbgemm_hbc_gpu(input: Tensor) -> Tuple[Tensor, Tensor]: return torch.ops.fbgemm.histogram_binning_calibration( input, bin_num_examples_gpu, @@ -206,7 +206,7 @@ def fbgemm_hbc_gpu(input: Tensor) -> tuple[Tensor, Tensor]: by_feature_bin_num_positives.cuda() ) - def fbgemm_hbc_by_feature_gpu(input: Tensor) -> tuple[Tensor, Tensor]: + def fbgemm_hbc_by_feature_gpu(input: Tensor) -> Tuple[Tensor, Tensor]: return torch.ops.fbgemm.histogram_binning_calibration_by_feature( input, segment_values_gpu, @@ -226,7 +226,7 @@ def fbgemm_hbc_by_feature_gpu(input: Tensor) -> tuple[Tensor, Tensor]: def fbgemm_generic_hbc_by_feature_gpu( input: Tensor, - ) -> tuple[Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor]: return ( torch.ops.fbgemm.generic_histogram_binning_calibration_by_feature( input, diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index bcc3e27488..361cd76bc7 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -12,6 +12,7 @@ import logging import random from dataclasses import dataclass +from typing import List, Tuple import click import fbgemm_gpu @@ -504,7 +505,7 @@ def masked_select_jagged_1d( def ref( values: torch.Tensor, lengths: torch.Tensor, mask: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: masked_values_ref = values[mask] cum_count = torch.cumsum(mask, 0) cum_count = torch.cat((cum_count, torch.tensor([0]))) @@ -652,9 +653,9 @@ def keyed_jagged_index_select_dim1( ref_inputs.append((key_values, key_lengths, indices, key_weights)) def keyed_jagged_index_select_dim1_ref( - inputs: list[torch.Tensor], + inputs: List[torch.Tensor], has_weights: bool, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: outputs = [] output_weights = [] for key_values, key_lengths, indices, _ in inputs: @@ -757,11 +758,11 @@ def jagged_slice_ref( offsets: torch.Tensor, start: torch.Tensor, max_L: int, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: end_offsets_ = max_L + start + offsets[:-1] end_offsets = torch.where(end_offsets_ > offsets[1:], offsets[1:], end_offsets_) start_offsets = start + offsets[:-1] - indices_to_select: list[torch.Tensor] = [] + indices_to_select: List[torch.Tensor] = [] for i in range(end_offsets.size(0)): indices_to_select.append( torch.arange(start_offsets[i].item(), end_offsets[i].item()) diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index 69e30a704a..6ed8ebb264 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -10,6 +10,7 @@ import logging import signal +from typing import List, Tuple import click import fbgemm_gpu @@ -58,7 +59,7 @@ def get_table_batched_offsets_from_dense( merged_indices: torch.Tensor, # pyre-fixme[2]: Parameter must be annotated. gpu_num, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: (T, B, L) = merged_indices.size() lengths = np.ones((T, B)) * L flat_lengths = lengths.flatten() @@ -79,7 +80,7 @@ def generate_requests( E: int, # inter-batch indices reuse rate reuse: float = 0.0, -) -> list[tuple[torch.IntTensor, torch.IntTensor, None]]: +) -> List[Tuple[torch.IntTensor, torch.IntTensor, None]]: rs = [] for gpu_num in range(num_gpus): all_indices = torch.randint( diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index 561f2f2d06..54755fff67 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -9,9 +9,8 @@ import functools import logging import random -from collections.abc import Iterable from contextlib import nullcontext -from typing import Optional, Union +from typing import Iterable, Optional, Union import click import fbgemm_gpu diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index 136e117538..2ef9abe8fa 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -11,6 +11,7 @@ import logging import math import random +from typing import List import click import fbgemm_gpu @@ -680,8 +681,8 @@ def index_select_bench( optim_group: torch.optim.Optimizer = torch.optim.SGD(gis_inputs, lr=0.1) def index_select_fwd_ref( - inputs: list[torch.Tensor], indices: list[torch.Tensor] - ) -> list[torch.Tensor]: + inputs: List[torch.Tensor], indices: List[torch.Tensor] + ) -> List[torch.Tensor]: outputs = [] for input, index in zip(inputs, indices): optim_index.zero_grad() @@ -689,18 +690,18 @@ def index_select_fwd_ref( return outputs def index_select_bwd_ref( - outputs: list[torch.Tensor], grads: list[torch.Tensor] + outputs: List[torch.Tensor], grads: List[torch.Tensor] ) -> None: for output, grad in zip(outputs, grads): optim_index.zero_grad() output.backward(grad, retain_graph=True) def batch_index_select_fwd( - concat_inputs: list[torch.Tensor], - concat_indices: list[int], - input_num_indices: list[int], - input_rows: list[int], - input_columns: list[int], + concat_inputs: List[torch.Tensor], + concat_indices: List[int], + input_num_indices: List[int], + input_rows: List[int], + input_columns: List[int], ) -> torch.autograd.Variable: optim_batch.zero_grad() return torch.ops.fbgemm.batch_index_select_dim0( @@ -708,14 +709,14 @@ def batch_index_select_fwd( ) def group_index_select_fwd( - gis_inputs: list[torch.Tensor], indices: list[int] + gis_inputs: List[torch.Tensor], indices: List[int] ) -> torch.autograd.Variable: optim_group.zero_grad() return torch.ops.fbgemm.group_index_select_dim0(gis_inputs, indices) def batch_group_index_select_bwd( output: torch.autograd.Variable, - grads: list[torch.Tensor], + grads: List[torch.Tensor], optim: torch.optim.Optimizer, ) -> torch.autograd.Variable: optim.zero_grad() diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..3518e91a21 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -12,7 +12,7 @@ import os import tempfile from contextlib import nullcontext -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional import click import numpy as np @@ -231,7 +231,7 @@ def device( # noqa C901 pooling_mode = PoolingMode.NONE do_pooling = False - common_split_args: dict[str, Any] = { + common_split_args: Dict[str, Any] = { "weights_precision": weights_precision, "stochastic_rounding": stoc, "output_dtype": output_dtype, @@ -1384,7 +1384,7 @@ def vbe( else EmbeddingLocation.HOST ) - common_split_args: dict[str, Any] = { + common_split_args: Dict[str, Any] = { "weights_precision": embconfig.weights_dtype, "stochastic_rounding": embconfig.stochastic_rounding, "output_dtype": embconfig.output_dtype, diff --git a/fbgemm_gpu/bench/tbe/tbe_cache_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_cache_benchmark.py index 225fb2a70d..8d9be72cb6 100644 --- a/fbgemm_gpu/bench/tbe/tbe_cache_benchmark.py +++ b/fbgemm_gpu/bench/tbe/tbe_cache_benchmark.py @@ -8,6 +8,7 @@ import logging import random +from typing import List, Tuple import click import numpy as np @@ -103,7 +104,7 @@ def create_embedding_specs( cached_tables_ratio: float, num_embeddings: int, embedding_dims: int, -) -> list[tuple[str, int, int, SparseType, EmbeddingLocation]]: +) -> List[Tuple[str, int, int, SparseType, EmbeddingLocation]]: """ Returns embedding specs to be used with IntNBitTableBatchedEmbeddingBagsCodegen. """ @@ -156,7 +157,7 @@ def create_embedding_specs( def create_request( num_tables: int, num_embeddings: int, batch: int, avg_pooling_factor: int -) -> tuple[Tensor, Tensor]: +) -> Tuple[Tensor, Tensor]: """ Returns [indices, offsets], which are inputs of embedding bags. """ diff --git a/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py index 9f324cc70d..b1e59a495e 100644 --- a/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py +++ b/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py @@ -16,7 +16,7 @@ import statistics from contextlib import nullcontext from pathlib import Path -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, List, Optional import click import numpy as np @@ -1398,14 +1398,14 @@ def nbit_uvm_compare_direct_mapped( ) if mixed: - Ds: list[int] = [ + Ds: List[int] = [ round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] # pyre-fixme[9]: D has type `int`; used as `floating[typing.Any]`. D = np.average(Ds) else: - Ds: list[int] = [D] * T + Ds: List[int] = [D] * T _requests_uvm = generate_requests( iters, @@ -1417,7 +1417,7 @@ def nbit_uvm_compare_direct_mapped( alpha=alpha, weighted=weighted, ) - requests_uvm: list[TBERequest] = [ + requests_uvm: List[TBERequest] = [ TBERequest(req.indices.int(), req.offsets.int(), req.per_sample_weights) for req in _requests_uvm ] @@ -1429,7 +1429,7 @@ def nbit_uvm_compare_direct_mapped( + param_size_multiplier * B * sum(Ds[:T]) * L ) - stats: dict[str, Any] = { + stats: Dict[str, Any] = { "B": B, "T": T, "E": E, diff --git a/fbgemm_gpu/bench/tbe/tbe_kv_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_kv_benchmark.py index 654996bf0e..7f9b5ee3e1 100644 --- a/fbgemm_gpu/bench/tbe/tbe_kv_benchmark.py +++ b/fbgemm_gpu/bench/tbe/tbe_kv_benchmark.py @@ -9,7 +9,7 @@ import gc import logging import time -from typing import Callable +from typing import Callable, Dict, Type import click import numpy as np @@ -57,7 +57,7 @@ ) -TBE_CLASS_MAP: dict[str, type[IntNBitTableBatchedEmbeddingBagsCodegen]] = { +TBE_CLASS_MAP: Dict[str, Type[IntNBitTableBatchedEmbeddingBagsCodegen]] = { "KVEmbeddingInference": KVEmbeddingInference, "IntNBitTableBatchedEmbeddingBagsCodegen": IntNBitTableBatchedEmbeddingBagsCodegen, } diff --git a/fbgemm_gpu/bench/tbe/tbe_ssd_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_ssd_benchmark.py index 6816e922c5..218a3b390e 100644 --- a/fbgemm_gpu/bench/tbe/tbe_ssd_benchmark.py +++ b/fbgemm_gpu/bench/tbe/tbe_ssd_benchmark.py @@ -12,7 +12,7 @@ import tempfile import time from contextlib import nullcontext -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import click import numpy as np @@ -61,7 +61,7 @@ def benchmark_ssd_function( buf: torch.Tensor, indices: torch.Tensor, indices_per_itr: int, -) -> tuple[float, float]: +) -> Tuple[float, float]: actions_count_cpu = torch.tensor([indices_per_itr]).long().cpu() # warmup for i in range(warmup_iters): @@ -302,14 +302,14 @@ def ssd_training( # noqa C901 else: feature_requires_grad = None if mixed: - Ds: list[int] = [ + Ds: List[int] = [ round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] # pyre-fixme[9]: D has type `int`; used as `floating[typing.Any]`. D = np.average(Ds) else: - Ds: list[int] = [D] * T + Ds: List[int] = [D] * T if pooling is None or pooling == "sum": pooling = "sum" @@ -323,13 +323,13 @@ def ssd_training( # noqa C901 do_pooling = False feature_table_map = list(range(T)) - common_args: dict[str, Any] = { + common_args: Dict[str, Any] = { "feature_table_map": feature_table_map, "learning_rate": 0.1, "eps": 0.1, "pooling_mode": pooling_mode, } - common_split_tbe_args: dict[str, Any] = { + common_split_tbe_args: Dict[str, Any] = { # SSD only supports rowwise-adagrad "optimizer": OptimType.EXACT_ROWWISE_ADAGRAD, "weights_precision": weights_precision, diff --git a/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py index b2c768186a..19e8cdbfa5 100644 --- a/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py +++ b/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py @@ -12,7 +12,7 @@ import os import tempfile from contextlib import nullcontext -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional try: from fbgemm_gpu.tbe.trace.fbgemm_kineto_trace_handler import ( @@ -136,7 +136,7 @@ def device( # noqa C901 optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD # Construct the common split arguments for the embedding op - common_split_args: dict[str, Any] = embconfig.split_args() | { + common_split_args: Dict[str, Any] = embconfig.split_args() | { "optimizer": optimizer, "learning_rate": 0.1, "eps": 0.1, @@ -295,7 +295,7 @@ def _context_factory( FbgemmKinetoTraceHandler(p_obj).sync_log( run_id=str(trace_url), test_phase="fwd", - test_name="tbe_training", + test_name=str("tbe_training"), benchmark_duration_us=float(time_per_iter * 1.0e6), achieved_bw_gbps=float(read_write_bytes / time_per_iter / 1.0e9), ) @@ -352,7 +352,7 @@ def _context_factory( FbgemmKinetoTraceHandler(p_obj).sync_log( run_id=str(trace_url), test_phase="fwd_bwd", - test_name="tbe_training", + test_name=str("tbe_training"), benchmark_duration_us=float(time_per_iter * 1.0e6), achieved_bw_gbps=float( 2 * read_write_bytes / time_per_iter / 1.0e9 @@ -457,7 +457,7 @@ def device_with_speclist( # noqa C901 optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD # Construct the common split arguments for the embedding op - common_split_args: dict[str, Any] = embconfig.split_args() | { + common_split_args: Dict[str, Any] = embconfig.split_args() | { "optimizer": optimizer, "learning_rate": 0.1, "eps": 0.1, diff --git a/fbgemm_gpu/codegen/genscript/torch_type_utils.py b/fbgemm_gpu/codegen/genscript/torch_type_utils.py index 206cc9149e..aa442ad374 100644 --- a/fbgemm_gpu/codegen/genscript/torch_type_utils.py +++ b/fbgemm_gpu/codegen/genscript/torch_type_utils.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from enum import IntEnum +from typing import Dict class ArgType(IntEnum): @@ -36,7 +37,7 @@ class TensorType: scalar_type: str -arg_type_to_tensor_type: dict[ArgType, TensorType] = { +arg_type_to_tensor_type: Dict[ArgType, TensorType] = { ArgType.FLOAT_TENSOR: TensorType("float", "at::ScalarType::Float"), ArgType.HALF_TENSOR: TensorType("at::Half", "at::ScalarType::Half"), ArgType.BFLOAT16_TENSOR: TensorType("at::BFloat16", "at::ScalarType::BFloat16"), diff --git a/fbgemm_gpu/experimental/example/test/__init__.py b/fbgemm_gpu/experimental/example/test/__init__.py index abdc1ba070..07f329ba92 100644 --- a/fbgemm_gpu/experimental/example/test/__init__.py +++ b/fbgemm_gpu/experimental/example/test/__init__.py @@ -6,10 +6,12 @@ # pyre-strict +from typing import Tuple + import torch -gpu_unavailable: tuple[bool, str] = ( +gpu_unavailable: Tuple[bool, str] = ( not torch.cuda.is_available() or torch.cuda.device_count() == 0, "CUDA is not available or no GPUs detected", ) diff --git a/fbgemm_gpu/experimental/gemm/test/fp4_quantize_test.py b/fbgemm_gpu/experimental/gemm/test/fp4_quantize_test.py index 5f5379c19e..cc5d4c1406 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp4_quantize_test.py +++ b/fbgemm_gpu/experimental/gemm/test/fp4_quantize_test.py @@ -6,6 +6,7 @@ import math import unittest +from typing import Tuple import fbgemm_gpu import torch @@ -40,7 +41,7 @@ def setUp(self) -> None: def test_quantize_fp4(self) -> None: def _test_quantize_fp4( - shape: tuple[int, int], + shape: Tuple[int, int], device: str = "cuda", ) -> None: M, N = shape @@ -91,7 +92,7 @@ def setUp(self) -> None: def test_rms_quantize_fp4(self) -> None: def _test_rms_quantize_fp4( - shape: tuple[int, int], + shape: Tuple[int, int], device: str = "cuda", ) -> None: M, N = shape @@ -157,7 +158,7 @@ def setUp(self) -> None: def test_silu_quantize_fp4(self) -> None: def _test_silu_quantize_fp4( - shape: tuple[int, int], + shape: Tuple[int, int], device: str = "cuda", ) -> None: M, N = shape @@ -215,7 +216,7 @@ def setUp(self) -> None: def test_silu_quantize_nvfp4(self) -> None: def _test_silu_quantize_nvfp4( - shape: tuple[int, int], + shape: Tuple[int, int], device: str = "cuda", ) -> None: M, N = shape @@ -257,7 +258,7 @@ def setUp(self) -> None: def test_rms_quantize_nvfp4(self) -> None: def _test_rms_quantize_nvfp4( - shape: tuple[int, int], + shape: Tuple[int, int], device: str = "cuda", ) -> None: M, N = shape diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py index 96cf7811f8..77fa04b3a2 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Callable +from typing import Callable, Tuple import click @@ -32,7 +32,7 @@ def _run_benchmark( bench_factory: Callable[ [torch.Tensor, torch.Tensor], Callable[[], torch.Tensor] ], - shape: tuple[int, int, int] = (1024, 1024, 1024), + shape: Tuple[int, int, int] = (1024, 1024, 1024), tag: str = "", ) -> None: # Benchmarks the function returned by bench_factory. diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py index 319be4cb55..0712443218 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py @@ -8,7 +8,7 @@ import itertools import unittest -from typing import Optional +from typing import Optional, Tuple import torch import triton @@ -42,7 +42,7 @@ def setUp(self) -> None: def test_quantize_fp8_row(self) -> None: def _test_quantize_fp8_row( - shape: tuple[int, ...], + shape: Tuple[int, ...], use_triton: bool, device: torch.device, output_device: Optional[torch.device] = None, @@ -227,7 +227,7 @@ def _test_quantize_fp8_row( def test_quantize_fp8_packed_row(self) -> None: def _test_quantize_fp8_packed_row( - shape: tuple[int, ...], + shape: Tuple[int, ...], use_triton: bool, device: torch.device, output_device: Optional[torch.device] = None, @@ -337,7 +337,7 @@ def _test_quantize_fp8_packed_row( def test_dequantize_fp8_row(self) -> None: def _test_dequantize_fp8_row( - shape: tuple[int, ...], + shape: Tuple[int, ...], ) -> None: a = torch.randn(shape, dtype=torch.bfloat16, device="cuda") a_fp8, a_scale = quantize_fp8_row( @@ -372,7 +372,7 @@ def _test_dequantize_fp8_row( def test_dequantize_fp8_packed_row(self) -> None: def _test_dequantize_fp8_packed_row( - shape: tuple[int, ...], + shape: Tuple[int, ...], ) -> None: a = torch.randn(shape, dtype=torch.bfloat16, device="cuda") @@ -410,7 +410,7 @@ def _test_dequantize_fp8_packed_row( def test_scale_fp8_row(self) -> None: def _test_scale_fp8_row( - shape: tuple[int, int], + shape: Tuple[int, int], device: torch.device, ) -> None: M, K = shape @@ -438,7 +438,7 @@ def _test_scale_fp8_row( def test_matmul_fp8_row(self) -> None: def _test_matmul_fp8_row( - shape: tuple[int, int, int], + shape: Tuple[int, int, int], device: torch.device, fp8_fast_accum: bool, use_bias: bool = False, @@ -521,7 +521,7 @@ def _fp8_clamp(x: torch.Tensor) -> torch.Tensor: return xq def _test_matmul_fp8_row_skip_scaling( - shape: tuple[int, int, int], + shape: Tuple[int, int, int], device: torch.device, use_bias: bool = True, transpose_input: bool = False, @@ -602,7 +602,7 @@ def _quantize_matmul_fp8( def test_quantize_fp8_group(self) -> None: def _test_quantize_fp8_group( - shape: tuple[int, int], + shape: Tuple[int, int], group_size: int, use_scale_ub: bool = False, ) -> None: @@ -633,8 +633,8 @@ def _test_quantize_fp8_group( def test_quantize_fp8_block(self) -> None: def _test_quantize_fp8_block( - shape: tuple[int, int], - block_shape: tuple[int, int], + shape: Tuple[int, int], + block_shape: Tuple[int, int], use_scale_ub: bool = False, ) -> None: M, K = shape @@ -667,8 +667,8 @@ def _test_quantize_fp8_block( def test_dequantize_fp8_block(self) -> None: def _test_dequantize_fp8_block( - shape: tuple[int, int], - block_shape: tuple[int, int], + shape: Tuple[int, int], + block_shape: Tuple[int, int], use_scale_ub: bool = False, ) -> None: M, K = shape @@ -695,8 +695,8 @@ def _test_dequantize_fp8_block( def test_matmul_fp8_block(self) -> None: def _test_matmul_fp8_block( - shape: tuple[int, int, int], - block_shape: tuple[int, int, int], + shape: Tuple[int, int, int], + block_shape: Tuple[int, int, int], fp8_fast_accum: bool, transpose_input: bool = False, device: str = "cuda", diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py index cec9cc0b2f..59d1977455 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Optional, Union +from typing import Optional, Tuple, Union import torch import triton # @manual @@ -275,7 +275,7 @@ def triton_quantize_mx4_unpack( mbits: int = 1, rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to mx4 format using efficient triton kernels. @@ -701,7 +701,7 @@ def triton_silu_quantize_mx4_unpack( mbits: int = 1, rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to mx4 format using efficient triton kernels. @@ -1126,7 +1126,7 @@ def triton_rms_quantize_mx4_unpack( rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, EPS: float = 1e-5, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to mx4 format using efficient triton kernels. @@ -1447,7 +1447,7 @@ def triton_scale_nvfp4_quant( rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, EPS: float = 1e-5, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to nvfp4 format using efficient triton kernels. @@ -1798,7 +1798,7 @@ def triton_scale_nvfp4_quant_silu( mbits: int = 1, rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to nvfp4 format using efficient triton kernels. @@ -2161,7 +2161,7 @@ def triton_scale_nvfp4_quant_rms( rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, EPS: float = 1e-5, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to nvfp4 format using efficient triton kernels. @@ -2546,7 +2546,7 @@ def triton_nvfp4_quant_stacked( rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, EPS: float = 1e-5, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to nvfp4 format using efficient triton kernels. @@ -3784,7 +3784,7 @@ def mega_fp4_quantize_kernel( rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, EPS: float = 1e-5, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: orig_shape = input.shape assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." @@ -3963,7 +3963,7 @@ def triton_nvfp4_quant_stacked_silu( rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, EPS: float = 1e-5, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to nvfp4 format using efficient triton kernels. @@ -4374,7 +4374,7 @@ def triton_nvfp4_quant_stacked_rms( rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil, stochastic_casting: bool = False, EPS: float = 1e-5, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to nvfp4 format using efficient triton kernels. @@ -5222,7 +5222,7 @@ def mega_fp4_unpack( m_sizes: torch.Tensor, input: torch.Tensor, group_size: int = 16, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: orig_shape = input.shape assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." @@ -5467,7 +5467,7 @@ def _calculate_group_max( def calculate_group_max( input: torch.Tensor, m_sizes: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." other_dims = 1 if input.ndim == 1 else -1 diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index a08d2934a0..4dc3e4f8e7 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -8,7 +8,7 @@ import functools import logging import os -from typing import Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch import triton # @manual @@ -68,7 +68,7 @@ def supports_float8_fnuz(throw_on_hip_incompatibility: bool = True) -> bool: return False -def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]: +def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]: """ Helper function to get constant values for the current platform. @@ -106,7 +106,7 @@ def init_to_zero(name): return lambda nargs: nargs[name].zero_() -def get_configs_io_bound() -> list[Config]: +def get_configs_io_bound() -> List[Config]: """ Returns a list of configs for matmul that are IO bound. @@ -159,7 +159,7 @@ def dummy_prune_configs(configs, named_args, **kwargs): return configs -MATMUL_CONFIGS: list[Config] = [ +MATMUL_CONFIGS: List[Config] = [ # basic configs for compute-bound matmuls Config( {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, @@ -960,7 +960,7 @@ def make_autotuner_config(dictargs, **kwargs): return Config(dictargs, **kwargs) -def get_ws_configs() -> list[Config]: +def get_ws_configs() -> List[Config]: if not has_warp_specialization: return [] return [ @@ -1281,7 +1281,7 @@ def matmul_fp8_row( output += bias[None, :] return output.to(c.dtype) - def grid(META: dict[str, int]) -> tuple[int, int]: + def grid(META: Dict[str, int]) -> Tuple[int, int]: return ( triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"], @@ -1289,7 +1289,7 @@ def grid(META: dict[str, int]) -> tuple[int, int]: NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - def persistent_grid(META: dict[str, int]) -> tuple[int]: + def persistent_grid(META: Dict[str, int]) -> Tuple[int]: return ( min( NUM_SMS, @@ -1337,7 +1337,7 @@ def persistent_grid(META: dict[str, int]) -> tuple[int]: desc_helper.init_tma_descriptor("b_scale") desc_helper.init_tma_descriptor("bias") - def persistent_grid_tma_ws(META: dict[str, int]) -> tuple[int]: + def persistent_grid_tma_ws(META: Dict[str, int]) -> Tuple[int]: nonlocal desc_helper # noqa: F824 assert a_scale is not None # Type narrowing for Pyre desc_helper.fill_2d_tma_descriptor( @@ -1450,7 +1450,7 @@ def persistent_grid_tma_ws(META: dict[str, int]) -> tuple[int]: desc_helper.init_tma_descriptor("b_scale") desc_helper.init_tma_descriptor("bias") - def persistent_grid_tma(META: dict[str, int]) -> tuple[int]: + def persistent_grid_tma(META: Dict[str, int]) -> Tuple[int]: nonlocal desc_helper # noqa: F824 assert a_scale is not None # Type narrowing for Pyre desc_helper.fill_2d_tma_descriptor( @@ -2113,7 +2113,7 @@ def matmul_fp8_block( raise Exception("'b_scale' must be on the same device as 'a'") # noqa: E731: - def grid(META: dict[str, int]) -> tuple[int, int]: + def grid(META: Dict[str, int]) -> Tuple[int, int]: return ( triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"], @@ -2205,7 +2205,7 @@ def matmul_fp8_block_meta( return torch.empty((M, N), device=a.device, dtype=torch.bfloat16) -def get_matmul_tune(M: int, N: int, K: int) -> tuple[int, int, int]: +def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]: """ Generate a simplified matmul tune key for A @ B.T with [M, K] A and [N, K] B to reduce excessive autotuning. @@ -2234,7 +2234,7 @@ def prep_matmul( a: Union[TensorWrapper, torch.Tensor], b: Union[TensorWrapper, torch.Tensor], dot_out_dtype: Optional[torch.dtype], -) -> tuple[ +) -> Tuple[ int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device ]: """ @@ -2464,7 +2464,7 @@ def triton_quantize_fp8_row( scale_ub: Optional[Tensor] = None, zero_start_index_M: Optional[Tensor] = None, align_rows_to: Optional[int] = None, -) -> tuple[Tensor, Tensor]: +) -> Tuple[Tensor, Tensor]: """ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings. @@ -2739,7 +2739,7 @@ def triton_quantize_fp8_packed_row( scale_ub: Optional[Tensor] = None, zero_start_index_M: Optional[Tensor] = None, return_only_packed: Optional[bool] = False, -) -> tuple[Optional[Tensor], Optional[Tensor], Tensor]: +) -> Tuple[Optional[Tensor], Optional[Tensor], Tensor]: """ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings. @@ -2832,7 +2832,7 @@ def quantize_fp8_packed_row( zero_start_index_M: Optional[Tensor] = None, use_triton: bool = True, output_device: Optional[torch.device] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a to fp8 with row-wise scalings and optionally move to output device. @@ -2930,7 +2930,7 @@ def quantize_fp8_row( use_triton: bool = True, output_device: Optional[torch.device] = None, align_rows_to: Optional[int] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a to fp8 with row-wise scalings and optionally move to output device. @@ -2990,7 +2990,7 @@ def quantize_fp8_row_meta( use_triton: bool = True, output_device: Optional[torch.device] = None, align_rows_to: Optional[int] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """Shape function for torch compile.""" if output_device is None: output_device = a.device @@ -3209,7 +3209,7 @@ def triton_quantize_fp8_block( block_k: int = 256, scale_ub: Optional[torch.Tensor] = None, k_major: bool = True, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to fp8 with block-wise scalings. @@ -3287,7 +3287,7 @@ def quantize_fp8_block( use_triton: bool = True, output_device: Optional[torch.device] = None, k_major: bool = True, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to fp8 with block-wise scalings and optionally move to output device. @@ -3520,7 +3520,7 @@ def triton_quantize_fp8_group( scale_ub: Optional[torch.Tensor] = None, m_sizes: Optional[torch.Tensor] = None, k_major: bool = True, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to fp8 with group-wise scalings. @@ -3590,7 +3590,7 @@ def quantize_fp8_group( k_major: bool = True, use_triton: bool = True, output_device: Optional[torch.device] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to fp8 with group-wise scalings and optionally move to output device. @@ -3807,7 +3807,7 @@ def get_full_non_persistent_tuning_space(): return configs -MATMUL_CONFIGS_NON_PERSISTENT: list[Config] = get_full_non_persistent_tuning_space() +MATMUL_CONFIGS_NON_PERSISTENT: List[Config] = get_full_non_persistent_tuning_space() MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K = [ triton.Config( { @@ -4323,7 +4323,7 @@ def dequantize_fp8_row( M = xq.shape[0] use_int64 = xq.numel() > 2**31 - def grid(meta: dict[str, int]) -> tuple[int]: + def grid(meta: Dict[str, int]) -> Tuple[int]: return (triton.cdiv(M, meta["BLOCK_M"]),) with torch.cuda.device(xq.device.index): @@ -4434,7 +4434,7 @@ def dequantize_fp8_packed_row( M = actual_xq.shape[0] use_int64 = actual_xq.numel() > 2**31 - def grid(meta: dict[str, int]) -> tuple[int]: + def grid(meta: Dict[str, int]) -> Tuple[int]: return (triton.cdiv(M, meta["BLOCK_M"]),) with torch.cuda.device(actual_xq.device.index): @@ -4514,7 +4514,7 @@ def dequantize_fp8_block( M, K = xq.size() x_dequant = torch.empty_like(xq, dtype=torch.bfloat16) - def grid(meta: dict[str, int]) -> tuple[int, int]: + def grid(meta: Dict[str, int]) -> Tuple[int, int]: return ( triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(K, meta["BLOCK_K"]), diff --git a/fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py index 2e12d41447..b3e0a7eba5 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py @@ -12,6 +12,7 @@ import uuid from functools import lru_cache from pprint import pprint +from typing import Tuple import fbgemm_gpu.experimental.gen_ai # noqa: F401 import pandas as pd @@ -31,7 +32,7 @@ def get_symm_buffer(group): return inp, group.group_name -def _setup(path: str) -> tuple[int, int]: +def _setup(path: str) -> Tuple[int, int]: rank = int(os.environ["LOCAL_RANK"]) W = int(os.environ["WORLD_SIZE"]) device = torch.device(f"cuda:{rank}") diff --git a/fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py index 3829d0b583..3a4f5ebce0 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py @@ -6,7 +6,7 @@ import functools import itertools -from typing import Optional +from typing import List, Optional, Tuple import click import torch @@ -156,16 +156,16 @@ def bench_topk_index_shuffling(T: int, E: int, K: int) -> None: torch.manual_seed(0) num_rotating_buffers = min(max(2, triton.cdiv(1024 * 1024 * 1024, T * E * 2)), 1000) - scores_list: list[torch.Tensor] = [ + scores_list: List[torch.Tensor] = [ torch.randn(T, E, device=_ACCELERATOR_TAG, dtype=torch.bfloat16) for i in range(num_rotating_buffers) ] - def fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def fn() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: for scores in scores_list: index_shuffling(scores, top_k=K) - def ref_fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def ref_fn() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: for scores in scores_list: _, selected_expert_indices = torch.topk(scores, K, dim=1) expert_indices, _ = torch.sort( @@ -235,14 +235,14 @@ def bench_combine_or_split_shuffling( assert token_counts.sum().item() == input_num_tokens num_rotating_buffers = triton.cdiv(1024 * 1024 * 1024, tokens.numel() * 2) - token_list: list[torch.Tensor] = [ + token_list: List[torch.Tensor] = [ tokens.clone() for _ in range(num_rotating_buffers) ] - token_count_list: list[torch.Tensor] = [ + token_count_list: List[torch.Tensor] = [ token_counts.clone() for _ in range(num_rotating_buffers) ] - def fn() -> tuple[torch.Tensor, Optional[torch.Tensor]]: + def fn() -> Tuple[torch.Tensor, Optional[torch.Tensor]]: for tokens, token_counts in zip(token_list, token_count_list): if is_combine_shuffling: combine_shuffling( diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 0775d4c66e..898653d756 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional +from typing import Any, Dict, List, Optional, Tuple import click @@ -73,7 +73,7 @@ def set_amd_env_vars() -> None: os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30" -def get_llama_shapes() -> list[tuple[int, int, int, int]]: +def get_llama_shapes() -> List[Tuple[int, int, int, int]]: # Helper function that returns a list of shapes relevant to llama. llama_shapes = [] @@ -103,7 +103,7 @@ def get_llama_shapes() -> list[tuple[int, int, int, int]]: return llama_shapes -def get_ldm_shapes() -> list[tuple[int, int, int, int]]: +def get_ldm_shapes() -> List[Tuple[int, int, int, int]]: # Helper function that returns a list of shapes relevant to ldm. return [ (1, 1536, 3584, 3584), @@ -160,11 +160,11 @@ def __str__(self) -> str: def benchmark_grouped( - quantize_ops: list[QuantizeOpBase], - b: list[int], - m: list[int], - n: list[int], - k: list[int], + quantize_ops: List[QuantizeOpBase], + b: List[int], + m: List[int], + n: List[int], + k: List[int], bench_quantize: bool = False, use_rotating_buffer_bench: bool = False, use_cuda_graph: bool = True, @@ -172,7 +172,7 @@ def benchmark_grouped( num_iters: int = 1, fast_accum: bool = True, torch_compile: bool = False, -) -> dict[str, Any]: +) -> Dict[str, Any]: num_groups = len(m) # Create input tensors. A = [] @@ -193,7 +193,7 @@ def benchmark_grouped( log_m = m[0] if len(np.unique(m)) == 1 else m log_n = n[0] if len(np.unique(n)) == 1 else n log_k = k[0] if len(np.unique(k)) == 1 else k - results: dict[str, Any] = {"M": log_m, "N": log_n, "K": log_k, "groups": num_groups} + results: Dict[str, Any] = {"M": log_m, "N": log_n, "K": log_k, "groups": num_groups} # Benchmark each operator. for quantize_op in quantize_ops: metrics = Metrics(op_name=quantize_op.name) @@ -277,7 +277,7 @@ def benchmark_grouped( def benchmark( - quantize_ops: list[QuantizeOpBase], + quantize_ops: List[QuantizeOpBase], b: int, m: int, n: int, @@ -289,7 +289,7 @@ def benchmark( num_iters: int = 1, fast_accum: bool = True, torch_compile: bool = False, -) -> dict[str, Any]: +) -> Dict[str, Any]: # Create input tensors. if b > 1: A = torch.randn(b, m, k, device="cuda", dtype=torch.bfloat16) @@ -301,7 +301,7 @@ def benchmark( # Compute baseline output for correctness checking. out_ref = torch.matmul(A, torch.transpose(B, -2, -1)) # Keep track of results. - results: dict[str, Any] = {"B": b, "M": m, "N": n, "K": k} + results: Dict[str, Any] = {"B": b, "M": m, "N": n, "K": k} # Benchmark each operator. for quantize_op in quantize_ops: metrics = Metrics(op_name=quantize_op.name) @@ -368,7 +368,7 @@ def benchmark( return results -def plot_benchmark(results: list[dict[str, Any]], output_dir: str) -> None: +def plot_benchmark(results: List[Dict[str, Any]], output_dir: str) -> None: """Create a barplot visualizing the TFLOPS of each kernel.""" # Reprocess into new dataframe with proper graph format. data = [] @@ -394,7 +394,7 @@ def plot_benchmark(results: list[dict[str, Any]], output_dir: str) -> None: print(f"Plot saved to {img_fn}") -def collect_kernels_to_profile(kernels: Optional[list[str]]) -> list[QuantizeOpBase]: +def collect_kernels_to_profile(kernels: Optional[List[str]]) -> List[QuantizeOpBase]: # Get existing quantization operators. quantize_ops = get_quantize_ops() quantize_ops = [op for op in quantize_ops if op.supported] @@ -403,7 +403,7 @@ def collect_kernels_to_profile(kernels: Optional[list[str]]) -> list[QuantizeOpB return [op for op in quantize_ops if op.name in kernels] -def print_kernels(kernels: Optional[list[str]]) -> list[QuantizeOpBase]: +def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]: data = sorted( [ (op.name, "Yes" if op.cuda else "No", "Yes" if op.hip else "No") diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 4e842b3130..cb0be6d43a 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -6,6 +6,7 @@ # Keep a registry of all quantize operators. import abc +from typing import List, Tuple import fbgemm_gpu.experimental.gen_ai # noqa: F401 import numpy as np @@ -220,7 +221,7 @@ def register_quantize_op(op): return op -def get_quantize_ops() -> list[QuantizeOpBase]: +def get_quantize_ops() -> List[QuantizeOpBase]: """Get all registered quantize ops.""" return quantize_op_registry @@ -1767,7 +1768,7 @@ def _int4_row_quantize( self, x: torch.Tensor, group_size: int = 128, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: n_bit = 4 # Number of target bits. to_quant = x.reshape(-1, group_size).to(torch.float) diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py index 0e2f7e6e9a..8323aba91a 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import Optional, Tuple import torch from torch.library import register_fake @@ -35,7 +35,7 @@ def custom_op_fmha( softmax_scale: Optional[float] = None, causal: bool = False, seqlen_kv: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: assert q.is_contiguous(), "q is not contiguous" assert k.is_contiguous(), "k is not contiguous" assert v.is_contiguous(), "v is not contiguous" @@ -123,7 +123,7 @@ def custom_op_fmha_bwd( max_seq_len_q: Optional[int] = None, max_seq_len_k: Optional[int] = None, causal: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return torch.ops.fbgemm.fmha_bwd( dOutput, query, diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py index ea9bdc3768..130711355e 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional +from typing import Any, Optional, Tuple import torch @@ -175,7 +175,7 @@ def forward( # type: ignore max_seq_len_q: Optional[int] = None, max_seq_len_k: Optional[int] = None, seqlen_kv: Optional[torch.Tensor] = None, - window_size: tuple[int, int] = (-1, -1), + window_size: Tuple[int, int] = (-1, -1), bottom_right: bool = True, deterministic: bool = False, ) -> torch.Tensor: @@ -242,7 +242,7 @@ def forward( # type: ignore return out @staticmethod - def backward(ctx, dout: torch.Tensor, *args: Any) -> tuple[ # type: ignore + def backward(ctx, dout: torch.Tensor, *args: Any) -> Tuple[ # type: ignore torch.Tensor, torch.Tensor, torch.Tensor, diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/activation.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/activation.py index 39bbec8d9c..de15111358 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/activation.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/activation.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import Optional +from typing import Optional, Tuple import torch import triton @@ -75,7 +75,7 @@ def silu_mul_quant( x1: torch.Tensor, scale_ub: Optional[torch.Tensor] = None, valid_token_count: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Fused silu, mul, and FP8 rowwise quantization operations. diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/gather_scatter.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/gather_scatter.py index 0b9f3a19ba..b4d10843bf 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/gather_scatter.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/gather_scatter.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import Optional +from typing import Optional, Tuple import torch import triton @@ -101,7 +101,7 @@ def gather_scale_quant_dense_tokens( scores: torch.Tensor, scale_ub: Optional[torch.Tensor] = None, valid_token_count: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Gather, scale, and quantize dense tokens along 1D indices. diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/layers.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/layers.py index 1497028d42..bb35d68774 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/layers.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/layers.py @@ -6,10 +6,9 @@ import os from abc import ABCMeta, abstractmethod -from collections.abc import Mapping from dataclasses import dataclass from functools import cached_property -from typing import Callable, Optional, Union +from typing import Callable, List, Mapping, Optional, Tuple, Union import torch @@ -92,7 +91,7 @@ def num_local_experts(self) -> int: INIT_METHODS_TYPE = Mapping[ str, - Callable[[torch.Tensor], Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], + Callable[[torch.Tensor], Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], ] @@ -847,8 +846,8 @@ def _gather_shapes(self, send_sizes: torch.Tensor) -> torch.Tensor: def _exchange_tokens( self, send_tokens: torch.Tensor, - send_sizes: Optional[list[int]], - recv_sizes: Optional[list[int]], + send_sizes: Optional[List[int]], + recv_sizes: Optional[List[int]], is_input: bool, ) -> torch.Tensor: """ @@ -1067,7 +1066,7 @@ def _gather_tokens( def _route( self, tokens: torch.Tensor - ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]: B, T, D = tokens.shape tokens = tokens.view(-1, D) diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/shuffling.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/shuffling.py index 066cf7210b..5a001c15ce 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/shuffling.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/moe/shuffling.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import Optional, Union +from typing import Optional, Tuple, Union import torch import triton @@ -20,7 +20,7 @@ def combine_shuffling( expert_start: Optional[int] = None, expert_end: Optional[int] = None, is_padded: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: # pyre-ignore return _combine_or_split_shuffling( tokens=tokens, @@ -60,7 +60,7 @@ def _combine_or_split_shuffling( is_padded: bool, is_combine: bool, init_with_zeros: bool = False, -) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: # T is intentionally ignored in kernel interface to avoid recompilation assert tokens.is_contiguous() assert token_counts.is_contiguous() diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py index 776f8e08ae..08eef43a40 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py @@ -8,6 +8,7 @@ # Helper functions for using FBGEMM quantized operators. +from typing import Tuple import torch @@ -31,7 +32,7 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor: def int4_row_quantize_zp( x: torch.Tensor, group_size: int = 128, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: n_bit = 4 # Number of target bits. # Split input into chunks of group_size. This approach allows K that isnt divisible by group_size. to_quant = torch.split(x.to(torch.float), group_size, dim=-1) @@ -72,7 +73,7 @@ def int4_row_quantize_zp( def int4_row_quantize( x: torch.Tensor, group_size: int = 128, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Helper function to quantize a tensor to int4 with groupwise scales. @@ -110,7 +111,7 @@ def int4_row_quantize( def quantize_int4_preshuffle( w: torch.Tensor, group_size: int = 128, dtype: str = "fp8", use_zp: bool = True -) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: +) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Quantizes an input weight tensor to int4 using preshuffling and scale packing. This function is intended to be used with fbgemms mixed dtype kernels and is expected @@ -130,7 +131,7 @@ def quantize_int4_preshuffle( def _quantize( w: torch.Tensor, dtype: str = "fp8" - ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if dtype == "fp8": # Start by lowering weights to FP8 and producing row scales. @@ -227,7 +228,7 @@ def shuffle_slice( def scale_nvfp4_quant( input: torch.Tensor, input_global_scale: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale. This function quantizes the last dimension of the given tensor `input`. For diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/make_heuristic.py b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/make_heuristic.py index 70447d77b9..8f354111bd 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/make_heuristic.py +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/make_heuristic.py @@ -9,6 +9,7 @@ import argparse from collections import defaultdict from dataclasses import dataclass +from typing import Dict, List, Set, Tuple @dataclass(frozen=True) @@ -20,11 +21,11 @@ class ProblemShape: K: int @classmethod - def from_tuple(cls, shape_tuple: tuple[int, int, int]) -> "ProblemShape": + def from_tuple(cls, shape_tuple: Tuple[int, int, int]) -> "ProblemShape": """Create ProblemShape from a tuple.""" return cls(M=shape_tuple[0], N=shape_tuple[1], K=shape_tuple[2]) - def to_tuple(self) -> tuple[int, int, int]: + def to_tuple(self) -> Tuple[int, int, int]: """Convert to tuple for backwards compatibility.""" return (self.M, self.N, self.K) @@ -51,7 +52,7 @@ class NEntry: """Represents an N dimension entry with its K entries.""" N: int - k_entries: list[KEntry] + k_entries: List[KEntry] @dataclass @@ -59,17 +60,17 @@ class MEntry: """Represents an M dimension entry with its N entries.""" M: int - n_entries: list[NEntry] + n_entries: List[NEntry] @dataclass class Heuristic: """Represents the complete heuristic structure.""" - m_entries: list[MEntry] + m_entries: List[MEntry] -def get_kernel_assignment(file_path: str, threshold: float) -> dict[ProblemShape, str]: +def get_kernel_assignment(file_path: str, threshold: float) -> Dict[ProblemShape, str]: """ Assign kernels to problem shape from a set of profiling runs on a kernel. The heuristic is currently built in a greedy approach: @@ -77,11 +78,11 @@ def get_kernel_assignment(file_path: str, threshold: float) -> dict[ProblemShape 2. For the above kernels, count how often it appeared across all problem shapes. 3. When assigning a kernel to a problem shape, prioritize kernels that appear more often to minimize the number of kernels used. """ - kernel_results: list[KernelResult] = [] - best_times_ms: dict[ProblemShape, float] = {} - kernel_count: dict[str, int] = defaultdict(int) - kernel_candidates: dict[ProblemShape, set[str]] = defaultdict(set) - kernel_assignment: dict[ProblemShape, str] = {} + kernel_results: List[KernelResult] = [] + best_times_ms: Dict[ProblemShape, float] = {} + kernel_count: Dict[str, int] = defaultdict(int) + kernel_candidates: Dict[ProblemShape, Set[str]] = defaultdict(set) + kernel_assignment: Dict[ProblemShape, str] = {} with open(file_path, "r") as file: # Parse CSV and find the best time for each problem shape @@ -121,7 +122,7 @@ def get_kernel_assignment(file_path: str, threshold: float) -> dict[ProblemShape return kernel_assignment -def get_heuristic(kernel_assignment: dict[ProblemShape, str]) -> Heuristic: +def get_heuristic(kernel_assignment: Dict[ProblemShape, str]) -> Heuristic: """Build hierarchical heuristic structure from kernel assignments.""" M_vals = sorted({problem_shape.M for problem_shape in kernel_assignment.keys()}) N_vals = sorted({problem_shape.N for problem_shape in kernel_assignment.keys()}) diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_custom_op_check.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_custom_op_check.py index 8263609d13..c029e0d66e 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_custom_op_check.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_custom_op_check.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import Optional, Tuple import fbgemm_gpu.experimental.gen_ai.attention.cutlass_blackwell_fmha # noqa @@ -24,7 +24,7 @@ def get_varlen_args( dtype: torch.dtype, causal: bool, fwd_only: bool = False, -) -> tuple[ +) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 758332b315..c1ff514270 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -6,7 +6,7 @@ import random import unittest -from typing import Optional +from typing import Optional, Tuple import hypothesis.strategies as st import torch @@ -95,7 +95,7 @@ def _generate_qkv( head_dim: int, device: torch.device, dtype: torch.dtype, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: q = torch.randn( batch_size, seqlen_q, diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test_deterministic.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test_deterministic.py index fe680812bf..f4fb846c2b 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test_deterministic.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test_deterministic.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Tuple import torch @@ -15,7 +16,7 @@ def _allclose( t_1: torch.Tensor, t_2: torch.Tensor, -) -> tuple[float, float]: +) -> Tuple[float, float]: diff = t_1 - t_2 return diff.abs().max().item(), diff.abs().sum().item() @@ -28,7 +29,7 @@ def _generate_inputs( kv_heads: int, head_dim: int, dtype: torch.dtype, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: device = torch.accelerator.current_accelerator() assert device is not None assert seqlen_q <= seqlen_k @@ -69,7 +70,7 @@ def _execute_cutlass_blackwell_attn_dense( v: torch.Tensor, g: torch.Tensor, causal: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Run tested kernel out = cutlass_blackwell_fmha_func(q, k, v, causal=causal, seqlen_kv=None) ( diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/gqa_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/gqa_test.py index 0d7b090743..ce55d0d3f8 100755 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/gqa_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/gqa_test.py @@ -9,6 +9,7 @@ import unittest from enum import Enum, unique +from typing import List, Tuple import fbgemm_gpu.experimental.gen_ai # noqa: F401 import hypothesis.strategies as st @@ -37,7 +38,7 @@ class LogicalDtype(Enum): def quant_int4_dequant_bf16( in_tensor: torch.Tensor, num_groups: int = 1 -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ A util function for quantizing a tensor from from a float type (including FP32, FP16, BF16) to INT4 and then dequantize the INT4 result to BF16 @@ -96,9 +97,9 @@ def gqa_reference( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, - seq_lens: list[int], + seq_lens: List[int], qk_scale: float, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ The reference GQA implementation """ @@ -266,7 +267,7 @@ def test_mqa_main( # noqa C901 self, dtype: str, num_groups: int, - args: tuple[int, int, int, int], + args: Tuple[int, int, int, int], mqa: bool, validate_p_inf_exp: bool, ) -> None: @@ -292,9 +293,9 @@ def mqa_reference( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, - seq_lens: list[int], + seq_lens: List[int], qk_scale: float, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: (B, T, H, D) = Q.shape (_, MAX_T, Hk, D) = K.shape (_, MAX_T, Hv, D) = V.shape diff --git a/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py b/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py index 6cf4e9a833..dc814c8d8c 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py @@ -14,7 +14,7 @@ import tempfile import unittest import uuid -from typing import Callable, Union +from typing import Callable, Dict, List, Tuple, Union import fbgemm_gpu.experimental.gen_ai # noqa: F401 @@ -46,7 +46,7 @@ def has_nvswitch() -> bool: return "GRANDTETON" in model or "SUPERMICRO" in model -def _setup(path: str) -> tuple[int, int]: +def _setup(path: str) -> Tuple[int, int]: rank = int(os.environ["LOCAL_RANK"]) W = int(os.environ["WORLD_SIZE"]) device = torch.device(f"cuda:{rank}") @@ -151,8 +151,8 @@ def _test_fn( ) rank_start = N // W * rank rank_end = N // W * (rank + 1) - args: list[torch.Tensor] = [y_reducescatter, y] - kwargs: dict[str, Union[bool, torch.Tensor]] = {} + args: List[torch.Tensor] = [y_reducescatter, y] + kwargs: Dict[str, Union[bool, torch.Tensor]] = {} if split_last_dim: kwargs["split_last_dim"] = True @@ -372,7 +372,7 @@ def test_allgather(self, dtype: torch.dtype) -> None: ) @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=100000) def test_allgather_dtype_mismatch( - self, dtypes: tuple[torch.dtype, torch.dtype] + self, dtypes: Tuple[torch.dtype, torch.dtype] ) -> None: dst_dtype, src_dtype = dtypes # float8 is only supported in H100 or MI300x diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py index 2e0824ef0f..d72e0f6e29 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py @@ -10,7 +10,7 @@ import logging import unittest from enum import Enum, unique -from typing import Optional +from typing import List, Optional, Tuple import fbgemm_gpu.experimental.gen_ai # noqa: F401 import torch @@ -40,8 +40,8 @@ class LogicalDtype(Enum): def _get_varseq_batch_seqpos( - seqlens_q: list[int], seqlens_kv: list[int], device: torch.device -) -> tuple[torch.Tensor, torch.Tensor]: + seqlens_q: List[int], seqlens_kv: List[int], device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: """ varseq_batch[i] is batch index of query i varseq_seqpos[i] is the offset of the last key which query i attends to diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/quantize_qkv_per_head_test.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/quantize_qkv_per_head_test.py index 60aef7d7b5..e9a609a06c 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/quantize_qkv_per_head_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/quantize_qkv_per_head_test.py @@ -9,7 +9,7 @@ import logging import unittest -from typing import Optional, Union +from typing import Optional, Tuple, Union import fbgemm_gpu.experimental.gen_ai # noqa: F401 import torch @@ -121,8 +121,8 @@ def quantize_qkv_per_head_python_reference( cache_K: Optional[torch.Tensor] = None, cache_V: Optional[torch.Tensor] = None, ) -> Union[ - tuple[torch.Tensor, torch.Tensor], - tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, @@ -293,7 +293,7 @@ def create_test_tensors( D_H: int, MAX_T: int, device: torch.device, -) -> tuple[ +) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py index 667289745b..c31d0afbac 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import Optional +from typing import Dict, Optional import torch @@ -46,7 +46,7 @@ from triton.language.extra.cuda.libdevice import pow -_INTERNAL_DTYPE_MAP: dict[str, int] = {"": 0, "f32": 1, "f64": 2} +_INTERNAL_DTYPE_MAP: Dict[str, int] = {"": 0, "f32": 1, "f64": 2} @triton.jit diff --git a/fbgemm_gpu/experimental/gen_ai/test/moe/activation_test.py b/fbgemm_gpu/experimental/gen_ai/test/moe/activation_test.py index dee42e93c6..bfc31e0718 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/moe/activation_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/moe/activation_test.py @@ -10,7 +10,7 @@ import logging import os import unittest -from typing import Optional +from typing import Optional, Tuple import torch import triton # noqa: F401 @@ -147,7 +147,7 @@ def test_silu_mul_quant( else: scale_ub_tensor = None - def fn() -> tuple[torch.Tensor, torch.Tensor]: + def fn() -> Tuple[torch.Tensor, torch.Tensor]: op = silu_mul_quant if compiled: op = torch.compile(op) @@ -156,7 +156,7 @@ def fn() -> tuple[torch.Tensor, torch.Tensor]: y_fp8, y_scale = fn() y = y_fp8.to(torch.float32) * y_scale[:, None] - def ref_fn() -> tuple[torch.Tensor, torch.Tensor]: + def ref_fn() -> Tuple[torch.Tensor, torch.Tensor]: x0_fp32 = x0.to(torch.float32) x1_fp32 = x1.to(torch.float32) y_fp32 = x0_fp32 * torch.sigmoid(x0_fp32) * x1_fp32 diff --git a/fbgemm_gpu/experimental/gen_ai/test/moe/gather_scatter_test.py b/fbgemm_gpu/experimental/gen_ai/test/moe/gather_scatter_test.py index bc697ab639..2dfc239a74 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/moe/gather_scatter_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/moe/gather_scatter_test.py @@ -10,7 +10,7 @@ import logging import random import unittest -from typing import Optional +from typing import Optional, Tuple import torch import triton # noqa: F401 @@ -171,7 +171,7 @@ def test_gather_scale_quant_dense_tokens( torch.arange(T).cuda() < num_valid_tokens, token_indices, -1 ) - def torch_fn() -> tuple[torch.Tensor, torch.Tensor]: + def torch_fn() -> Tuple[torch.Tensor, torch.Tensor]: shuffled_x = torch.index_select(x, dim=0, index=token_indices) shuffled_scores = torch.index_select(scores, dim=1, index=token_indices) shuffled_selected_scores = torch.gather( @@ -190,7 +190,7 @@ def torch_fn() -> tuple[torch.Tensor, torch.Tensor]: -1, 1 ) - def triton_fn() -> tuple[torch.Tensor, torch.Tensor]: + def triton_fn() -> Tuple[torch.Tensor, torch.Tensor]: scores_ = scores.contiguous().transpose(0, 1) if rowmajor: scores_ = scores_.contiguous() diff --git a/fbgemm_gpu/experimental/gen_ai/test/moe/layers_test.py b/fbgemm_gpu/experimental/gen_ai/test/moe/layers_test.py index a180c289d3..a8b67b3bb4 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/moe/layers_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/moe/layers_test.py @@ -13,6 +13,7 @@ import traceback from datetime import datetime from functools import partial +from typing import Tuple import torch @@ -100,7 +101,7 @@ def default_init_method(x: torch.Tensor) -> torch.Tensor: def fp8_rowwise_init_method( x: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: default_init_method(x) if x.ndim == 3: E, K, N = x.shape diff --git a/fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py b/fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py index bd025b655f..66d448e7f5 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/moe/shuffling_test.py @@ -11,7 +11,7 @@ import logging import random import unittest -from typing import Optional +from typing import List, Optional, Tuple import torch import triton # noqa: F401 @@ -106,7 +106,7 @@ def test_topk_index_shuffling( if not rowmajor: routing_scores = routing_scores.transpose(0, 1).contiguous().transpose(0, 1) - def fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def fn() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: op = index_shuffling if compiled: op = torch.compile(op, backend="inductor", fullgraph=True) @@ -121,7 +121,7 @@ def fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: top_k, ) - def ref_fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def ref_fn() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: valid_routing_scores = routing_scores[:num_valid_tokens].contiguous() selected_expert_indices = torch.topk(valid_routing_scores, top_k, dim=1)[1] expert_indices, flattened_position_indices = torch.sort( @@ -313,10 +313,10 @@ def generate_token_counts(token_per_rank: int) -> torch.Tensor: tokens = tokens.view(-1, dim) tokens[token_counts[:, expert_start:expert_end].sum() :, :] = torch.nan - token_counts_list: list[list[int]] = token_counts.tolist() - token_counts_t_list: list[list[int]] = token_counts.T.tolist() + token_counts_list: List[List[int]] = token_counts.tolist() + token_counts_t_list: List[List[int]] = token_counts.T.tolist() - def slice_tokens() -> tuple[torch.Tensor, ...]: + def slice_tokens() -> Tuple[torch.Tensor, ...]: if is_combine_shuffling: reshuffled_chunks = [[] for _ in range(num_local_experts)] # token_counts: [EP, E] @@ -341,9 +341,9 @@ def slice_tokens() -> tuple[torch.Tensor, ...]: offset += chunk_size return tuple(itertools.chain(*reshuffled_chunks)) - reshuffled_chunks: tuple[torch.Tensor, ...] = slice_tokens() + reshuffled_chunks: Tuple[torch.Tensor, ...] = slice_tokens() - def ref_fn() -> tuple[torch.Tensor, Optional[torch.Tensor]]: + def ref_fn() -> Tuple[torch.Tensor, Optional[torch.Tensor]]: cat_tokens = torch.cat(reshuffled_chunks) if is_combine_shuffling: counts = token_counts[:, expert_start:expert_end].sum(dim=0) @@ -355,7 +355,7 @@ def ref_fn() -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert not ref_output_tokens.isnan().any().item() - def fn() -> tuple[torch.Tensor, Optional[torch.Tensor]]: + def fn() -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if is_combine_shuffling: return combine_shuffling( tokens.view(-1, dim), diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 14a9b2cf3d..26b6cfdaac 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -10,7 +10,7 @@ import os import unittest -from typing import Optional, Union +from typing import Optional, Tuple, Union import fbgemm_gpu.experimental.gen_ai # noqa: F401 @@ -95,7 +95,7 @@ def evaluate_cuda_platform_version(major: int): open_source: bool = getattr(fbgemm_gpu, "open_source", False) -def fp8_row_quantize_ref(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def fp8_row_quantize_ref(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # Quantize an input tensor and return the fp8 tensor and its inverse scale. x_row_max = torch.max(torch.abs(x), dim=1).values max_scaling_factor = E4M3_MAX_POS * 512.0 # Match kernel logics @@ -104,7 +104,7 @@ def fp8_row_quantize_ref(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return xq, scale.reciprocal().to(torch.float32) -def fp8_col_quantize_ref(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def fp8_col_quantize_ref(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # Quantize an input tensor and return the fp8 tensor and its inverse scale. x_col_max = torch.max(torch.abs(x), dim=0).values max_scaling_factor = E4M3_MAX_POS * 512.0 # Match kernel logics @@ -116,7 +116,7 @@ def fp8_col_quantize_ref(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def int4_row_quantize( x: torch.Tensor, group_size: int = 128, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: n_bit = 4 # Number of target bits. to_quant = x.reshape(-1, group_size).to(torch.float) diff --git a/fbgemm_gpu/experimental/hstu/hstu/cuda_hstu_attention.py b/fbgemm_gpu/experimental/hstu/hstu/cuda_hstu_attention.py index c147044378..6c31a62501 100644 --- a/fbgemm_gpu/experimental/hstu/hstu/cuda_hstu_attention.py +++ b/fbgemm_gpu/experimental/hstu/hstu/cuda_hstu_attention.py @@ -9,7 +9,7 @@ # pyre-strict -from typing import Any, Optional +from typing import Any, Optional, Tuple from .library import * # noqa: F401, F403 import torch @@ -28,7 +28,7 @@ def forward( # pyre-ignore[14] num_contexts: torch.Tensor, num_targets: torch.Tensor, target_group_size: int, - window_size: tuple[int, int] = (-1, -1), + window_size: Tuple[int, int] = (-1, -1), alpha: float = 1.0, rab: Optional[torch.Tensor] = None, # need grad has_drab: bool = False, @@ -227,7 +227,7 @@ def hstu_attn_varlen_func( num_contexts: torch.Tensor, num_targets: torch.Tensor, target_group_size: int = 1, - window_size: tuple[int, int] = (-1, -1), + window_size: Tuple[int, int] = (-1, -1), alpha: float = 1.0, rab: Optional[torch.Tensor] = None, has_drab: bool = False, @@ -325,7 +325,7 @@ def cuda_hstu_attn_varlen( max_seqlen_q: int, max_seqlen_k: int, num_targets: torch.Tensor, - window_size: tuple[int, int] = (-1, -1), + window_size: Tuple[int, int] = (-1, -1), alpha: float = 1.0, is_train: bool = True, ) -> torch.Tensor: diff --git a/fbgemm_gpu/experimental/hstu/test/hstu_test.py b/fbgemm_gpu/experimental/hstu/test/hstu_test.py index 587e05e5e4..b41cf3c300 100755 --- a/fbgemm_gpu/experimental/hstu/test/hstu_test.py +++ b/fbgemm_gpu/experimental/hstu/test/hstu_test.py @@ -11,7 +11,7 @@ import math import os import unittest -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn.functional as F @@ -165,7 +165,7 @@ def generate_input( target_group_size: int, attn_dim: int, hidden_dim: int, - window_size: tuple[int, int], + window_size: Tuple[int, int], dtype: torch.dtype, full_batch: bool, has_drab: bool, @@ -518,11 +518,11 @@ def test_hstu_attn( batch_size: int, heads: int, max_context_len: int, - attn_hidden_dims: tuple[int, int], + attn_hidden_dims: Tuple[int, int], alpha: float, - rab_params: tuple[bool, bool, Optional[int]], - seq_len_params: tuple[int, int, bool], - target_params: tuple[int, tuple[int, int], int], + rab_params: Tuple[bool, bool, Optional[int]], + seq_len_params: Tuple[int, int, bool], + target_params: Tuple[int, Tuple[int, int], int], dtype: torch.dtype, full_batch: bool, ) -> None: @@ -871,11 +871,11 @@ def test_hstu_attn_fp8( self, batch_size: int, heads: int, - seq_len_params: tuple[int, int], - window_size: tuple[int, int], - attn_hidden_dims: tuple[int, int], + seq_len_params: Tuple[int, int], + window_size: Tuple[int, int], + attn_hidden_dims: Tuple[int, int], alpha: float, - rab_params: tuple[bool, bool, Optional[int]], + rab_params: Tuple[bool, bool, Optional[int]], full_batch: bool, dtype: torch.dtype, max_target_len: int, diff --git a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py index 18e372bba8..db2260df4d 100644 --- a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py @@ -9,6 +9,7 @@ from math import sqrt +from typing import List import torch @@ -21,7 +22,7 @@ load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") -def wrap_weight_to_parameter(weights: list[torch.Tensor]) -> list[torch.Tensor]: +def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]: for i, v in enumerate(weights): if not isinstance(v, torch.nn.Parameter): weights[i] = torch.nn.Parameter(v) @@ -30,7 +31,7 @@ def wrap_weight_to_parameter(weights: list[torch.Tensor]) -> list[torch.Tensor]: class BatchedUnaryEmbeddingBag(torch.nn.Module): # pyre-fixme[3]: Return type must be annotated. - def __init__(self, num_tasks: int, hash_sizes: list[int], long_index: bool = False): + def __init__(self, num_tasks: int, hash_sizes: List[int], long_index: bool = False): super().__init__() self.num_tasks = num_tasks self.hash_sizes = hash_sizes diff --git a/fbgemm_gpu/fbgemm_gpu/enums.py b/fbgemm_gpu/fbgemm_gpu/enums.py index 7de9246bda..174c030971 100644 --- a/fbgemm_gpu/fbgemm_gpu/enums.py +++ b/fbgemm_gpu/fbgemm_gpu/enums.py @@ -8,13 +8,14 @@ # pyre-strict import enum -from typing import Any, Callable +import typing +from typing import Any, Callable, List, Tuple # Create enums in given namespace with information from query_op def create_enums( - namespace: dict[str, Any], - query_op: Callable[[], list[tuple[str, list[tuple[str, int]]]]], + namespace: typing.Dict[str, Any], + query_op: Callable[[], List[Tuple[str, List[Tuple[str, int]]]]], ) -> None: for enum_name, items in query_op(): # Create matching python enumeration diff --git a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py index b9ef0fece5..2f26c35476 100644 --- a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py +++ b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py @@ -8,7 +8,7 @@ # pyre-strict from itertools import accumulate -from typing import Optional +from typing import List, Optional import torch @@ -93,8 +93,8 @@ class PermutePooledEmbeddings: def __init__( self, - embs_dims: list[int], - permute: list[int], + embs_dims: List[int], + permute: List[int], device: Optional[torch.device] = None, ) -> None: self._offset_dim_list: torch.Tensor = torch.tensor( @@ -105,7 +105,7 @@ def __init__( permute, device=device, dtype=torch.int64 ) - inv_permute: list[int] = [0] * len(permute) + inv_permute: List[int] = [0] * len(permute) for i, p in enumerate(permute): inv_permute[p] = i diff --git a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules_split.py b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules_split.py index a13e82c598..a7245c90f1 100644 --- a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules_split.py +++ b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules_split.py @@ -9,7 +9,7 @@ import logging from itertools import accumulate -from typing import Optional +from typing import List, Optional import torch from torch import nn @@ -34,8 +34,8 @@ def _fx_wrap_tensor_to_device(t: torch.Tensor, device: torch.device) -> torch.Te class PermutePooledEmbeddingsSplit(nn.Module): def __init__( self, - embs_dims: list[int], - permute: list[int], + embs_dims: List[int], + permute: List[int], device: Optional[torch.device] = None, ) -> None: super(PermutePooledEmbeddingsSplit, self).__init__() @@ -51,7 +51,7 @@ def __init__( "_permute", torch.tensor(permute, device=device, dtype=torch.int64) ) - inv_permute: list[int] = [0] * len(permute) + inv_permute: List[int] = [0] * len(permute) for i, p in enumerate(permute): inv_permute[p] = i diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index 3b5c6dfccc..7db84e4bf6 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -13,7 +13,7 @@ import logging -from typing import Optional, TypeVar +from typing import List, Optional, Tuple, TypeVar import torch @@ -67,7 +67,7 @@ class QuantizationContext: row_dim_quant: int = -1 mx_group_size: int = MX_GROUP_SIZE_DEFAULT rounding_mode: Optional[RoundingMode] = RoundingMode.even - padded_dim_sum_per_rank: Optional[list[int]] = None + padded_dim_sum_per_rank: Optional[List[int]] = None def _quantize_tensor( @@ -273,10 +273,10 @@ def create_context(self) -> Optional[QuantizationContext]: def padded_size( self, input_tensor: torch.Tensor, - dim_per_rank: list[int], + dim_per_rank: List[int], my_rank: int, qcomm_ctx: QuantizationContext, - ) -> tuple[int, int]: + ) -> Tuple[int, int]: if input_tensor.ndim == 1: return input_tensor.shape[0], 0 # return padded size for the feature dimension (dim 1), 0 if no padding needed. diff --git a/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py b/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py index 923c249442..7583fa57cb 100644 --- a/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py +++ b/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py @@ -12,7 +12,7 @@ from collections import deque from dataclasses import dataclass from types import TracebackType -from typing import Callable, Optional, TypeVar +from typing import Callable, Deque, Optional, Tuple, Type, TypeVar import torch @@ -171,7 +171,7 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: Optional[type[BaseException]], + exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: @@ -191,7 +191,7 @@ class AsyncSeriesTimer: """ def __init__(self, report_functor: Callable[[T, float], None]) -> None: - self._events_queue: deque[tuple[torch.cuda.Event, torch.cuda.Event, T]] = ( + self._events_queue: Deque[Tuple[torch.cuda.Event, torch.cuda.Event, T]] = ( deque() ) self._active_start_event: Optional[torch.cuda.Event] = None diff --git a/fbgemm_gpu/fbgemm_gpu/sll/cpu/cpu_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/cpu/cpu_sll.py index ae5d524203..f50260ae0e 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/cpu/cpu_sll.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/cpu/cpu_sll.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Any +from typing import Any, Tuple import torch @@ -65,7 +65,7 @@ def forward( # pyre-fixme def backward( ctx: Any, grad_output: torch.Tensor # pyre-ignore - ) -> tuple[torch.Tensor, torch.Tensor, None, None, None]: + ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: """ # X = [Sum_B, D] # Y = [B, D, T] @@ -128,7 +128,7 @@ def forward( # pyre-fixme def backward( ctx: Any, grad_output: torch.Tensor # pyre-ignore - ) -> tuple[torch.Tensor, torch.Tensor, None, None, None]: + ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: """ # X = [Sum_B, D] # Y = [Sum_B, T] @@ -172,7 +172,7 @@ def cpu_dense_jagged_cat_jagged_out( b: torch.Tensor, b_offsets: torch.Tensor, max_seq_len: int, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: assert a.size(0) == b_offsets.size(0) - 1 c = torch.empty(b.size(0) + a.size(0), dtype=a.dtype, device=a.device) c_offsets = b_offsets + torch.arange( @@ -368,7 +368,7 @@ def forward( # pyre-fixme def backward( ctx: Any, grad_output: torch.Tensor # pyre-ignore - ) -> tuple[torch.Tensor, None, None]: + ) -> Tuple[torch.Tensor, None, None]: y, x_offsets = ctx.saved_tensors B = x_offsets.size(0) - 1 @@ -923,7 +923,7 @@ def forward( def backward( ctx, # pyre-ignore grad_output: torch.Tensor, - ) -> tuple[torch.Tensor, None, torch.Tensor, None]: + ) -> Tuple[torch.Tensor, None, torch.Tensor, None]: (offsets,) = ctx.saved_tensors grad_dense = torch.ops.fbgemm.jagged_to_padded_dense( grad_output, [offsets], [ctx.max_seq_len] diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py index 14ab516016..dfeabbce3e 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py @@ -6,6 +6,7 @@ # pyre-unsafe +from typing import Tuple import torch import triton @@ -195,7 +196,7 @@ def forward( # pyre-fixme def backward( ctx, grad_output: torch.Tensor - ) -> tuple[torch.Tensor, None, None, None]: + ) -> Tuple[torch.Tensor, None, None, None]: max_length = ctx.max_length (lengths, offsets) = ctx.saved_tensors grad_in = padded_dense_to_jagged2_fwd(grad_output, lengths, offsets, max_length) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py index 3183831f12..a331ded2bf 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py @@ -6,6 +6,7 @@ # pyre-unsafe +from typing import Tuple import torch import triton @@ -170,7 +171,7 @@ def jagged_dense_flash_attention_fwd( jagged_offsets, max_seq_len, allow_tf32=False, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Q: jagged tensor, [sum_B, D] K: dense tensor, [B, D, T] @@ -649,7 +650,7 @@ def jagged_dense_flash_attention_bwd( jagged_offsets, max_seq_len, allow_tf32=False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Q: jagged tensor, [sum_B, D] K: dense tensor, [B, D, T] @@ -811,7 +812,7 @@ def forward( # pyre-fixme def backward( ctx, do: torch.Tensor - ) -> tuple[ + ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, None, None, None ]: Q, K, V, attn_bias, jagged_offsets, lse, attn_out = ctx.saved_tensors diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py index 15c05dd62c..7443ed934b 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py @@ -6,6 +6,7 @@ # pyre-unsafe +from typing import Tuple import torch import triton @@ -606,7 +607,7 @@ def forward( # pyre-fixme def backward( ctx, grad_output: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None]: ( jagged_Q, jagged_K, diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py index cfce271b27..5b42ef7a5f 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py @@ -6,6 +6,7 @@ # pyre-unsafe +from typing import Tuple import torch import triton @@ -687,7 +688,7 @@ def forward( # pyre-fixme def backward( ctx, grad_output: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None]: ( jagged_Q, jagged_K, diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index a3c93273c2..f351acbd07 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -7,8 +7,7 @@ # pyre-strict import math -from collections.abc import Sequence -from typing import Callable, Optional +from typing import Callable, List, Optional, Sequence, Tuple import torch @@ -75,7 +74,7 @@ def permute_2D_sparse_data_input1D_meta( stride: int, weights: Optional[Tensor] = None, permuted_lengths_sum: Optional[int] = None, -) -> tuple[Tensor, Tensor, Optional[Tensor]]: +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: torch._check( lengths.dim() == 1, lambda: f"expected lengths.dim() == 1, got {lengths.dim()}" ) @@ -112,7 +111,7 @@ def permute_2D_sparse_data_input1D_backward( grad_lengths: torch.Tensor, grad_values: torch.Tensor, grad_weights: torch.Tensor, -) -> tuple[None, Tensor, Tensor, None, Tensor, None]: +) -> Tuple[None, Tensor, Tensor, None, Tensor, None]: inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute) permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = ( torch.ops.fbgemm.permute_2D_sparse_data_input1D( @@ -140,7 +139,7 @@ def permute_2D_sparse_data_meta( values: Tensor, weights: Optional[Tensor] = None, permuted_lengths_sum: Optional[int] = None, -) -> tuple[Tensor, Tensor, Optional[Tensor]]: +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: torch._check( lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}" ) @@ -198,7 +197,7 @@ def permute_1D_sparse_data_meta( values: Tensor, weights: Optional[Tensor] = None, permuted_lengths_sum: Optional[int] = None, -) -> tuple[Tensor, Tensor, Optional[Tensor]]: +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: indices = values permuted_lengths_size = permute.numel() permuted_lengths = lengths.new_empty([permuted_lengths_size]) @@ -219,7 +218,7 @@ def permute_1D_sparse_data_meta( def masked_select_jagged_1d( values: Tensor, lengths: Tensor, mask: Tensor -) -> tuple[Tensor, Tensor]: +) -> Tuple[Tensor, Tensor]: torch._check(values.dim() == 1) torch._check(lengths.dim() == 1) torch._check(values.device == lengths.device) @@ -232,11 +231,11 @@ def masked_select_jagged_1d( def tbe_input_combine_abstract( - indices_list: list[Tensor], - offsets_list: list[Tensor], - per_sample_weights: list[Tensor], + indices_list: List[Tensor], + offsets_list: List[Tensor], + per_sample_weights: List[Tensor], include_last_offsets: Tensor, -) -> tuple[Tensor, Tensor, Tensor]: +) -> Tuple[Tensor, Tensor, Tensor]: torch._check(len(indices_list) > 0) torch._check(len(indices_list) == len(offsets_list)) torch._check(len(indices_list) == len(per_sample_weights)) @@ -269,10 +268,10 @@ def tbe_input_combine_abstract( def tbe_input_combine_with_length_abstract( - indices_list: list[Tensor], - offsets_list: list[Tensor], - per_sample_weights: list[Tensor], -) -> tuple[Tensor, Tensor, Tensor]: + indices_list: List[Tensor], + offsets_list: List[Tensor], + per_sample_weights: List[Tensor], +) -> Tuple[Tensor, Tensor, Tensor]: torch._check(len(indices_list) > 0) torch._check(len(indices_list) == len(offsets_list)) torch._check(len(indices_list) == len(per_sample_weights)) @@ -340,7 +339,7 @@ def expand_into_jagged_permute_meta( permute: Tensor, input_offsets: Tensor, output_offsets: Tensor, - output_size: tuple[int, ...], + output_size: Tuple[int, ...], ) -> Tensor: torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0") torch._check( @@ -466,7 +465,7 @@ def block_bucketize_sparse_features_meta( keep_orig_idx: bool = False, total_num_blocks: Optional[torch.Tensor] = None, keep_orig_idx_per_feature: Optional[torch.Tensor] = None, -) -> tuple[ +) -> Tuple[ torch.Tensor, torch.Tensor, Optional[torch.Tensor], @@ -501,7 +500,7 @@ def block_bucketize_sparse_features_2d_weights_meta( keep_orig_idx: bool = False, total_num_blocks: Optional[torch.Tensor] = None, keep_orig_idx_per_feature: Optional[torch.Tensor] = None, -) -> tuple[ +) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, @@ -522,7 +521,7 @@ def block_bucketize_sparse_features_2d_weights_meta( def merge_pooled_embeddings( - pooled_embeddings: list[torch.Tensor], + pooled_embeddings: List[torch.Tensor], uncat_dim_size: int, target_device: torch.device, cat_dim: int = 1, @@ -553,7 +552,7 @@ def merge_pooled_embeddings( def permute_sparse_features_abstract( permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None -) -> tuple[Tensor, Tensor, Optional[Tensor]]: +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: torch._check(lengths.dtype == indices.dtype) torch._check(permute.device == lengths.device) torch._check(permute.device == indices.device) @@ -584,7 +583,7 @@ def segment_sum_csr_abstract( def dense_to_jagged_forward( dense: torch.Tensor, - offsets: list[torch.Tensor], + offsets: List[torch.Tensor], total_L: Optional[torch.SymInt] = None, ) -> torch.Tensor: if total_L is None: @@ -599,9 +598,9 @@ def dense_to_jagged_forward( def dense_to_jagged( dense: torch.Tensor, - offsets: list[torch.Tensor], + offsets: List[torch.Tensor], total_L: Optional[torch.SymInt] = None, -) -> tuple[torch.Tensor, list[torch.Tensor]]: +) -> Tuple[torch.Tensor, List[torch.Tensor]]: if total_L is None: total_L = torch.library.get_ctx().new_dynamic_size() return (dense_to_jagged_forward(dense, offsets, total_L), offsets) @@ -610,9 +609,9 @@ def dense_to_jagged( def batch_index_select_dim0_abstract( inputs: torch.Tensor, indices: torch.Tensor, - input_num_indices: list[int], - input_rows: list[int], - input_columns: list[int], + input_num_indices: List[int], + input_rows: List[int], + input_columns: List[int], permute_output_dim_0_1: bool, ) -> torch.Tensor: """ @@ -654,11 +653,11 @@ def batch_index_select_dim0_tensor_abstract( def batch_index_select_dim0_forward_cuda_impl_abstract( inputs: torch.Tensor, indices: torch.Tensor, - input_num_indices: list[int], - input_rows: list[int], - input_columns: list[int], + input_num_indices: List[int], + input_rows: List[int], + input_columns: List[int], permute_output_dim_0_1: bool, -) -> list[torch.Tensor]: +) -> List[torch.Tensor]: num_inputs = len(input_rows) torch._check(len(input_num_indices) == len(input_rows)) torch._check(len(input_num_indices) == len(input_columns)) @@ -695,7 +694,7 @@ def batch_index_select_dim0_tensor_forward_cuda_impl_abstract( input_rows: torch.Tensor, input_columns: torch.Tensor, permute_output_dim_0_1: bool, -) -> list[torch.Tensor]: +) -> List[torch.Tensor]: num_inputs: int = input_rows.size(0) torch._check(input_num_indices.size(0) == input_rows.size(0)) torch._check(input_num_indices.size(0) == input_columns.size(0)) @@ -740,7 +739,7 @@ def keyed_jagged_index_select_dim1_abstract( batch_size: torch.SymInt, weights: Optional[torch.Tensor] = None, selected_lengths_sum: Optional[torch.SymInt] = None, -) -> list[torch.Tensor]: +) -> List[torch.Tensor]: """ This meta function is used to calculate the shape of output tensors from the original function `fbgemm::keyed_jagged_index_select_dim1` without the actual data. @@ -765,7 +764,7 @@ def keyed_jagged_index_select_dim1_abstract( torch.index_select(lengths, 0, length_indices).sum().item() ) - ret: list[torch.Tensor] = [ + ret: List[torch.Tensor] = [ # pyre-ignore values.new_empty([selected_lengths_sum]), lengths.new_empty([indices.shape[0] * num_batches]), @@ -797,11 +796,11 @@ def batch_index_select_dim0_backward_cuda_impl_abstract( def batch_index_select_dim0_forward_cpu_impl_abstract( inputs: torch.Tensor, indices: torch.Tensor, - input_num_indices: list[int], - input_rows: list[int], - input_columns: list[int], + input_num_indices: List[int], + input_rows: List[int], + input_columns: List[int], permute_output_dim_0_1: bool, -) -> list[torch.Tensor]: +) -> List[torch.Tensor]: # input lists must have the same length num_inputs = len(input_num_indices) torch._check(num_inputs == len(input_rows)) @@ -831,7 +830,7 @@ def batch_index_select_dim0_tensor_forward_cpu_impl_abstract( input_rows: torch.Tensor, input_columns: torch.Tensor, permute_output_dim_0_1: bool, -) -> list[torch.Tensor]: +) -> List[torch.Tensor]: # input lists must have the same length num_inputs = len(input_num_indices) torch._check(num_inputs == len(input_rows)) @@ -881,8 +880,8 @@ def bounds_check_indices_abstract( def group_index_select_dim0_gpu_impl_abstract( - inputs: list[torch.Tensor], group_size: int -) -> list[torch.Tensor]: + inputs: List[torch.Tensor], group_size: int +) -> List[torch.Tensor]: """ Calculate output shapes for group_index_select_dim0_gpu_impl without the actual data. @@ -912,8 +911,8 @@ def group_index_select_dim0_gpu_impl_abstract( def group_index_select_dim0_gpu_backward_abstract( - all_inputs: list[torch.Tensor], output_shape_group_ref: list[torch.SymInt] -) -> list[torch.Tensor]: + all_inputs: List[torch.Tensor], output_shape_group_ref: List[torch.SymInt] +) -> List[torch.Tensor]: """ Calculate output shapes for group_index_select_dim0_gpu_backward without the actual data. @@ -946,7 +945,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract( batch_size: torch.SymInt, weights: Optional[torch.Tensor] = None, selected_lengths_sum: Optional[torch.SymInt] = None, -) -> list[torch.Tensor]: +) -> List[torch.Tensor]: num_batches = lengths.size(0) // batch_size torch._check(lengths.size(0) + 1 == offsets.size(0)) # pyre-ignore @@ -960,7 +959,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract( selected_lengths_sum = torch.library.get_ctx().new_dynamic_size() torch._check_is_size(selected_lengths_sum) - vlw: list[torch.Tensor] = [ + vlw: List[torch.Tensor] = [ values.new_empty([selected_lengths_sum]), # output lengths.new_empty([indices.shape[0] * num_batches]), # output_lengths ] @@ -1003,7 +1002,7 @@ def histogram_binning_calibration_abstract( upper_bound: float, bin_ctr_in_use_after: int, bin_ctr_weight_value: float, -) -> tuple[Tensor, Tensor]: +) -> Tuple[Tensor, Tensor]: return torch.empty_like(logit), torch.empty([logit.numel()], dtype=torch.int64) @@ -1154,7 +1153,7 @@ def generic_histogram_binning_calibration_by_feature( positive_weight: float, bin_ctr_in_use_after: int, bin_ctr_weight_value: float, -) -> tuple[Tensor, Tensor]: +) -> Tuple[Tensor, Tensor]: torch._check(bin_num_examples.numel() == bin_num_positives.numel()) torch._check( bin_num_examples.numel() == (num_segments + 1) * (bin_boundaries.numel() + 1) @@ -1165,13 +1164,13 @@ def generic_histogram_binning_calibration_by_feature( def permute_multi_embedding_function_impl_abstract( - pooled_embs: list[Tensor], + pooled_embs: List[Tensor], permutes: Tensor, in_shapes: Tensor, out_shapes: Tensor, - out_lengths: list[int], + out_lengths: List[int], reverse: bool = False, -) -> list[Tensor]: +) -> List[Tensor]: out_dtype = pooled_embs[0].dtype bs = pooled_embs[0].shape[0] torch._check(permutes.shape[1] == 6, lambda: "permutes must have 6 columns") @@ -1197,9 +1196,9 @@ def lengths_range_abstract( def all_to_one_device( - input_tensors: list[Tensor], + input_tensors: List[Tensor], target_device: torch.device, -) -> list[Tensor]: +) -> List[Tensor]: return [ torch.empty_like(input_tensor, device=torch.device("meta")) for input_tensor in input_tensors diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index a2f41ac35d..dc78676ba7 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -9,7 +9,7 @@ import enum import itertools -from typing import Any, Dict # noqa: F401 +from typing import Any, Dict, List, Optional, Tuple # noqa: F401 import torch @@ -81,13 +81,13 @@ def __str__(self) -> str: return self.value def _extract_dtype( - self, optimizer_state_dtypes: dict[str, "SparseType"], name: str + self, optimizer_state_dtypes: Dict[str, "SparseType"], name: str ) -> torch.dtype: if optimizer_state_dtypes is None or name not in optimizer_state_dtypes: return torch.float32 return optimizer_state_dtypes[name].as_dtype() - def state_names(self) -> list[str]: + def state_names(self) -> List[str]: """ Returns the names of the optimizer states. The order of the states will be the order in which they are processed and returned in @@ -101,7 +101,7 @@ def state_names(self) -> list[str]: else: return [] - def state_size_table(self, D: int) -> dict[str, int]: + def state_size_table(self, D: int) -> Dict[str, int]: """ Returns the table of state names to state sizes in terms of number of elements (per table row) @@ -118,7 +118,7 @@ def state_size_table(self, D: int) -> dict[str, int]: def state_size_nbytes( self, D: int, - optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006 + optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006 ) -> int: """ Returns the size of the data (in bytes) required to hold the optimizer @@ -143,8 +143,8 @@ def byte_offsets_along_row( self, D: int, weights_precision: "SparseType", - optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006 - ) -> dict[str, tuple[int, int]]: + optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006 + ) -> Dict[str, Tuple[int, int]]: """ Returns the start and end byte offsets of each optimizer state along a cache row with optimizer state offloading enabled. @@ -184,10 +184,10 @@ def byte_offsets_along_row( def empty_states( self, - rows: list[int], - dims: list[int], - optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006 - ) -> list[list[torch.Tensor]]: + rows: List[int], + dims: List[int], + optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006 + ) -> List[List[torch.Tensor]]: """ Creates sets of empty tensors per table to hold optimizer states based on the specified optimizer type, state dtypes, embedding specs, and @@ -196,7 +196,7 @@ def empty_states( # Else, check that the local row count for each table is set assert len(rows) == len(dims) - opt_states_set: list[list[torch.Tensor]] = [] + opt_states_set: List[List[torch.Tensor]] = [] for r, D in zip(rows, dims): # Set up the table of state names to state sizes, ordered by their @@ -223,10 +223,10 @@ def empty_states( def ssd_state_splits( self, - embedding_specs: list[tuple[int, int]], # Tuple of (rows, dims) - optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006 + embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims) + optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006 enable_optimizer_offloading: bool = False, - ) -> list[tuple[SplitState, str, torch.dtype]]: + ) -> List[Tuple[SplitState, str, torch.dtype]]: """ Returns the split planning for the optimizer states """ @@ -234,9 +234,9 @@ def ssd_state_splits( T_ = len(embedding_specs) # This is the cumulative row counts for rowwise states - row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows)) + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows)) # This is the cumulative element counts for elementwise states - table_size_cumsum: list[int] = [0] + list( + table_size_cumsum: List[int] = [0] + list( itertools.accumulate([r * d for r, d in embedding_specs]) ) @@ -441,7 +441,7 @@ def default_config(self) -> QuantizationConfig: return QuantizationConfig() -ELEMENT_SIZE: dict[SparseType, int] = { +ELEMENT_SIZE: Dict[SparseType, int] = { SparseType.FP32: 4, SparseType.FP16: 2, SparseType.FP8: 1, diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py index 1f1434f614..500157b4cd 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py @@ -10,7 +10,7 @@ import logging import math -from typing import cast, Optional +from typing import cast, Optional, Tuple import torch @@ -53,7 +53,7 @@ def convert_model(self, model: torch.nn.Module) -> torch.nn.Module: return model # pyre-fixme[2]: Parameter must be annotated. - def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> tuple[Tensor, float]: + def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> Tuple[Tensor, float]: assert new_num_rows > 0 from numpy.linalg import norm @@ -75,7 +75,7 @@ def _prune_embs( idx: int, num_rows: int, module: SplitTableBatchedEmbeddingBagsCodegen, - ) -> tuple[Tensor, Optional[Tensor]]: + ) -> Tuple[Tensor, Optional[Tensor]]: # TODO(yingz): Avoid DtoH / HtoD overhead. weights = module.split_embedding_weights()[idx].cpu() if self.pruning_ratio is None: @@ -100,7 +100,7 @@ def _get_quantization_config(self, name): def _quantize_embs( self, weight: Tensor, weight_ty: SparseType - ) -> tuple[Tensor, Optional[Tensor]]: + ) -> Tuple[Tensor, Optional[Tensor]]: fp8_quant_config = cast(FP8QuantizationConfig, self.quantization_config) return quantize_embs(weight, weight_ty, fp8_quant_config) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 4d55ed2738..4f4ef46f4b 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -11,7 +11,7 @@ import enum from dataclasses import dataclass -from typing import NamedTuple, Optional +from typing import List, NamedTuple, Optional, Tuple import torch from torch import Tensor @@ -73,25 +73,25 @@ class EvictionPolicy(NamedTuple): eviction_mem_threshold_gb: Optional[int] = ( None # eviction trigger condition if trigger mode is mem_util ) - counter_thresholds: Optional[list[int]] = ( + counter_thresholds: Optional[List[int]] = ( None # count_thresholds for each table if eviction strategy is counter ) - ttls_in_mins: Optional[list[int]] = ( + ttls_in_mins: Optional[List[int]] = ( None # ttls_in_mins for each table if eviction strategy is timestamp ) - counter_decay_rates: Optional[list[float]] = ( + counter_decay_rates: Optional[List[float]] = ( None # count_decay_rates for each table if eviction strategy is counter ) - feature_score_counter_decay_rates: Optional[list[float]] = ( + feature_score_counter_decay_rates: Optional[List[float]] = ( None # feature_score_counter_decay_rates for each table if eviction strategy is feature score ) - training_id_eviction_trigger_count: Optional[list[int]] = ( + training_id_eviction_trigger_count: Optional[List[int]] = ( None # training_id_eviction_trigger_count for each table ) - training_id_keep_count: Optional[list[int]] = ( + training_id_keep_count: Optional[List[int]] = ( None # training_id_keep_count for each table ) - l2_weight_thresholds: Optional[list[float]] = ( + l2_weight_thresholds: Optional[List[float]] = ( None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm ) threshold_calculation_bucket_stride: Optional[float] = ( @@ -113,7 +113,7 @@ class EvictionPolicy(NamedTuple): interval_for_feature_statistics_decay_s: int = ( 24 * 3600 # 1 day, interval for feature statistics decay ) - meta_header_lens: Optional[list[int]] = None # metaheader length for each table + meta_header_lens: Optional[List[int]] = None # metaheader length for each table def validate(self) -> None: assert self.eviction_trigger_mode in [0, 1, 2, 3, 4], ( @@ -217,10 +217,10 @@ def validate(self) -> None: class KVZCHParams(NamedTuple): # global bucket id start and global bucket id end offsets for each logical table, # where start offset is inclusive and end offset is exclusive - bucket_offsets: list[tuple[int, int]] = [] + bucket_offsets: List[Tuple[int, int]] = [] # bucket size for each logical table # the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets - bucket_sizes: list[int] = [] + bucket_sizes: List[int] = [] # enable optimizer offloading or not enable_optimizer_offloading: bool = False # when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only @@ -340,8 +340,8 @@ class EmbeddingSpecInfo(enum.IntEnum): ("dev_size", int), ("host_size", int), ("uvm_size", int), - ("placements", list[EmbeddingLocation]), - ("offsets", list[int]), + ("placements", List[EmbeddingLocation]), + ("offsets", List[int]), ], ) @@ -349,15 +349,15 @@ class EmbeddingSpecInfo(enum.IntEnum): @dataclass class CacheState: # T + 1 elements and cache_hash_size_cumsum[-1] == total_cache_hash_size - cache_hash_size_cumsum: list[int] - cache_index_table_map: list[int] + cache_hash_size_cumsum: List[int] + cache_index_table_map: List[int] total_cache_hash_size: int def construct_cache_state( - row_list: list[int], - location_list: list[EmbeddingLocation], - feature_table_map: list[int], + row_list: List[int], + location_list: List[EmbeddingLocation], + feature_table_map: List[int], ) -> CacheState: _cache_hash_size_cumsum = [0] total_cache_hash_size = 0 diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index f036c3ce74..4f4399bc30 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -12,7 +12,7 @@ import logging import uuid from itertools import accumulate -from typing import Optional, Union +from typing import List, Optional, Tuple, Union import fbgemm_gpu # noqa: F401 import torch # usort:skip @@ -92,14 +92,14 @@ def align_to_cacheline(a: int) -> int: def nbit_construct_split_state( - embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]], + embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], cacheable: bool, row_alignment: int, scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES, cacheline_alignment: bool = True, ) -> SplitState: - placements = torch.jit.annotate(list[EmbeddingLocation], []) - offsets = torch.jit.annotate(list[int], []) + placements = torch.jit.annotate(List[EmbeddingLocation], []) + offsets = torch.jit.annotate(List[int], []) dev_size = 0 host_size = 0 uvm_size = 0 @@ -165,7 +165,7 @@ def inputs_to_device( offsets: torch.Tensor, per_sample_weights: Optional[torch.Tensor], bounds_check_warning: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if bounds_check_warning.device.type == "meta": return indices, offsets, per_sample_weights @@ -331,7 +331,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module): Options are `torch.int32` and `torch.int64`. """ - embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]] + embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]] record_cache_metrics: RecordCacheMetrics # pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized. cache_miss_counter: torch.Tensor @@ -346,15 +346,15 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module): def __init__( # noqa C901 self, - embedding_specs: list[ - tuple[str, int, int, SparseType, EmbeddingLocation] + embedding_specs: List[ + Tuple[str, int, int, SparseType, EmbeddingLocation] ], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement) - feature_table_map: Optional[list[int]] = None, # [T] - index_remapping: Optional[list[Tensor]] = None, + feature_table_map: Optional[List[int]] = None, # [T] + index_remapping: Optional[List[Tensor]] = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: Optional[Union[str, int, torch.device]] = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, - weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None, + weight_lists: Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, @@ -373,7 +373,7 @@ def __init__( # noqa C901 cacheline_alignment: bool = True, uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged. reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row. - feature_names_per_table: Optional[list[list[str]]] = None, + feature_names_per_table: Optional[List[List[str]]] = None, indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64). ) -> None: # noqa C901 # tuple of (rows, dims,) super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -406,14 +406,14 @@ def __init__( # noqa C901 self.indices_dtype = indices_dtype # (feature_names, rows, dims, weights_tys, locations) = zip(*embedding_specs) # Pyre workaround - self.feature_names: list[str] = [e[0] for e in embedding_specs] + self.feature_names: List[str] = [e[0] for e in embedding_specs] self.cache_load_factor: float = cache_load_factor self.cache_sets: int = cache_sets self.cache_reserved_memory: float = cache_reserved_memory - rows: list[int] = [e[1] for e in embedding_specs] - dims: list[int] = [e[2] for e in embedding_specs] - weights_tys: list[SparseType] = [e[3] for e in embedding_specs] - locations: list[EmbeddingLocation] = [e[4] for e in embedding_specs] + rows: List[int] = [e[1] for e in embedding_specs] + dims: List[int] = [e[2] for e in embedding_specs] + weights_tys: List[SparseType] = [e[3] for e in embedding_specs] + locations: List[EmbeddingLocation] = [e[4] for e in embedding_specs] # if target device is meta then we set use_cpu based on the embedding location # information in embedding_specs. if self.current_device.type == "meta": @@ -453,7 +453,7 @@ def __init__( # noqa C901 T_ = len(self.embedding_specs) assert T_ > 0 - self.feature_table_map: list[int] = ( + self.feature_table_map: List[int] = ( feature_table_map if feature_table_map is not None else list(range(T_)) ) T = len(self.feature_table_map) @@ -676,7 +676,7 @@ def get_table_wise_cache_miss(self) -> Tensor: return self.table_wise_cache_miss @torch.jit.export - def get_feature_num_per_table(self) -> list[int]: + def get_feature_num_per_table(self) -> List[int]: if self.feature_names_per_table is None: return [] return [len(feature_names) for feature_names in self.feature_names_per_table] @@ -1211,8 +1211,8 @@ def _apply_split( dev_size: int, host_size: int, uvm_size: int, - placements: list[int], - offsets: list[int], + placements: List[int], + offsets: List[int], enforce_hbm: bool, ) -> None: assert not self.weight_initialized, "Weights have already been initialized." @@ -1602,7 +1602,7 @@ def update_cache_load_factor(self, cache_load_factor: float = 0.2) -> None: @torch.jit.export def split_embedding_weights_with_scale_bias( self, split_scale_bias_mode: int = 1 - ) -> list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]]: + ) -> List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]]: """ Returns a list of weights, split by table split_scale_bias_mode: @@ -1611,7 +1611,7 @@ def split_embedding_weights_with_scale_bias( 2: return weights, scale, bias. """ assert self.weight_initialized - splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = [] + splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = [] for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs): placement = self.weights_physical_placements[t] if ( @@ -1736,12 +1736,12 @@ def split_embedding_weights( # the second with scale_bias. # This should've been named as split_scale_bias. # Keep as is for backward compatibility. - ) -> list[tuple[Tensor, Optional[Tensor]]]: + ) -> List[Tuple[Tensor, Optional[Tensor]]]: """ Returns a list of weights, split by table """ # fmt: off - splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = ( + splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = ( self.split_embedding_weights_with_scale_bias( split_scale_bias_mode=(1 if split_scale_shifts else 0) ) @@ -1779,7 +1779,7 @@ def fill_random_weights(self) -> None: ) def assign_embedding_weights( - self, q_weight_list: list[tuple[Tensor, Optional[Tensor]]] + self, q_weight_list: List[Tuple[Tensor, Optional[Tensor]]] ) -> None: """ Assigns self.split_embedding_weights() with values from the input list of weights and scale_shifts. @@ -1799,11 +1799,11 @@ def assign_embedding_weights( @torch.jit.export def set_index_remappings_array( self, - index_remapping: list[Tensor], + index_remapping: List[Tensor], ) -> None: - rows: list[int] = [e[1] for e in self.embedding_specs] + rows: List[int] = [e[1] for e in self.embedding_specs] index_remappings_array_offsets = [0] - original_feature_rows = torch.jit.annotate(list[int], []) + original_feature_rows = torch.jit.annotate(List[int], []) last_offset = 0 for t, mapping in enumerate(index_remapping): if mapping is not None: @@ -1842,11 +1842,11 @@ def set_index_remappings_array( def set_index_remappings( self, - index_remapping: list[Tensor], + index_remapping: List[Tensor], pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, ) -> None: - rows: list[int] = [e[1] for e in self.embedding_specs] + rows: List[int] = [e[1] for e in self.embedding_specs] T = len(self.embedding_specs) # Hash mapping pruning if not use_array_for_index_remapping: @@ -1916,7 +1916,7 @@ def set_index_remappings( def _embedding_inplace_update_per_table( self, update_table_idx: int, - update_row_indices: list[int], + update_row_indices: List[int], update_weights: Tensor, ) -> None: row_size = len(update_row_indices) @@ -1941,9 +1941,9 @@ def _embedding_inplace_update_per_table( @torch.jit.export def embedding_inplace_update( self, - update_table_indices: list[int], - update_row_indices: list[list[int]], - update_weights: list[Tensor], + update_table_indices: List[int], + update_row_indices: List[List[int]], + update_weights: List[Tensor], ) -> None: for i in range(len(update_table_indices)): self._embedding_inplace_update_per_table( @@ -1954,8 +1954,8 @@ def embedding_inplace_update( def embedding_inplace_update_internal( self, - update_table_indices: list[int], - update_row_indices: list[int], + update_table_indices: List[int], + update_row_indices: List[int], update_weights: Tensor, ) -> None: assert len(update_table_indices) == len(update_row_indices) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 1b45f7f147..d075c12c8d 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, field from itertools import accumulate from math import log2 -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch # usort:skip from torch import nn, Tensor # usort:skip @@ -191,25 +191,25 @@ class UVMCacheStatsIndex(enum.IntEnum): class RESParams: res_server_port: int = 0 # the port of the res server res_store_shards: int = 1 # the number of shards to store the raw embeddings - table_names: list[str] = field(default_factory=list) # table names the TBE holds - table_offsets: list[int] = field( + table_names: List[str] = field(default_factory=list) # table names the TBE holds + table_offsets: List[int] = field( default_factory=list ) # table offsets for the global rows the TBE holds - table_sizes: list[int] = field( + table_sizes: List[int] = field( default_factory=list ) # table sizes for the global rows the TBE holds def construct_split_state( - embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]], + embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]], rowwise: bool, cacheable: bool, precision: SparseType = SparseType.FP32, int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET, placement: Optional[EmbeddingLocation] = None, ) -> SplitState: - placements: list[EmbeddingLocation] = [] - offsets: list[int] = [] + placements: List[EmbeddingLocation] = [] + offsets: List[int] = [] dev_size: int = 0 host_size: int = 0 uvm_size: int = 0 @@ -251,18 +251,18 @@ def construct_split_state( def apply_split_helper( persistent_state_fn: Callable[[str, Tensor], None], set_attr_fn: Callable[ - [str, Union[Tensor, list[int], list[EmbeddingLocation]]], None + [str, Union[Tensor, List[int], List[EmbeddingLocation]]], None ], current_device: torch.device, use_cpu: bool, - feature_table_map: list[int], + feature_table_map: List[int], split: SplitState, prefix: str, - dtype: type[torch.dtype], + dtype: Type[torch.dtype], enforce_hbm: bool = False, make_dev_param: bool = False, - dev_reshape: Optional[tuple[int, ...]] = None, - uvm_tensors_log: Optional[list[str]] = None, + dev_reshape: Optional[Tuple[int, ...]] = None, + uvm_tensors_log: Optional[List[str]] = None, uvm_host_mapped: bool = False, ) -> None: set_attr_fn(f"{prefix}_physical_placements", split.placements) @@ -622,12 +622,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): (preshard_table_height, preshard_table_dim, height_offset, dim_offset) """ - embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]] + embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]] optimizer_args: invokers.lookup_args.OptimizerArgs - lxu_cache_locations_list: list[Tensor] + lxu_cache_locations_list: List[Tensor] lxu_cache_locations_empty: Tensor - timesteps_prefetched: list[int] - prefetched_info: list[tuple[Tensor, Tensor, Optional[Tensor]]] + timesteps_prefetched: List[int] + prefetched_info: List[Tuple[Tensor, Tensor, Optional[Tensor]]] record_cache_metrics: RecordCacheMetrics # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized. uvm_cache_stats: torch.Tensor @@ -641,10 +641,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): def __init__( # noqa C901 self, - embedding_specs: list[ - tuple[int, int, EmbeddingLocation, ComputeDevice] + embedding_specs: List[ + Tuple[int, int, EmbeddingLocation, ComputeDevice] ], # tuple of (rows, dims, placements, compute_devices) - feature_table_map: Optional[list[int]] = None, # [T] + feature_table_map: Optional[List[int]] = None, # [T] cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, @@ -682,8 +682,8 @@ def __init__( # noqa C901 use_experimental_tbe: bool = False, prefetch_pipeline: bool = False, stats_reporter_config: Optional[TBEStatsReporterConfig] = None, - table_names: Optional[list[str]] = None, - optimizer_state_dtypes: Optional[dict[str, SparseType]] = None, + table_names: Optional[List[str]] = None, + optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None, multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, global_weight_decay: Optional[GlobalWeightDecayDefinition] = None, uvm_host_mapped: bool = False, @@ -691,7 +691,7 @@ def __init__( # noqa C901 tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None, embedding_table_index_type: torch.dtype = torch.int64, embedding_table_offset_type: torch.dtype = torch.int64, - embedding_shard_info: Optional[list[tuple[int, int, int, int]]] = None, + embedding_shard_info: Optional[List[Tuple[int, int, int, int]]] = None, enable_raw_embedding_streaming: bool = False, res_params: Optional[RESParams] = None, ) -> None: @@ -800,7 +800,7 @@ def __init__( # noqa C901 self.embedding_specs = embedding_specs (rows, dims, locations, compute_devices) = zip(*embedding_specs) T_ = len(self.embedding_specs) - self.dims: list[int] = dims + self.dims: List[int] = dims assert T_ > 0 # mixed D is not supported by no bag kernels mixed_D = False @@ -877,7 +877,7 @@ def __init__( # noqa C901 self.stats_reporter: Optional[TBEStatsReporter] = ( stats_reporter_config.create_reporter() if stats_reporter_config else None ) - self._uvm_tensors_log: list[str] = [] + self._uvm_tensors_log: List[str] = [] self.bwd_wait_prefetch_timer: Optional[AsyncSeriesTimer] = None self.prefetch_duration_timer: Optional[AsyncSeriesTimer] = None @@ -904,7 +904,7 @@ def __init__( # noqa C901 self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET - self.feature_table_map: list[int] = ( + self.feature_table_map: List[int] = ( feature_table_map if feature_table_map is not None else list(range(T_)) ) @@ -1110,13 +1110,13 @@ def __init__( # noqa C901 if ensemble_mode is None: ensemble_mode = EnsembleModeDefinition() - self._ensemble_mode: dict[str, float] = { + self._ensemble_mode: Dict[str, float] = { key: float(fval) for key, fval in ensemble_mode.__dict__.items() } if emainplace_mode is None: emainplace_mode = EmainplaceModeDefinition() - self._emainplace_mode: dict[str, float] = { + self._emainplace_mode: Dict[str, float] = { key: float(fval) for key, fval in emainplace_mode.__dict__.items() } @@ -1421,7 +1421,7 @@ def __init__( # noqa C901 self.step = 0 self.last_reported_step = 0 - self.last_reported_uvm_stats: list[float] = [] + self.last_reported_uvm_stats: List[float] = [] # Check whether to use TBE v2 is_experimental = False @@ -1470,8 +1470,8 @@ def __init__( # noqa C901 ) self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type - self.prefetched_info: list[tuple[Tensor, Tensor, Optional[Tensor]]] = ( - torch.jit.annotate(list[tuple[Tensor, Tensor, Optional[Tensor]]], []) + self.prefetched_info: List[Tuple[Tensor, Tensor, Optional[Tensor]]] = ( + torch.jit.annotate(List[Tuple[Tensor, Tensor, Optional[Tensor]]], []) ) if self.enable_raw_embedding_streaming: self.res_params: RESParams = res_params or RESParams() @@ -1537,7 +1537,7 @@ def _register_nonpersistent_buffers(self, prefix: str) -> None: ) @staticmethod - def get_table_name_for_logging(table_names: Optional[list[str]]) -> str: + def get_table_name_for_logging(table_names: Optional[List[str]]) -> str: """ Given a list of all table names in the TBE, generate a string to represent them in logging. If there is more than one table, this method @@ -1563,7 +1563,7 @@ def get_prefetch_passes( multipass_prefetch_config: Optional[MultiPassPrefetchConfig], input_tensor: Tensor, output_tensor: Tensor, - ) -> list[tuple[Tensor, Tensor, int]]: + ) -> List[Tuple[Tensor, Tensor, int]]: """ Given inputs (the indices to forward), partition the input and output into smaller chunks and return them as a list of tuples @@ -1611,7 +1611,7 @@ def get_prefetch_passes( ) ) - def get_states(self, prefix: str) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + def get_states(self, prefix: str) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Get a state of a given tensor (`prefix`) @@ -1650,7 +1650,7 @@ def get_states(self, prefix: str) -> tuple[Tensor, Tensor, Tensor, Tensor, Tenso torch.tensor(offsets, dtype=torch.int64), ) - def get_all_states(self) -> list[tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]: + def get_all_states(self) -> List[Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]: """ Get all states in the TBE (`weights`, `momentum1`, `momentum2`, `prev_iter`, and `row_counter`) @@ -1782,7 +1782,7 @@ def _report_io_size_count(self, event: str, data: Tensor) -> Tensor: def _generate_vbe_metadata( self, offsets: Tensor, - batch_size_per_feature_per_rank: Optional[list[list[int]]], + batch_size_per_feature_per_rank: Optional[List[List[int]]], ) -> invokers.lookup_args.VBEMetadata: # Blocking D2H copy, but only runs at first call self.feature_dims = self.feature_dims.cpu() @@ -1842,7 +1842,7 @@ def writeback_update_gradient( return mask # pyre-fixme[2]: For 1st argument expected not ANY - def writeback_hook(self, module: Any, grad: Tensor) -> tuple[Tensor]: + def writeback_hook(self, module: Any, grad: Tensor) -> Tuple[Tensor]: indices = self._indices offsets = self._offsets @@ -1854,7 +1854,7 @@ def forward( # noqa: C901 offsets: Tensor, per_sample_weights: Optional[Tensor] = None, feature_requires_grad: Optional[Tensor] = None, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, total_unique_indices: Optional[int] = None, ) -> Tensor: """ @@ -2346,7 +2346,7 @@ def forward( # noqa: C901 raise ValueError(f"Invalid OptimType: {self.optimizer}") - def ema_inplace(self, emainplace_mode: dict[str, float]) -> None: + def ema_inplace(self, emainplace_mode: Dict[str, float]) -> None: """ Perform ema operations on the full sparse embedding tables. We organize the sparse table, in the following way. @@ -2376,7 +2376,7 @@ def ema_inplace(self, emainplace_mode: dict[str, float]) -> None: emainplace_mode["step_ema_coef"], ) - def ensemble_and_swap(self, ensemble_mode: dict[str, float]) -> None: + def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: """ Perform ensemble and swap operations on the full sparse embedding tables. @@ -2424,7 +2424,7 @@ def get_uvm_cache_stats(self, use_local_cache: bool = False) -> Tensor: ), "gather_uvm_cache_stats should be set to true to access uvm cache stats." return self.local_uvm_cache_stats if use_local_cache else self.uvm_cache_stats - def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> list[float]: + def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> List[float]: snapshot = self.get_uvm_cache_stats(use_local_cache) if use_local_cache: return snapshot.tolist() @@ -2437,7 +2437,7 @@ def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> list[floa @torch.jit.ignore def print_uvm_cache_stats(self, use_local_cache: bool = False) -> None: # TODO: Create a separate reporter class to unify the stdlog reporting - uvm_cache_stats: list[float] = self._get_uvm_cache_print_state(use_local_cache) + uvm_cache_stats: List[float] = self._get_uvm_cache_print_state(use_local_cache) N = max(1, uvm_cache_stats[0]) m = { "N_called": uvm_cache_stats[UVMCacheStatsIndex.num_calls], @@ -2481,14 +2481,14 @@ def _report_uvm_cache_stats(self) -> None: if not stats_reporter.should_report(self.step): return - uvm_cache_stats: list[float] = self.get_uvm_cache_stats( + uvm_cache_stats: List[float] = self.get_uvm_cache_stats( use_local_cache=False ).tolist() self.last_reported_step = self.step if len(self.last_reported_uvm_stats) == 0: self.last_reported_uvm_stats = [0.0] * len(uvm_cache_stats) - uvm_cache_stats_delta: list[float] = [0.0] * len(uvm_cache_stats) + uvm_cache_stats_delta: List[float] = [0.0] * len(uvm_cache_stats) for i in range(len(uvm_cache_stats)): uvm_cache_stats_delta[i] = ( uvm_cache_stats[i] - self.last_reported_uvm_stats[i] @@ -2517,7 +2517,7 @@ def prefetch( indices: Tensor, offsets: Tensor, forward_stream: Optional[torch.cuda.Stream] = None, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> None: if self.prefetch_stream is None and forward_stream is not None: self.prefetch_stream = torch.cuda.current_stream() @@ -2792,7 +2792,7 @@ def init_embedding_weights_uniform(self, min_val: float, max_val: float) -> None param.uniform_(min_val, max_val) @torch.jit.ignore - def split_embedding_weights(self) -> list[Tensor]: + def split_embedding_weights(self) -> List[Tensor]: """ Returns a list of embedding weights (view), split by table @@ -2834,7 +2834,7 @@ def get_optimizer_buffer(self, state: str) -> torch.Tensor: raise ValueError(f"Optimizer buffer {state} not found") @torch.jit.export - def get_optimizer_state(self) -> list[dict[str, torch.Tensor]]: + def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]: r""" Get the optimizer state dict that matches the OSS Pytorch optims TODO: populate the supported list of optimizers @@ -2918,7 +2918,7 @@ def get_optimizer_state(self) -> list[dict[str, torch.Tensor]]: @torch.jit.ignore def split_optimizer_states( self, - ) -> list[list[torch.Tensor]]: + ) -> List[List[torch.Tensor]]: """ Returns a list of optimizer states (view), split by table @@ -2966,7 +2966,7 @@ def get_optimizer_states( state_offsets: Tensor, state_placements: Tensor, rowwise: bool, - ) -> list[torch.Tensor]: + ) -> List[torch.Tensor]: splits = [] for t, (rows, dim, _, _) in enumerate(self.embedding_specs): offset = state_offsets[t] @@ -2985,7 +2985,7 @@ def get_optimizer_states( splits.append(state.detach()[offset : offset + rows].view(rows)) return splits - states: list[list[torch.Tensor]] = [] + states: List[List[torch.Tensor]] = [] if self.optimizer not in (OptimType.EXACT_SGD,): states.append( get_optimizer_states( @@ -3111,7 +3111,7 @@ def get_learning_rate(self) -> float: return self.learning_rate_tensor.item() @torch.jit.ignore - def update_hyper_parameters(self, params_dict: dict[str, float]) -> None: + def update_hyper_parameters(self, params_dict: Dict[str, float]) -> None: """ Sets hyper-parameters from external control flow. @@ -3187,10 +3187,10 @@ def _apply_split( self, split: SplitState, prefix: str, - dtype: type[torch.dtype], + dtype: Type[torch.dtype], enforce_hbm: bool = False, make_dev_param: bool = False, - dev_reshape: Optional[tuple[int, ...]] = None, + dev_reshape: Optional[Tuple[int, ...]] = None, uvm_host_mapped: bool = False, ) -> None: apply_split_helper( @@ -3436,7 +3436,7 @@ def _sync_stream_post_backward( def _update_cache_counter_and_locations( self, module: nn.Module, - grad_input: Union[tuple[Tensor, ...], Tensor], + grad_input: Union[Tuple[Tensor, ...], Tensor], ) -> None: """ Backward prehook function when prefetch_pipeline is enabled. @@ -3632,10 +3632,10 @@ def prepare_inputs( indices: Tensor, offsets: Tensor, per_sample_weights: Optional[Tensor] = None, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, force_cast_input_types: bool = True, prefetch_pipeline: bool = False, - ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]: + ) -> Tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]: """ Prepare TBE inputs as follows: @@ -3825,7 +3825,7 @@ def _debug_print_input_stats_factory_impl( # Counts of indices that segment lengths > 1024 counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024] - def compute_numel_and_avg(counts: Tensor) -> tuple[int, float]: + def compute_numel_and_avg(counts: Tensor) -> Tuple[int, float]: numel = counts.numel() avg = (counts.sum().item() / numel) if numel != 0 else -1.0 return numel, avg @@ -4026,12 +4026,12 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module): max_D: int hash_size_cumsum: Tensor total_hash_size_bits: int - embedding_specs: list[tuple[int, int]] + embedding_specs: List[Tuple[int, int]] def __init__( self, - embedding_specs: list[tuple[int, int]], # tuple of (rows, dims) - feature_table_map: Optional[list[int]] = None, # [T] + embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims) + feature_table_map: Optional[List[int]] = None, # [T] weights_precision: SparseType = SparseType.FP32, pooling_mode: PoolingMode = PoolingMode.SUM, use_cpu: bool = False, @@ -4144,7 +4144,7 @@ def __init__( row for (row, _) in embedding_specs[:t] ) - self.weights_physical_offsets: list[int] = weights_offsets + self.weights_physical_offsets: List[int] = weights_offsets weights_offsets = [weights_offsets[t] for t in feature_table_map] self.register_buffer( "weights_offsets", @@ -4171,7 +4171,7 @@ def log(self, msg: str) -> None: def _generate_vbe_metadata( self, offsets: Tensor, - batch_size_per_feature_per_rank: Optional[list[list[int]]], + batch_size_per_feature_per_rank: Optional[List[List[int]]], ) -> invokers.lookup_args.VBEMetadata: # Blocking D2H copy, but only runs at first call self.feature_dims = self.feature_dims.cpu() @@ -4189,7 +4189,7 @@ def forward( offsets: Tensor, per_sample_weights: Optional[Tensor] = None, feature_requires_grad: Optional[Tensor] = None, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> Tensor: # Generate VBE metadata vbe_metadata = self._generate_vbe_metadata( @@ -4228,7 +4228,7 @@ def forward( ) @torch.jit.export - def split_embedding_weights(self) -> list[Tensor]: + def split_embedding_weights(self) -> List[Tensor]: """ Returns a list of weights, split by table """ diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py index 40080b0fba..79050eb805 100755 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py @@ -7,7 +7,7 @@ # pyre-unsafe -from typing import Optional +from typing import List, Optional import torch from torch import Tensor @@ -36,7 +36,7 @@ def is_torchdynamo_compiling() -> bool: # type: ignore[misc] def generate_vbe_metadata( offsets: Tensor, - batch_size_per_feature_per_rank: Optional[list[list[int]]], + batch_size_per_feature_per_rank: Optional[List[List[int]]], pooling_mode: PoolingMode, feature_dims_cpu: Tensor, device: torch.device, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_config.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_config.py index 3ffe0889e3..6e51d715dd 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_config.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_config.py @@ -10,7 +10,7 @@ import dataclasses import json from enum import Enum -from typing import Any, Optional +from typing import Any, Dict, Optional import click @@ -34,7 +34,7 @@ class TBEBenchmarkingConfig: @classmethod # pyre-ignore [3] - def from_dict(cls, data: dict[str, Any]): + def from_dict(cls, data: Dict[str, Any]): return cls(**data) @classmethod @@ -42,7 +42,7 @@ def from_dict(cls, data: dict[str, Any]): def from_json(cls, data: str): return cls.from_dict(json.loads(data)) - def dict(self) -> dict[str, Any]: + def dict(self) -> Dict[str, Any]: return dataclasses.asdict(self) def json(self, format: bool = False) -> str: diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 1243f14db4..a0f0c518d8 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -12,7 +12,7 @@ import threading import time from subprocess import Popen -from typing import Callable, Optional +from typing import Callable, List, Optional, Tuple import torch @@ -49,7 +49,7 @@ def bench_warmup_with_spec( warmup_ms: int, warmup_runs: int, func: Callable[ - [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]], + [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[List[List[int]]]], torch.Tensor, ], bwd_only: bool = False, @@ -92,7 +92,7 @@ def wait(self) -> None: def cpu_tbe_worker( - requests_: list[TBERequest], + requests_: List[TBERequest], func_: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], use_barrier: bool = False, ) -> float: @@ -124,7 +124,7 @@ def cpu_tbe_worker( def benchmark_cpu_requests_mp( - requests: list[TBERequest], + requests: List[TBERequest], emb_module: torch.nn.Module, num_warmups: int = 0, num_copies: int = 1, @@ -207,7 +207,7 @@ def benchmark_cpu_requests_mp( def benchmark_cpu_requests( - requests: list[TBERequest], + requests: List[TBERequest], func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], num_warmups: int = 0, ) -> float: @@ -225,7 +225,7 @@ def benchmark_cpu_requests( def benchmark_requests( # noqa: C901 - requests: list[TBERequest], + requests: List[TBERequest], func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], flush_gpu_cache_size_mb: int = 0, check_median: bool = False, @@ -335,9 +335,9 @@ def benchmark_requests( # noqa: C901 def benchmark_requests_with_spec( # noqa: C901 - requests: list[TBERequest], + requests: List[TBERequest], func: Callable[ - [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]], + [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[List[List[int]]]], torch.Tensor, ], flush_gpu_cache_size_mb: int = 0, @@ -450,7 +450,7 @@ def benchmark_requests_with_spec( # noqa: C901 def benchmark_requests_refer( - requests: list[TBERequest], + requests: List[TBERequest], T: int, B: int, L: int, @@ -542,12 +542,12 @@ def benchmark_requests_refer( def benchmark_pipelined_requests( - requests: list[TBERequest], + requests: List[TBERequest], func1: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None], func2: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None], flush_gpu_cache_size_mb: int = 0, check_median: bool = False, -) -> tuple[float, float]: +) -> Tuple[float, float]: torch.cuda.synchronize() start_events = [ (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) @@ -599,10 +599,10 @@ def benchmark_pipelined_requests( def benchmark_vbe( - requests: list[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + requests: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], num_warmups: int = 0, -) -> tuple[float, float]: +) -> Tuple[float, float]: """ A benchmark function to return the average execution time in seconds of forward and backward of VBE kernels. diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/eeg_cli.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/eeg_cli.py index 770fdd9d61..9206287589 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/eeg_cli.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/eeg_cli.py @@ -6,6 +6,7 @@ # pyre-strict +from typing import List, Tuple import click import torch @@ -81,7 +82,7 @@ def estimate(indices: str) -> None: ) def generate( hitters: str, - zipf: tuple[float, float], + zipf: Tuple[float, float], max_index: int, num_indices: int, output: str, @@ -113,7 +114,7 @@ def generate( assert output != "", "Output file path must be provided" try: - _hitters: list[float] = ( + _hitters: List[float] = ( [float(x) for x in hitters.split(",")] if hitters else [] ) except Exception as e: diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/embedding_ops_common_config.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/embedding_ops_common_config.py index d9d8d46c58..e9dde03e51 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/embedding_ops_common_config.py @@ -8,7 +8,7 @@ # pyre-strict import dataclasses -from typing import Any, Optional +from typing import Any, Dict, Optional import click import torch @@ -44,7 +44,7 @@ class EmbeddingOpsCommonConfig: def validate(self): return self - def split_args(self) -> dict[str, Any]: + def split_args(self) -> Dict[str, Any]: return { "weights_precision": self.weights_dtype, "stochastic_rounding": self.stochastic_rounding, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/eval_compression.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/eval_compression.py index bf5a4c52a9..bacd62177d 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/eval_compression.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/eval_compression.py @@ -10,7 +10,7 @@ import logging import statistics from dataclasses import dataclass -from typing import Callable +from typing import Callable, List, Tuple import torch @@ -29,8 +29,8 @@ class EvalCompressionBenchmarkOutput: def benchmark_eval_compression( - baseline_requests: list[tuple[torch.Tensor, torch.Tensor]], - compressed_requests: list[tuple[torch.Tensor, torch.Tensor]], + baseline_requests: List[Tuple[torch.Tensor, torch.Tensor]], + compressed_requests: List[Tuple[torch.Tensor, torch.Tensor]], baseline_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], compressed_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], reindex: torch.Tensor, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py index dcb95da3be..8ad4c7c3f1 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py @@ -9,7 +9,7 @@ import dataclasses import json -from typing import Any, Optional +from typing import Any, Dict, List, Optional import torch @@ -46,9 +46,9 @@ class TBEDataConfig: # Force generated tensors to be on CPU use_cpu: bool = False # Number of embeddings in each embedding features (number of rows) - Es: Optional[list[int]] = None + Es: Optional[List[int]] = None # Target embedding dimension for each features (number of columns) - Ds: Optional[list[int]] = None + Ds: Optional[List[int]] = None # Maximum number of indices max_indices: Optional[int] = None # Maximum number of indices @@ -60,7 +60,7 @@ def __post_init__(self) -> None: self.validate() @staticmethod - def complex_fields() -> dict[str, Any]: + def complex_fields() -> Dict[str, Any]: return { "batch_params": BatchParams, "indices_params": IndicesParams, @@ -69,7 +69,7 @@ def complex_fields() -> dict[str, Any]: @classmethod # pyre-ignore [3] - def from_dict(cls, data: dict[str, Any]): + def from_dict(cls, data: Dict[str, Any]): for field, Type in cls.complex_fields().items(): if not isinstance(data[field], Type): data[field] = Type.from_dict(data[field]) @@ -80,7 +80,7 @@ def from_dict(cls, data: dict[str, Any]): def from_json(cls, data: str): return cls.from_dict(json.loads(data)) - def dict(self) -> dict[str, Any]: + def dict(self) -> Dict[str, Any]: tmp = dataclasses.asdict(self) for field in TBEDataConfig.complex_fields().keys(): tmp[field] = self.__dict__[field].dict() diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py index f2c63546ca..e7944c51f3 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Optional +from typing import List, Optional, Tuple import numpy as np import torch @@ -34,7 +34,7 @@ def _generate_batch_sizes( tbe_data_config: TBEDataConfig, -) -> tuple[list[int], Optional[list[list[int]]]]: +) -> Tuple[List[int], Optional[List[List[int]]]]: if tbe_data_config.variable_B(): assert ( tbe_data_config.batch_params.vbe_num_ranks is not None @@ -54,7 +54,7 @@ def _generate_batch_sizes( def _generate_pooling_info( - tbe_data_config: TBEDataConfig, iters: int, Bs: list[int] + tbe_data_config: TBEDataConfig, iters: int, Bs: List[int] ) -> torch.Tensor: if tbe_data_config.variable_L(): # Generate L from stats @@ -77,7 +77,7 @@ def _generate_pooling_info( def _generate_indices( tbe_data_config: TBEDataConfig, iters: int, - Bs: list[int], + Bs: List[int], L_offsets: torch.Tensor, ) -> torch.Tensor: @@ -107,11 +107,11 @@ def _generate_indices( def _build_requests_jagged( tbe_data_config: TBEDataConfig, iters: int, - Bs: list[int], - Bs_feature_rank: Optional[list[list[int]]], + Bs: List[int], + Bs_feature_rank: Optional[List[List[int]]], L_offsets: torch.Tensor, all_indices: torch.Tensor, -) -> list[TBERequest]: +) -> List[TBERequest]: total_B = sum(Bs) all_indices = all_indices.flatten() requests = [] @@ -142,7 +142,7 @@ def _build_requests_jagged( def _build_requests_dense( tbe_data_config: TBEDataConfig, iters: int, all_indices: torch.Tensor -) -> list[TBERequest]: +) -> List[TBERequest]: # NOTE: We're using existing code from requests.py to build the # requests, and since the existing code requires 2D view of all_indices, # the existing all_indices must be reshaped @@ -175,8 +175,8 @@ def _build_requests_dense( def generate_requests( tbe_data_config: TBEDataConfig, iters: int = 1, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, -) -> list[TBERequest]: + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, +) -> List[TBERequest]: # Generate batch sizes if batch_size_per_feature_per_rank: @@ -221,8 +221,8 @@ def generate_requests_with_Llist( tbe_data_config: TBEDataConfig, L_list: torch.Tensor, iters: int = 1, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, -) -> list[TBERequest]: + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, +) -> List[TBERequest]: """ Generate a list of TBERequest objects based on the provided TBE data configuration and L_list This function generates batch sizes and pooling information from the input L_list, @@ -284,7 +284,7 @@ def generate_requests_with_Llist( return _build_requests_dense(tbe_data_config, iters, all_indices) -def generate_embedding_dims(tbe_data_config: TBEDataConfig) -> tuple[int, list[int]]: +def generate_embedding_dims(tbe_data_config: TBEDataConfig) -> Tuple[int, List[int]]: if tbe_data_config.mixed_dim: Ds = [ round_up( diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py index 7d145c3136..1519788ef2 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py @@ -9,7 +9,7 @@ import dataclasses import json -from typing import Any, Optional +from typing import Any, Dict, List, Optional import torch @@ -40,7 +40,7 @@ class IndicesParams: @classmethod # pyre-ignore [3] - def from_dict(cls, data: dict[str, Any]): + def from_dict(cls, data: Dict[str, Any]): if not isinstance(data["heavy_hitters"], torch.Tensor): data["heavy_hitters"] = torch.tensor( data["heavy_hitters"], dtype=torch.float32 @@ -54,7 +54,7 @@ def from_dict(cls, data: dict[str, Any]): def from_json(cls, data: str): return cls.from_dict(json.loads(data)) - def dict(self) -> dict[str, Any]: + def dict(self) -> Dict[str, Any]: # https://stackoverflow.com/questions/73735974/convert-dataclass-of-dataclass-to-json-string tmp = dataclasses.asdict(self) # Convert tensor to list for JSON serialization @@ -99,11 +99,11 @@ class BatchParams: # Number of ranks for variable batch size generation vbe_num_ranks: Optional[int] = None # List of target batch sizes, i.e. number of batch lookups per table - Bs: Optional[list[int]] = None + Bs: Optional[List[int]] = None @classmethod # pyre-ignore [3] - def from_dict(cls, data: dict[str, Any]): + def from_dict(cls, data: Dict[str, Any]): return cls(**data) @classmethod @@ -111,7 +111,7 @@ def from_dict(cls, data: dict[str, Any]): def from_json(cls, data: str): return cls.from_dict(json.loads(data)) - def dict(self) -> dict[str, Any]: + def dict(self) -> Dict[str, Any]: return dataclasses.asdict(self) def json(self, format: bool = False) -> str: @@ -145,7 +145,7 @@ class PoolingParams: @classmethod # pyre-ignore [3] - def from_dict(cls, data: dict[str, Any]): + def from_dict(cls, data: Dict[str, Any]): return cls(**data) @classmethod @@ -153,7 +153,7 @@ def from_dict(cls, data: dict[str, Any]): def from_json(cls, data: str): return cls.from_dict(json.loads(data)) - def dict(self) -> dict[str, Any]: + def dict(self) -> Dict[str, Any]: return dataclasses.asdict(self) def json(self, format: bool = False) -> str: diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py index dc536dccdb..ad006a74b0 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py @@ -10,7 +10,7 @@ # pyre-ignore-all-errors[56] -from typing import Optional, Union +from typing import List, Optional, Tuple, Union import torch # usort:skip from torch import Tensor # usort:skip @@ -47,15 +47,15 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen): def __init__( # noqa C901 self, - embedding_specs: list[ - tuple[str, int, int, SparseType, EmbeddingLocation] + embedding_specs: List[ + Tuple[str, int, int, SparseType, EmbeddingLocation] ], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement) - feature_table_map: Optional[list[int]] = None, # [T] - index_remapping: Optional[list[Tensor]] = None, + feature_table_map: Optional[List[int]] = None, # [T] + index_remapping: Optional[List[Tensor]] = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: Optional[Union[str, int, torch.device]] = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, - weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None, + weight_lists: Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, @@ -74,7 +74,7 @@ def __init__( # noqa C901 cacheline_alignment: bool = True, uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged. reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row. - feature_names_per_table: Optional[list[list[str]]] = None, + feature_names_per_table: Optional[List[List[str]]] = None, indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64). ) -> None: # noqa C901 # tuple of (rows, dims,) super(KVEmbeddingInference, self).__init__( @@ -119,12 +119,12 @@ def __init__( # noqa C901 num_shards, uniform_init_lower, uniform_init_upper ) - self.specs: list[tuple[int, int, int]] = [ + self.specs: List[Tuple[int, int, int]] = [ (rows, dims, sparse_type.as_int()) for (_, rows, dims, sparse_type, _) in self.embedding_specs ] # table shard offset if inference sharding is enabled, otherwise, should be all zeros - self.table_sharding_offset: list[int] = [0] * len(self.embedding_specs) + self.table_sharding_offset: List[int] = [0] * len(self.embedding_specs) self.kv_embedding_cache_initialized = False self.hash_size_cumsum: torch.Tensor = torch.zeros( 0, @@ -137,7 +137,7 @@ def __init__( # noqa C901 dtype=torch.int64, ) - def construct_hash_size_cumsum(self) -> list[int]: + def construct_hash_size_cumsum(self) -> List[int]: hash_size_cumsum = [0] for spec in self.embedding_specs: rows = spec[1] @@ -146,7 +146,7 @@ def construct_hash_size_cumsum(self) -> list[int]: def calculate_indices_and_weights_offsets( self, indices: Tensor, offsets: Tensor - ) -> tuple[Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor]: if self.pooling_mode is not PoolingMode.NONE: T = self.weights_offsets.numel() else: @@ -280,7 +280,7 @@ def fill_random_weights(self) -> None: self.weight_initialized = True @torch.jit.export - def init_tbe_config(self, table_sharding_offset: list[int]) -> None: + def init_tbe_config(self, table_sharding_offset: List[int]) -> None: """ Initialize the dynamic TBE table configs, e.g. sharded table offsets, etc. Should be called before loading weights. @@ -290,9 +290,9 @@ def init_tbe_config(self, table_sharding_offset: list[int]) -> None: @torch.jit.export def embedding_inplace_update( self, - update_table_indices: list[int], - update_row_indices: list[list[int]], - update_weights: list[Tensor], + update_table_indices: List[int], + update_row_indices: List[List[int]], + update_weights: List[Tensor], ) -> None: # function is not used for now on the inference side for i in range(len(update_table_indices)): diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py index 36f2c689a0..4eb82385a0 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import Optional, Union +from typing import Optional, Tuple, Union import torch @@ -17,13 +17,13 @@ def get_unique_indices_v2( compute_count: bool = False, compute_inverse_indices: bool = False, ) -> Union[ - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], - tuple[ + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], + Tuple[ torch.Tensor, torch.Tensor, Optional[torch.Tensor], ], - tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], ]: """ A wrapper for get_unique_indices for overloading the return type diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/inference.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/inference.py index a5782546fd..e1cb4a22b1 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/inference.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/inference.py @@ -13,7 +13,7 @@ import os import tempfile from math import log2 -from typing import Optional +from typing import List, Optional, Tuple import torch # usort:skip @@ -42,15 +42,15 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module): Inference version, with FP32/FP16/FP8/INT8/INT4/INT2 supports """ - embedding_specs: list[tuple[str, int, int, SparseType]] + embedding_specs: List[Tuple[str, int, int, SparseType]] _local_instance_index: int = -1 def __init__( self, - embedding_specs: list[ - tuple[str, int, int, SparseType] + embedding_specs: List[ + Tuple[str, int, int, SparseType] ], # tuple of (feature_names, rows, dims, SparseType) - feature_table_map: Optional[list[int]] = None, # [T] + feature_table_map: Optional[List[int]] = None, # [T] pooling_mode: PoolingMode = PoolingMode.SUM, output_dtype: SparseType = SparseType.FP16, row_alignment: Optional[int] = None, @@ -73,7 +73,7 @@ def __init__( ssd_uniform_init_lower: float = -0.01, ssd_uniform_init_upper: float = 0.01, # Parameter Server Configs - ps_hosts: Optional[tuple[tuple[str, int]]] = None, + ps_hosts: Optional[Tuple[Tuple[str, int]]] = None, ps_max_key_per_request: Optional[int] = None, ps_client_thread_num: Optional[int] = None, ps_max_local_index_length: Optional[int] = None, @@ -99,7 +99,7 @@ def __init__( self.current_device = torch.device(device) self.use_cpu: bool = self.current_device.type == "cpu" - self.feature_table_map: list[int] = ( + self.feature_table_map: List[int] = ( feature_table_map if feature_table_map is not None else list(range(T_)) ) T = len(self.feature_table_map) @@ -112,9 +112,9 @@ def __init__( self.output_dtype: int = output_dtype.as_int() # (feature_names, rows, dims, weights_tys) = zip(*embedding_specs) # Pyre workaround - rows: list[int] = [e[1] for e in embedding_specs] - dims: list[int] = [e[2] for e in embedding_specs] - weights_tys: list[SparseType] = [e[3] for e in embedding_specs] + rows: List[int] = [e[1] for e in embedding_specs] + dims: List[int] = [e[2] for e in embedding_specs] + weights_tys: List[SparseType] = [e[3] for e in embedding_specs] D_offsets = [dims[t] for t in self.feature_table_map] D_offsets = [0] + list(itertools.accumulate(D_offsets)) @@ -169,7 +169,7 @@ def max_ty_D(ty: SparseType) -> int: offsets.append(uvm_size) uvm_size += state_size - self.weights_physical_offsets: list[int] = offsets + self.weights_physical_offsets: List[int] = offsets weights_tys_int = [weights_tys[t].as_int() for t in self.feature_table_map] self.register_buffer( @@ -517,13 +517,13 @@ def forward( @torch.jit.export def split_embedding_weights( self, split_scale_shifts: bool = True - ) -> list[tuple[Tensor, Optional[Tensor]]]: + ) -> List[Tuple[Tensor, Optional[Tensor]]]: """ Returns a list of weights, split by table. Testing only, very slow. """ - splits: list[tuple[Tensor, Optional[Tensor]]] = [] + splits: List[Tuple[Tensor, Optional[Tensor]]] = [] rows_cumsum = 0 for _, row, dim, weight_ty in self.embedding_specs: weights = torch.empty( diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index f6a2adb4a5..23c45309f5 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -18,7 +18,7 @@ import time from functools import cached_property from math import floor, log2 -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch # usort:skip # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers @@ -75,10 +75,10 @@ class IterData: @dataclass class KVZCHCachedData: - cached_optimizer_states_per_table: list[list[torch.Tensor]] - cached_weight_tensor_per_table: list[torch.Tensor] - cached_id_tensor_per_table: list[torch.Tensor] - cached_bucket_splits: list[torch.Tensor] + cached_optimizer_states_per_table: List[List[torch.Tensor]] + cached_weight_tensor_per_table: List[torch.Tensor] + cached_id_tensor_per_table: List[torch.Tensor] + cached_bucket_splits: List[torch.Tensor] class SSDTableBatchedEmbeddingBags(nn.Module): @@ -99,12 +99,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module): weights_offsets: Tensor _local_instance_index: int = -1 res_params: RESParams - table_names: list[str] + table_names: List[str] def __init__( self, - embedding_specs: list[tuple[int, int]], # tuple of (rows, dims) - feature_table_map: Optional[list[int]], # [T] + embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims) + feature_table_map: Optional[List[int]], # [T] cache_sets: int, # A comma-separated string, e.g. "/data00_nvidia0,/data01_nvidia0/", db shards # will be placed in these paths round-robin. @@ -147,7 +147,7 @@ def __init__( pooling_mode: PoolingMode = PoolingMode.SUM, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, # Parameter Server Configs - ps_hosts: Optional[tuple[tuple[str, int]]] = None, + ps_hosts: Optional[Tuple[Tuple[str, int]]] = None, ps_max_key_per_request: Optional[int] = None, ps_client_thread_num: Optional[int] = None, ps_max_local_index_length: Optional[int] = None, @@ -176,9 +176,9 @@ def __init__( enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming res_params: Optional[RESParams] = None, # raw embedding streaming sharding info flushing_block_size: int = 2_000_000_000, # 2GB - table_names: Optional[list[str]] = None, + table_names: Optional[List[str]] = None, use_rowwise_bias_correction: bool = False, # For Adam use - optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006 + optimizer_state_dtypes: Dict[str, SparseType] = {}, # noqa: B006 ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -197,11 +197,11 @@ def __init__( if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD: # Adagrad currently only supports FP32 for momentum1 - self.optimizer_state_dtypes: dict[str, SparseType] = { + self.optimizer_state_dtypes: Dict[str, SparseType] = { "momentum1": SparseType.FP32, } else: - self.optimizer_state_dtypes: dict[str, SparseType] = optimizer_state_dtypes + self.optimizer_state_dtypes: Dict[str, SparseType] = optimizer_state_dtypes # Zero collision TBE configurations self.kv_zch_params = kv_zch_params @@ -260,7 +260,7 @@ def __init__( f"get env {self.res_params.res_server_port=}, at rank {dist.get_rank()}, with {self.res_params=}" ) - self.feature_table_map: list[int] = ( + self.feature_table_map: List[int] = ( feature_table_map if feature_table_map is not None else list(range(T_)) ) T = len(self.feature_table_map) @@ -561,9 +561,9 @@ def __init__( """ self._cached_kvzch_data: Optional[KVZCHCachedData] = None # initial embedding rows on this rank per table, this is used for loading checkpoint - self.local_weight_counts: list[int] = [0] * T_ + self.local_weight_counts: List[int] = [0] * T_ # groundtruth global id on this rank per table, this is used for loading checkpoint - self.global_id_per_rank: list[torch.Tensor] = [torch.zeros(0)] * T_ + self.global_id_per_rank: List[torch.Tensor] = [torch.zeros(0)] * T_ # loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend self.load_state_dict: bool = False @@ -817,22 +817,22 @@ def __init__( ) # (Indices, Count) - self.prefetched_info: list[tuple[Tensor, Tensor]] = [] + self.prefetched_info: List[Tuple[Tensor, Tensor]] = [] - self.timesteps_prefetched: list[int] = [] + self.timesteps_prefetched: List[int] = [] # TODO: add type annotation # pyre-fixme[4]: Attribute must be annotated. self.ssd_prefetch_data = [] # Scratch pad eviction data queue - self.ssd_scratch_pad_eviction_data: list[ - tuple[Tensor, Tensor, Tensor, bool] + self.ssd_scratch_pad_eviction_data: List[ + Tuple[Tensor, Tensor, Tensor, bool] ] = [] - self.ssd_location_update_data: list[tuple[Tensor, Tensor]] = [] + self.ssd_location_update_data: List[Tuple[Tensor, Tensor]] = [] if self.prefetch_pipeline: # Scratch pad value queue - self.ssd_scratch_pads: list[tuple[Tensor, Tensor, Tensor]] = [] + self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = [] # pyre-ignore[4] # Scratch pad index queue @@ -934,7 +934,7 @@ def __init__( self.ssd_cache_stats_size = 6 # 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses, # 4: N_conflict_unique_misses, 5: N_conflict_misses - self.last_reported_ssd_stats: list[float] = [] + self.last_reported_ssd_stats: List[float] = [] self.last_reported_step = 0 self.register_buffer( @@ -965,7 +965,7 @@ def __init__( self.prefetch_parallel_stream_cnt: int = 2 # tuple of iteration, prefetch parallel stream cnt, reported duration # since there are 2 stream in parallel in prefetch, we want to count the longest one - self.prefetch_duration_us: tuple[int, int, float] = ( + self.prefetch_duration_us: Tuple[int, int, float] = ( -1, self.prefetch_parallel_stream_cnt, 0, @@ -1272,10 +1272,10 @@ def _apply_split( self, split: SplitState, prefix: str, - dtype: type[torch.dtype], + dtype: Type[torch.dtype], enforce_hbm: bool = False, make_dev_param: bool = False, - dev_reshape: Optional[tuple[int, ...]] = None, + dev_reshape: Optional[Tuple[int, ...]] = None, ) -> None: apply_split_helper( self.register_buffer, @@ -1298,11 +1298,11 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor: def to_pinned_cpu_on_stream_wait_on_another_stream( self, - tensors: list[Tensor], + tensors: List[Tensor], stream: torch.cuda.Stream, stream_to_wait_on: torch.cuda.Stream, post_event: Optional[torch.cuda.Event] = None, - ) -> list[Tensor]: + ) -> List[Tensor]: """ Transfer input tensors from GPU to CPU using a pinned host buffer. The transfer is carried out on the given stream @@ -1542,7 +1542,7 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None: def _update_cache_counter_and_pointers( self, module: nn.Module, - grad_input: Union[tuple[Tensor, ...], Tensor], + grad_input: Union[Tuple[Tensor, ...], Tensor], ) -> None: """ Update cache line locking counter and pointers before backward @@ -1650,7 +1650,7 @@ def prefetch( offsets: Tensor, weights: Optional[Tensor] = None, # todo: need to update caller forward_stream: Optional[torch.cuda.Stream] = None, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> None: if self.prefetch_stream is None and forward_stream is not None: # Set the prefetch stream to the current stream @@ -2194,7 +2194,7 @@ def _prefetch( # noqa C901 def _generate_vbe_metadata( self, offsets: Tensor, - batch_size_per_feature_per_rank: Optional[list[list[int]]], + batch_size_per_feature_per_rank: Optional[List[List[int]]], ) -> invokers.lookup_args.VBEMetadata: # Blocking D2H copy, but only runs at first call self.feature_dims = self.feature_dims.cpu() @@ -2242,7 +2242,7 @@ def forward( weights: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None, feature_requires_grad: Optional[Tensor] = None, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`. ) -> Tensor: self.clear_cache() @@ -2398,7 +2398,7 @@ def forward( @torch.jit.ignore def _split_optimizer_states_non_kv_zch( self, - ) -> list[list[torch.Tensor]]: + ) -> List[List[torch.Tensor]]: """ Returns a list of optimizer states (view), split by table. @@ -2417,9 +2417,9 @@ def _split_optimizer_states_non_kv_zch( # Row count per table (rows, dims) = zip(*self.embedding_specs) # Cumulative row counts per table for rowwise states - row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows)) + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows)) # Cumulative element counts per table for elementwise states - elem_count_cumsum: list[int] = [0] + list( + elem_count_cumsum: List[int] = [0] + list( itertools.accumulate([r * d for r, d in self.embedding_specs]) ) @@ -2475,14 +2475,14 @@ def _slice(tensor: Tensor, t: int, rowwise: bool) -> Tensor: def _split_optimizer_states_kv_zch_no_offloading( self, sorted_ids: torch.Tensor, - ) -> list[list[torch.Tensor]]: + ) -> List[List[torch.Tensor]]: # Row count per table (rows, dims) = zip(*self.embedding_specs) # Cumulative row counts per table for rowwise states - row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows)) + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows)) # Cumulative element counts per table for elementwise states - elem_count_cumsum: list[int] = [0] + list( + elem_count_cumsum: List[int] = [0] + list( itertools.accumulate([r * d for r, d in self.embedding_specs]) ) @@ -2567,12 +2567,12 @@ def _split_optimizer_states_kv_zch_w_offloading( sorted_ids: torch.Tensor, no_snapshot: bool = True, should_flush: bool = False, - ) -> list[list[torch.Tensor]]: + ) -> List[List[torch.Tensor]]: dtype = self.weights_precision.as_dtype() # Row count per table (rows_, dims_) = zip(*self.embedding_specs) # Cumulative row counts per table for rowwise states - row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_)) + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows_)) snapshot_handle, _ = self._may_create_snapshot_for_state_dict( no_snapshot=no_snapshot, @@ -2582,7 +2582,7 @@ def _split_optimizer_states_kv_zch_w_offloading( # pyre-ignore[53] def _fetch_offloaded_optimizer_states( t: int, - ) -> list[Tensor]: + ) -> List[Tensor]: e: int = rows_[t] d: int = dims_[t] @@ -2660,7 +2660,7 @@ def _fetch_offloaded_optimizer_states( ) # Now split up the buffer into N views, N for each optimizer state - optimizer_states: list[Tensor] = [] + optimizer_states: List[Tensor] = [] for state_name in self.optimizer.state_names(): # Extract the offsets (start, end) = optimizer_state_byte_offsets[state_name] @@ -2704,7 +2704,7 @@ def _split_optimizer_states_kv_zch_whole_row( sorted_ids: torch.Tensor, no_snapshot: bool = True, should_flush: bool = False, - ) -> list[list[torch.Tensor]]: + ) -> List[List[torch.Tensor]]: dtype = self.weights_precision.as_dtype() # Row and dimension counts per table @@ -2712,7 +2712,7 @@ def _split_optimizer_states_kv_zch_whole_row( (rows_, dims_) = zip(*self.embedding_specs) # Cumulative row counts per (virtual) table for rowwise states - row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_)) + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows_)) snapshot_handle, _ = self._may_create_snapshot_for_state_dict( no_snapshot=no_snapshot, @@ -2722,7 +2722,7 @@ def _split_optimizer_states_kv_zch_whole_row( # pyre-ignore[53] def _fetch_offloaded_optimizer_states( t: int, - ) -> list[Tensor]: + ) -> List[Tensor]: d: int = dims_[t] # pyre-ignore[16] @@ -2758,7 +2758,7 @@ def _fetch_offloaded_optimizer_states( ) # Now split up the buffer into N views, N for each optimizer state - optimizer_states: list[PartiallyMaterializedTensor] = [] + optimizer_states: List[PartiallyMaterializedTensor] = [] for state_name in self.optimizer.state_names(): state_dtype = self.optimizer_state_dtypes.get( state_name, SparseType.FP32 @@ -2838,10 +2838,10 @@ def _fetch_offloaded_optimizer_states( @torch.jit.export def split_optimizer_states( self, - sorted_id_tensor: Optional[list[torch.Tensor]] = None, + sorted_id_tensor: Optional[List[torch.Tensor]] = None, no_snapshot: bool = True, should_flush: bool = False, - ) -> list[list[torch.Tensor]]: + ) -> List[List[torch.Tensor]]: """ Returns a list of optimizer states split by table. @@ -2906,14 +2906,14 @@ def split_optimizer_states( @torch.jit.export def get_optimizer_state( self, - sorted_id_tensor: Optional[list[torch.Tensor]], + sorted_id_tensor: Optional[List[torch.Tensor]], no_snapshot: bool = True, should_flush: bool = False, - ) -> list[dict[str, torch.Tensor]]: + ) -> List[Dict[str, torch.Tensor]]: """ Returns a list of dictionaries of optimizer states split by table. """ - states_list: list[list[Tensor]] = self.split_optimizer_states( + states_list: List[List[Tensor]] = self.split_optimizer_states( sorted_id_tensor=sorted_id_tensor, no_snapshot=no_snapshot, should_flush=should_flush, @@ -2922,7 +2922,7 @@ def get_optimizer_state( return [dict(zip(state_names, states)) for states in states_list] @torch.jit.export - def debug_split_embedding_weights(self) -> list[torch.Tensor]: + def debug_split_embedding_weights(self) -> List[torch.Tensor]: """ Returns a list of weights, split by table. @@ -3014,11 +3014,11 @@ def split_embedding_weights( self, no_snapshot: bool = True, should_flush: bool = False, - ) -> tuple[ # TODO: make this a NamedTuple for readability - Union[list[PartiallyMaterializedTensor], list[torch.Tensor]], - Optional[list[torch.Tensor]], - Optional[list[torch.Tensor]], - Optional[list[torch.Tensor]], + ) -> Tuple[ # TODO: make this a NamedTuple for readability + Union[List[PartiallyMaterializedTensor], List[torch.Tensor]], + Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], ]: """ This method is intended to be used by the checkpointing engine @@ -3208,7 +3208,7 @@ def _apply_state_dict_w_offloading(self) -> None: # Row count per table (rows, _) = zip(*self.embedding_specs) # Cumulative row counts per table for rowwise states - row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows)) + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows)) for t, _ in enumerate(self.embedding_specs): # pyre-ignore [16] @@ -3237,7 +3237,7 @@ def _apply_state_dict_no_offloading(self) -> None: # Row count per table (rows, _) = zip(*self.embedding_specs) # Cumulative row counts per table for rowwise states - row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows)) + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows)) def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None: device = dst.device @@ -3328,7 +3328,7 @@ def apply_state_dict(self) -> None: def streaming_write_weight_and_id_per_table( self, weight_state: torch.Tensor, - opt_states: list[torch.Tensor], + opt_states: List[torch.Tensor], id_tensor: torch.Tensor, row_offset: int, ) -> None: @@ -3544,8 +3544,8 @@ def prepare_inputs( indices: Tensor, offsets: Tensor, per_sample_weights: Optional[Tensor] = None, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, - ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]: + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]: """ Prepare TBE inputs """ @@ -3612,7 +3612,7 @@ def _report_ssd_l1_cache_stats(self) -> None: ssd_cache_stats = self.ssd_cache_stats.tolist() if len(self.last_reported_ssd_stats) == 0: self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats) - ssd_cache_stats_delta: list[float] = [0.0] * len(ssd_cache_stats) + ssd_cache_stats_delta: List[float] = [0.0] * len(ssd_cache_stats) for i in range(len(ssd_cache_stats)): ssd_cache_stats_delta[i] = ( ssd_cache_stats[i] - self.last_reported_ssd_stats[i] @@ -4179,7 +4179,7 @@ def _recording_to_timer( def fetch_from_l1_sp_w_row_ids( self, row_ids: torch.Tensor, only_get_optimizer_states: bool = False - ) -> tuple[list[torch.Tensor], torch.Tensor]: + ) -> Tuple[List[torch.Tensor], torch.Tensor]: """ Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids. @return: updated_weights/optimizer_states, mask of which rows are filled @@ -4201,7 +4201,7 @@ def fetch_from_l1_sp_w_row_ids( def split_results_by_opt_states( updated_weights: torch.Tensor, cache_location_mask: torch.Tensor - ) -> tuple[list[torch.Tensor], torch.Tensor]: + ) -> Tuple[List[torch.Tensor], torch.Tensor]: if not only_get_optimizer_states: return [updated_weights], cache_location_mask # TODO: support mixed dimension case diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py index 46ad811e52..a84333101c 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Optional, Union +from typing import List, Optional, Union import torch @@ -248,7 +248,7 @@ def __eq__(self, tensor1, tensor2, **kwargs): return torch.equal(tensor1.full_tensor(), tensor2.full_tensor()) - def get_kvtensor_serializable_metadata(self) -> list[str]: + def get_kvtensor_serializable_metadata(self) -> List[str]: return self._wrapped.get_kvtensor_serializable_metadata() def __hash__(self): diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py index 847cfe5764..0fa6e40ff1 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py @@ -11,7 +11,7 @@ import json import logging import os -from typing import Optional +from typing import List, Optional import fbgemm_gpu # noqa F401 import torch # usort:skip @@ -144,7 +144,7 @@ def extract_params( indices: torch.Tensor, offsets: torch.Tensor, per_sample_weights: Optional[torch.Tensor] = None, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> TBEDataConfig: """ Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig. @@ -266,7 +266,7 @@ def report_stats( offsets: torch.Tensor, op_id: str = "", per_sample_weights: Optional[torch.Tensor] = None, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> None: """ Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore. diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/utils/offsets.py b/fbgemm_gpu/fbgemm_gpu/tbe/utils/offsets.py index e0ba2851e8..8b43441534 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/utils/offsets.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/utils/offsets.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import numpy as np import torch @@ -21,7 +21,7 @@ def get_table_batched_offsets_from_dense( L: Optional[int] = None, total_B: Optional[int] = None, use_cpu: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: if L is None and total_B is None: (T, B, L) = merged_indices.size() total_B = T * B @@ -37,7 +37,7 @@ def get_table_batched_offsets_from_dense( ) -def get_offsets_from_dense(indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: (B, L) = indices.size() return ( indices.contiguous().view(-1), diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/utils/quantize.py b/fbgemm_gpu/fbgemm_gpu/tbe/utils/quantize.py index 23db8bf637..99d8a28011 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/utils/quantize.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/utils/quantize.py @@ -7,7 +7,7 @@ # pyre-strict # pyre-ignore-all-errors[61] -from typing import Optional +from typing import Optional, Tuple import torch @@ -22,7 +22,7 @@ def quantize_embs( weight: torch.Tensor, weight_ty: SparseType, fp8_config: Optional[FP8QuantizationConfig] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: weight = weight.detach() if weight_ty == SparseType.FP32: q_weight = weight.float() diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py index c27296ac05..bd64223a09 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py @@ -8,7 +8,7 @@ import logging from dataclasses import dataclass -from typing import Optional +from typing import List, Optional, Tuple import numpy as np import numpy.typing as npt @@ -32,20 +32,20 @@ class TBERequest: indices: torch.Tensor offsets: torch.Tensor per_sample_weights: Optional[torch.Tensor] = None - Bs_per_feature_per_rank: Optional[list[list[int]]] = None + Bs_per_feature_per_rank: Optional[List[List[int]]] = None - def unpack_2(self) -> tuple[torch.Tensor, torch.Tensor]: + def unpack_2(self) -> Tuple[torch.Tensor, torch.Tensor]: return (self.indices, self.offsets) def unpack_3( self, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return (self.indices, self.offsets, self.per_sample_weights) def unpack_4( self, - ) -> tuple[ - torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]] + ) -> Tuple[ + torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[List[List[int]]] ]: return ( self.indices, @@ -68,7 +68,7 @@ def generate_requests_from_data_file( tables: Optional[str] = None, index_dtype: Optional[torch.dtype] = None, offset_dtype: Optional[torch.dtype] = None, -) -> list[TBERequest]: +) -> List[TBERequest]: """ Generate TBE requests from the input data file. If `requests_data_file` is provided, `indices_file` and `offsets_file` should not be provided. If either `indices_file` @@ -178,12 +178,12 @@ def generate_int_data_from_stats( def generate_pooling_factors_from_stats( iters: int, - Bs: list[int], + Bs: List[int], L: int, sigma_L: int, # distribution of pooling factors length_dist: str, -) -> tuple[int, torch.Tensor]: +) -> Tuple[int, torch.Tensor]: """ Generate pooling factors for the TBE requests from the given stats """ @@ -211,7 +211,7 @@ def generate_batch_sizes_from_stats( vbe_num_ranks: int, # Distribution of batch sizes batch_size_dist: str, -) -> tuple[list[int], list[list[int]]]: +) -> Tuple[List[int], List[List[int]]]: """ Generate batch sizes for features from the given stats """ @@ -234,7 +234,7 @@ def generate_batch_sizes_from_stats( def generate_indices_uniform( iters: int, - Bs: list[int], + Bs: List[int], L: int, E: int, use_variable_L: bool, @@ -267,7 +267,7 @@ def generate_indices_uniform( def generate_indices_zipf( iters: int, - Bs: list[int], + Bs: List[int], L: int, E: int, alpha: float, @@ -324,7 +324,7 @@ def generate_indices_zipf( def update_indices_with_random_reuse( iters: int, - Bs: list[int], + Bs: List[int], L: int, reuse: float, indices: torch.Tensor, @@ -411,7 +411,7 @@ def generate_requests( # noqa C901 vbe_num_ranks: Optional[int] = None, index_dtype: Optional[torch.dtype] = None, offset_dtype: Optional[torch.dtype] = None, -) -> list[TBERequest]: +) -> List[TBERequest]: # TODO: refactor and split into helper functions to separate load from file, # generate from distribution, and other future methods of generating data if ( diff --git a/fbgemm_gpu/fbgemm_gpu/tbe_input_multiplexer.py b/fbgemm_gpu/fbgemm_gpu/tbe_input_multiplexer.py index c6b74d7ad9..a5c5238c9e 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe_input_multiplexer.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe_input_multiplexer.py @@ -10,7 +10,7 @@ import abc from dataclasses import dataclass -from typing import Optional +from typing import List, Optional from torch import Tensor @@ -32,15 +32,15 @@ class TBEInfo: col_offset: the shard offset of the current rank on column (dim) """ - table_names: list[str] - table_heights: list[int] + table_names: List[str] + table_heights: List[int] tbe_uuid: str - feature_table_map: list[int] - table_dims: list[int] - full_table_heights: list[int] - full_table_dims: list[int] - row_offset: list[int] - col_offset: list[int] + feature_table_map: List[int] + table_dims: List[int] + full_table_heights: List[int] + full_table_dims: List[int] + row_offset: List[int] + col_offset: List[int] @dataclass(frozen=True) @@ -55,7 +55,7 @@ class TBEInputInfo: indices: Tensor offsets: Tensor - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None class TBEInputMultiplexer(abc.ABC): diff --git a/fbgemm_gpu/fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py b/fbgemm_gpu/fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py index a9560ffb1c..964612b45e 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py @@ -9,7 +9,7 @@ # pyre-ignore-all-errors[6] -from typing import Optional, Union +from typing import List, Optional, Tuple, Union import torch import triton # @manual @@ -472,7 +472,7 @@ def triton_jagged_to_dense_optimization_2d( # In FBGEMM it was computed by GPU but in triton currently has some compilation issue so we use CUP computation method as workaround # However in real-world case if we only dealing with 2d jagged tensor we don't need to use this function at all def _jagged_offsets_to_dense_indice( - offsets: list[torch.Tensor], dense_strides: list[int], dense_sizes: list[int] + offsets: List[torch.Tensor], dense_strides: List[int], dense_sizes: List[int] ) -> torch.Tensor: output_offset = torch.zeros(len(offsets[-1]) - 1, device="cpu", dtype=torch.int32) @@ -532,8 +532,8 @@ def _jagged_offsets_to_dense_indice( # not be affected at all def jagged_to_dense( jagged_values: torch.Tensor, - jagged_offsets: list[torch.Tensor], - jagged_max_lengths: list[int], + jagged_offsets: List[torch.Tensor], + jagged_max_lengths: List[int], padding_value: float = 0.0, # padding value currently use 0.0 as default value operation_function: Union[ str, None @@ -720,10 +720,10 @@ def triton_dense_to_jagged( def dense_to_jagged( dense: torch.Tensor, - jagged_offsets: list[torch.Tensor], + jagged_offsets: List[torch.Tensor], operation_function: Union[str, None] = None, operation_jagged_values: Union[torch.Tensor, None] = None, -) -> tuple[torch.Tensor, list[torch.Tensor]]: +) -> Tuple[torch.Tensor, List[torch.Tensor]]: thread_block_row_size = 32 thread_block_col_size = 32 @@ -780,7 +780,7 @@ def dense_to_jagged( # jagged_tensor + dense -> dense def jagged_dense_elementwise_add_dense_output( jagged_values: Tensor, - jagged_offsets: list[Tensor], + jagged_offsets: List[Tensor], # pyre-fixme[2]: Parameter must be annotated. dense, ) -> Tensor: @@ -800,8 +800,8 @@ def jagged_dense_elementwise_add_dense_output( # jagged_tensor + dense -> jagged_tensor def jagged_dense_elementwise_add_jagged_output( - jagged_values: Optional[Tensor], jagged_offsets: list[Tensor], dense: Tensor -) -> tuple[Tensor, list[Tensor]]: + jagged_values: Optional[Tensor], jagged_offsets: List[Tensor], dense: Tensor +) -> Tuple[Tensor, List[Tensor]]: return dense_to_jagged( dense, @@ -813,8 +813,8 @@ def jagged_dense_elementwise_add_jagged_output( # jagged_tensor * dense -> jagged_tensor def jagged_dense_elementwise_mul_jagged_output( - jagged_values: Optional[Tensor], jagged_offsets: list[Tensor], dense: Tensor -) -> tuple[Tensor, list[Tensor]]: + jagged_values: Optional[Tensor], jagged_offsets: List[Tensor], dense: Tensor +) -> Tuple[Tensor, List[Tensor]]: return dense_to_jagged( dense, diff --git a/fbgemm_gpu/fbgemm_gpu/utils/torch_library.py b/fbgemm_gpu/fbgemm_gpu/utils/torch_library.py index 3d0a6ce97d..a05b791ce8 100644 --- a/fbgemm_gpu/fbgemm_gpu/utils/torch_library.py +++ b/fbgemm_gpu/fbgemm_gpu/utils/torch_library.py @@ -8,7 +8,7 @@ # pyre-strict import re -from typing import Callable +from typing import Callable, Dict import torch @@ -112,7 +112,7 @@ def register_dispatch(self, op_name: str, dispatch_key: str, fn: Callable) -> No self.lib.impl(op_name, fn, dispatch_key) # pyre-ignore[24] - def register(self, op_name: str, functors: dict[str, Callable]) -> None: + def register(self, op_name: str, functors: Dict[str, Callable]) -> None: """ Registers a set of dispatches for a defined operator. diff --git a/fbgemm_gpu/list_versions/cli_run.py b/fbgemm_gpu/list_versions/cli_run.py index 3f6df7a0bf..654aca3203 100644 --- a/fbgemm_gpu/list_versions/cli_run.py +++ b/fbgemm_gpu/list_versions/cli_run.py @@ -9,8 +9,9 @@ import logging import subprocess +import typing from datetime import datetime -from typing import Union +from typing import List, Union import click @@ -36,7 +37,7 @@ def __init__( self._timestamp = timestamp self._visible = visible - def to_dict(self) -> dict[str, Union[int, str]]: + def to_dict(self) -> typing.Dict[str, Union[int, str]]: return { "cli": self._cli, "stdout": self._stdout, @@ -51,7 +52,7 @@ class CLI: def __init__(self) -> None: pd.options.display.max_rows pd.set_option("display.max_colwidth", None) - self._cli_outputs: list[CLIOutput] = [ + self._cli_outputs: List[CLIOutput] = [ CLIOutput( cli="python –c “import torch; print(torch.__version__)”", stdout="{}".format(torch.__version__), @@ -64,7 +65,7 @@ def __init__(self) -> None: def run( self, - cli: Union[str, list[str]], + cli: Union[str, List[str]], visible: bool = True, input: str = "", capture_output: bool = True, @@ -100,7 +101,7 @@ def run( self._cli_outputs.append(result) return result - def run_piped(self, clis: list[str]) -> None: + def run_piped(self, clis: List[str]) -> None: the_input = "" for cli in clis[:-1]: result = self.run( diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index 1031297307..f07c4650e8 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -16,7 +16,7 @@ import textwrap from dataclasses import dataclass from datetime import date -from typing import Optional +from typing import List, Optional import setuptools import setuptools_git_versioning as gitversion @@ -31,12 +31,12 @@ @dataclass(frozen=True) class FbgemmGpuBuild: args: argparse.Namespace - other_args: list[str] + other_args: List[str] """FBGEMM_GPU Package Build Configuration""" @classmethod - def from_args(cls, argv: list[str]): + def from_args(cls, argv: List[str]): parser = argparse.ArgumentParser(description="FBGEMM_GPU Build Setup") parser.add_argument( "--verbose", @@ -268,7 +268,7 @@ def package_version(self): ) return full_version_string - def cmake_args(self) -> list[str]: + def cmake_args(self) -> List[str]: def _get_cxx11_abi(): try: value = int(torch._C._GLIBCXX_USE_CXX11_ABI) @@ -581,7 +581,7 @@ def run(self): self.print_versions() -def main(argv: list[str]) -> None: +def main(argv: List[str]) -> None: # Handle command line args before passing to main setup() method. build = FbgemmGpuBuild.from_args(argv) # Repair command line args for setup() method. diff --git a/fbgemm_gpu/test/batched_unary_embeddings_test.py b/fbgemm_gpu/test/batched_unary_embeddings_test.py index 10f7eed548..3aa51304e8 100644 --- a/fbgemm_gpu/test/batched_unary_embeddings_test.py +++ b/fbgemm_gpu/test/batched_unary_embeddings_test.py @@ -12,7 +12,7 @@ import sys import unittest from math import sqrt -from typing import Callable +from typing import Callable, List, Tuple import fbgemm_gpu.batched_unary_embeddings_ops as batched_unary_embeddings_ops import numpy as np @@ -59,7 +59,7 @@ def torch_compiled(model: Callable, **kwargs) -> Callable: class TableBatchedEmbeddingsTest(unittest.TestCase): class RefEmb(torch.nn.Module): - def __init__(self, num_tasks: int, hash_sizes: list[int]) -> None: + def __init__(self, num_tasks: int, hash_sizes: List[int]) -> None: super().__init__() self.num_tasks = num_tasks self.hash_sizes = hash_sizes @@ -79,7 +79,7 @@ def __init__(self, num_tasks: int, hash_sizes: list[int]) -> None: self.emb_modules.append(emb) def forward( - self, offsets: list[torch.Tensor], indices: list[torch.Tensor] + self, offsets: List[torch.Tensor], indices: List[torch.Tensor] ) -> torch.Tensor: tt_list = [] for n in range(self.num_tasks): @@ -99,7 +99,7 @@ def _generate_unary_features( num_embeddings: int, # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List[]` to avoid runtime subscripting errors. - ) -> tuple[list, list, list]: + ) -> Tuple[List, List, List]: lengths = [] offsets = [] indices = [] diff --git a/fbgemm_gpu/test/combine/common.py b/fbgemm_gpu/test/combine/common.py index 098ef5fe58..75c8bb5291 100644 --- a/fbgemm_gpu/test/combine/common.py +++ b/fbgemm_gpu/test/combine/common.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Optional +from typing import List, Optional, Tuple import fbgemm_gpu import torch @@ -25,17 +25,17 @@ class TBEInputPrepareReference(torch.nn.Module): - def __init__(self, include_last_offsets: list[bool]) -> None: + def __init__(self, include_last_offsets: List[bool]) -> None: super().__init__() self.include_last_offsets = include_last_offsets def forward( # noqa C901 self, - indices_list: list[torch.Tensor], - offsets_list: list[torch.Tensor], - per_sample_weights_list: list[torch.Tensor], + indices_list: List[torch.Tensor], + offsets_list: List[torch.Tensor], + per_sample_weights_list: List[torch.Tensor], batch_size: Optional[int] = None, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: size = 0 assert len(indices_list) > 0 assert len(indices_list) == len(offsets_list) @@ -73,7 +73,7 @@ def forward( # noqa C901 offsets_accs[i + 1] = offsets_accs[i] + indices_list[i].size(0) assert offsets_accs[-1] == combined_indices.size(0) - combined_offsets_size: list[int] = ( + combined_offsets_size: List[int] = ( [int(offsets_starts[-1].item()) + 1] if batch_size is None else [batch_size * len(offsets_list) + 1] diff --git a/fbgemm_gpu/test/combine/input_combine_test.py b/fbgemm_gpu/test/combine/input_combine_test.py index 4e3c428591..f0239bca93 100644 --- a/fbgemm_gpu/test/combine/input_combine_test.py +++ b/fbgemm_gpu/test/combine/input_combine_test.py @@ -8,6 +8,8 @@ # pyre-strict import unittest +from typing import List, Tuple + import fbgemm_gpu # noqa: F401 import torch @@ -57,15 +59,15 @@ def _get_inputs(self, dtypes, device=DEFAULT_DEVICE): def _get_prepadded_inputs( self, - dtypes: list[torch.dtype], + dtypes: List[torch.dtype], device: torch._C.device = DEFAULT_DEVICE, include_last: bool = True, - ) -> tuple[ - list[torch.Tensor], - list[torch.Tensor], - list[torch.Tensor], - list[torch.Tensor], - list[bool], + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + List[bool], ]: indices_list = [ torch.tensor([1, 2, 3, 123, 123, 123], dtype=dtypes[0], device=device), diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index d8838b8447..491176f769 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -11,7 +11,7 @@ import itertools import sys import unittest -from typing import Callable +from typing import Callable, Dict, List, Tuple import fbgemm_gpu import fbgemm_gpu.sparse_ops @@ -26,7 +26,7 @@ if not open_source: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") -suppressed_list: list[HealthCheck] = ( +suppressed_list: List[HealthCheck] = ( [HealthCheck.differing_executors] if getattr(HealthCheck, "differing_executors", False) else [] @@ -43,7 +43,7 @@ # Please avoid putting tests here, you should put operator-specific # skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. -additional_decorators: dict[str, list[Callable]] = { +additional_decorators: Dict[str, List[Callable]] = { "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [ # This operator has been grandfathered in. We need to fix this test failure. unittest.expectedFailure, @@ -117,9 +117,9 @@ def generate_jagged_tensor( # dynamo to mark the input as dynamic shape to make sure symbolic # shape is generated mark_dynamic: bool = False, -) -> tuple[torch.Tensor, list[torch.LongTensor], npt.NDArray]: +) -> Tuple[torch.Tensor, List[torch.LongTensor], npt.NDArray]: max_lengths = np.random.randint(low=1, high=10, size=(num_jagged_dim,)) - x_offsets: list[torch.LongTensor] = [] + x_offsets: List[torch.LongTensor] = [] num_lengths = outer_dense_size for d in range(num_jagged_dim): # Sometimes length[i] exceed max_L meaning jagged->dense will be @@ -161,7 +161,7 @@ def generate_jagged_tensor( def to_padded_dense( values: torch.Tensor, - offsets: list[torch.LongTensor], + offsets: List[torch.LongTensor], max_lengths: npt.NDArray, padding_value: float = 0, ) -> torch.Tensor: diff --git a/fbgemm_gpu/test/jagged/dense_to_jagged_test.py b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py index 0e6e08e56a..b774537a25 100644 --- a/fbgemm_gpu/test/jagged/dense_to_jagged_test.py +++ b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py @@ -9,6 +9,7 @@ # pyre-ignore-all-errors[56] import unittest +from typing import List, Tuple import hypothesis.strategies as st import torch @@ -248,8 +249,8 @@ def test_dense_to_jagged_dynamic_shape( def jagged_to_dense( values: torch.Tensor, - offsets: list[torch.LongTensor], - max_lengths: list[int], + offsets: List[torch.LongTensor], + max_lengths: List[int], ) -> torch.Tensor: return torch.ops.fbgemm.jagged_to_padded_dense(values, offsets, max_lengths) @@ -264,13 +265,13 @@ def jagged_to_dense( torch._dynamo.mark_dynamic(dense, -1) def dense_to_jagged_withL( - dense: torch.Tensor, offsets: list[torch.LongTensor], total_L: list[int] - ) -> tuple[torch.Tensor, torch.Tensor]: + dense: torch.Tensor, offsets: List[torch.LongTensor], total_L: List[int] + ) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.fbgemm.dense_to_jagged(dense, offsets, total_L) def dense_to_jagged_noL( - dense: torch.Tensor, offsets: list[torch.LongTensor] - ) -> tuple[torch.Tensor, torch.Tensor]: + dense: torch.Tensor, offsets: List[torch.LongTensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.fbgemm.dense_to_jagged(dense, offsets) jagged_values, jagged_offsets = dense_to_jagged_noL(dense, offsets) diff --git a/fbgemm_gpu/test/jagged/elementwise_binary_test.py b/fbgemm_gpu/test/jagged/elementwise_binary_test.py index 5503ce26b2..6c6c45d330 100644 --- a/fbgemm_gpu/test/jagged/elementwise_binary_test.py +++ b/fbgemm_gpu/test/jagged/elementwise_binary_test.py @@ -9,6 +9,7 @@ # pyre-ignore-all-errors[56] import unittest +from typing import List, Tuple import hypothesis.strategies as st import numpy as np @@ -227,20 +228,20 @@ def test_jagged_elementwise_binary_dynamic_shape( x_padded = to_padded_dense(x_values, x_offsets, max_lengths) def jagged_dense_elementwise_add( - x_values: torch.Tensor, x_offsets: list[torch.LongTensor], y: torch.Tensor + x_values: torch.Tensor, x_offsets: List[torch.LongTensor], y: torch.Tensor ) -> torch.Tensor: return torch.ops.fbgemm.jagged_dense_elementwise_add(x_values, x_offsets, y) def jagged_dense_elementwise_add_jagged_output( - x_values: torch.Tensor, x_offsets: list[torch.LongTensor], y: torch.Tensor - ) -> tuple[torch.Tensor, list[torch.LongTensor]]: + x_values: torch.Tensor, x_offsets: List[torch.LongTensor], y: torch.Tensor + ) -> Tuple[torch.Tensor, List[torch.LongTensor]]: return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( x_values, x_offsets, y ) def jagged_dense_elementwise_mul( - x_values: torch.Tensor, x_offsets: list[torch.LongTensor], y: torch.Tensor - ) -> tuple[torch.Tensor, list[torch.LongTensor]]: + x_values: torch.Tensor, x_offsets: List[torch.LongTensor], y: torch.Tensor + ) -> Tuple[torch.Tensor, List[torch.LongTensor]]: return torch.ops.fbgemm.jagged_dense_elementwise_mul(x_values, x_offsets, y) if operation == "add": diff --git a/fbgemm_gpu/test/jagged/expand_into_jagged_permute_test.py b/fbgemm_gpu/test/jagged/expand_into_jagged_permute_test.py index 060443536f..590e43b8fd 100644 --- a/fbgemm_gpu/test/jagged/expand_into_jagged_permute_test.py +++ b/fbgemm_gpu/test/jagged/expand_into_jagged_permute_test.py @@ -11,6 +11,7 @@ import itertools import random import unittest +from typing import List import hypothesis.strategies as st import torch @@ -30,9 +31,9 @@ class ExpandIntoJaggedPermuteTest(unittest.TestCase): @staticmethod def expand_into_jagged_permute_ref_( - permute: list[int], - length: list[int], - ) -> list[int]: + permute: List[int], + length: List[int], + ) -> List[int]: offsets = [0] + list(itertools.accumulate(length)) output_permute = [] for r in permute: diff --git a/fbgemm_gpu/test/jagged/slice_test.py b/fbgemm_gpu/test/jagged/slice_test.py index 73d2f6d581..98ced3647f 100644 --- a/fbgemm_gpu/test/jagged/slice_test.py +++ b/fbgemm_gpu/test/jagged/slice_test.py @@ -10,6 +10,7 @@ import random import unittest +from typing import List, Tuple import hypothesis.strategies as st import torch @@ -69,13 +70,13 @@ def jagged_slice_ref( offsets: torch.Tensor, start: torch.Tensor, slice_length: int, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: end_offsets_ = slice_length + start + offsets[:-1] end_offsets = torch.where( end_offsets_ > offsets[1:], offsets[1:], end_offsets_ ) start_offsets = start + offsets[:-1] - indices_to_select: list[torch.Tensor] = [] + indices_to_select: List[torch.Tensor] = [] for i in range(end_offsets.size(0)): indices_to_select.append( torch.arange(start_offsets[i].item(), end_offsets[i].item()) diff --git a/fbgemm_gpu/test/jagged/unique_indices_test.py b/fbgemm_gpu/test/jagged/unique_indices_test.py index 68d41e9fe1..78999b421a 100644 --- a/fbgemm_gpu/test/jagged/unique_indices_test.py +++ b/fbgemm_gpu/test/jagged/unique_indices_test.py @@ -11,6 +11,7 @@ import itertools import random import unittest +from typing import List import hypothesis.strategies as st import numpy as np @@ -31,7 +32,7 @@ ) -def hash_size_cumsum_to_offsets(hash_size_cum_sum_list: list[int]) -> list[int]: +def hash_size_cumsum_to_offsets(hash_size_cum_sum_list: List[int]) -> List[int]: hash_size_offsets_list = [0] count = 0 for f in range(1, len(hash_size_cum_sum_list)): diff --git a/fbgemm_gpu/test/merge_pooled_embeddings_test.py b/fbgemm_gpu/test/merge_pooled_embeddings_test.py index 9d9ab8f4ab..996c992115 100644 --- a/fbgemm_gpu/test/merge_pooled_embeddings_test.py +++ b/fbgemm_gpu/test/merge_pooled_embeddings_test.py @@ -9,6 +9,7 @@ import unittest +from typing import Tuple import fbgemm_gpu @@ -29,7 +30,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") -typed_gpu_unavailable: tuple[bool, str] = gpu_unavailable +typed_gpu_unavailable: Tuple[bool, str] = gpu_unavailable @unittest.skipIf(*gpu_unavailable) diff --git a/fbgemm_gpu/test/permute/common.py b/fbgemm_gpu/test/permute/common.py index d6bc4436d4..fb8668d3e6 100644 --- a/fbgemm_gpu/test/permute/common.py +++ b/fbgemm_gpu/test/permute/common.py @@ -7,7 +7,7 @@ # pyre-strict # pyre-ignore-all-errors[56] -from typing import Any +from typing import Any, List, Tuple import fbgemm_gpu import torch @@ -30,10 +30,10 @@ ) from fbgemm_gpu.test.test_utils import gpu_unavailable, on_arm_platform -typed_gpu_unavailable: tuple[bool, str] = gpu_unavailable -typed_on_arm_platform: tuple[bool, str] = on_arm_platform +typed_gpu_unavailable: Tuple[bool, str] = gpu_unavailable +typed_on_arm_platform: Tuple[bool, str] = on_arm_platform -suppressed_list: list[HealthCheck] = ( +suppressed_list: List[HealthCheck] = ( [HealthCheck.not_a_test_method] if getattr(HealthCheck, "not_a_test_method", False) else [] diff --git a/fbgemm_gpu/test/permute/permute_pooled_embedding_test.py b/fbgemm_gpu/test/permute/permute_pooled_embedding_test.py index a828c82ce6..cfc40a3ace 100644 --- a/fbgemm_gpu/test/permute/permute_pooled_embedding_test.py +++ b/fbgemm_gpu/test/permute/permute_pooled_embedding_test.py @@ -10,6 +10,7 @@ import inspect import sys import unittest +from typing import List import hypothesis.strategies as st import torch @@ -25,7 +26,7 @@ else: from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, on_arm_platform, optests -suppressed_list: list[HealthCheck] = ( +suppressed_list: List[HealthCheck] = ( [HealthCheck.not_a_test_method] if getattr(HealthCheck, "not_a_test_method", False) else [] diff --git a/fbgemm_gpu/test/quantize/bfloat16_test.py b/fbgemm_gpu/test/quantize/bfloat16_test.py index a79bc47e45..04a1e2d4c8 100644 --- a/fbgemm_gpu/test/quantize/bfloat16_test.py +++ b/fbgemm_gpu/test/quantize/bfloat16_test.py @@ -8,6 +8,7 @@ import unittest from ctypes import c_float, c_int32, cast, POINTER, pointer +from typing import Tuple import hypothesis.strategies as st import numpy as np @@ -123,7 +124,7 @@ def test_quantize_and_dequantize_op(self, nrows: int, ncols: int) -> None: ) @settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much]) def test_quantize_and_dequantize_op_cuda_large_nrows_bf16( - self, ncols_nrows: tuple[int, int] + self, ncols_nrows: Tuple[int, int] ) -> None: ncols, nrows = ncols_nrows input_data = torch.rand(nrows, ncols).float() diff --git a/fbgemm_gpu/test/quantize/comm_codec_test.py b/fbgemm_gpu/test/quantize/comm_codec_test.py index 0f88090ca8..5d36d1d8f4 100644 --- a/fbgemm_gpu/test/quantize/comm_codec_test.py +++ b/fbgemm_gpu/test/quantize/comm_codec_test.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Optional +from typing import Optional, Tuple import hypothesis.strategies as st import torch @@ -45,7 +45,7 @@ class QuantizedCommCodecTest(unittest.TestCase): ) def test_quantized_comm_codec( self, - comm_precisions_loss_scale: tuple[SparseType, Optional[float]], + comm_precisions_loss_scale: Tuple[SparseType, Optional[float]], row_size: int, col_size: int, rand_seed: int, diff --git a/fbgemm_gpu/test/quantize/hfp8_test.py b/fbgemm_gpu/test/quantize/hfp8_test.py index 876fdae8b7..0d9dcdf70a 100644 --- a/fbgemm_gpu/test/quantize/hfp8_test.py +++ b/fbgemm_gpu/test/quantize/hfp8_test.py @@ -7,6 +7,7 @@ # pyre-strict import unittest +from typing import Dict, Tuple import hypothesis.strategies as st import torch @@ -21,7 +22,7 @@ class TestHFP8QuantizationConversion(unittest.TestCase): # min_normal_pos is the minimal of normal numbers def _get_hfp8_dynamic_range( self, ebits: int, mbits: int, bias: int - ) -> tuple[int, int, int]: + ) -> Tuple[int, int, int]: max_pos = (1 << ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits)) min_pos = 2 ** (1 - bias - mbits) min_normal_pos = 2 ** (1 - bias) @@ -29,7 +30,7 @@ def _get_hfp8_dynamic_range( def _get_hfp8_config( self, - ) -> tuple[int, int, dict[int, int], dict[int, int], dict[int, int]]: + ) -> Tuple[int, int, Dict[int, int], Dict[int, int], Dict[int, int]]: # TODO: set up test for 1-5-2 format # TODO: parameterize ebits and mbits in unit test ebits = 4 diff --git a/fbgemm_gpu/test/quantize/mx/common.py b/fbgemm_gpu/test/quantize/mx/common.py index 02105e1aee..7b64039ff0 100644 --- a/fbgemm_gpu/test/quantize/mx/common.py +++ b/fbgemm_gpu/test/quantize/mx/common.py @@ -10,7 +10,7 @@ import struct from enum import Enum, IntEnum -from typing import Optional, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -31,7 +31,7 @@ class RoundingMode(IntEnum): even = 2 @staticmethod - def string_enums() -> list[str]: + def string_enums() -> List[str]: return [s.name for s in list(RoundingMode)] @@ -74,12 +74,12 @@ def _get_max_norm(ebits: int, mbits: int) -> float: return 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) -_FORMAT_CACHE: dict[ElemFormat, tuple[int, int, int, float, float]] = {} +_FORMAT_CACHE: Dict[ElemFormat, Tuple[int, int, int, float, float]] = {} def _get_format_params( # noqa fmt: Union[ElemFormat, str, None], -) -> tuple[int, int, int, float, float]: +) -> Tuple[int, int, int, float, float]: """Allowed formats: - intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation - floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf @@ -150,8 +150,8 @@ def _get_format_params( # noqa def _reshape_to_blocks( - A: torch.Tensor, axes: list[int], block_size: int -) -> tuple[torch.Tensor, list[int], torch.Size, torch.Size]: + A: torch.Tensor, axes: List[int], block_size: int +) -> Tuple[torch.Tensor, List[int], torch.Size, torch.Size]: if axes is None: raise Exception( "axes required in order to determine which " @@ -192,7 +192,7 @@ def _reshape_to_blocks( pad = list(reversed(pad)) A = torch.nn.functional.pad(A, pad, mode="constant") - def _reshape(shape: list[int], reshape_block_size: int) -> list[int]: + def _reshape(shape: List[int], reshape_block_size: int) -> List[int]: for axis in axes: # Reshape to tiles if axis length > reshape_block_size if shape[axis] >= reshape_block_size: @@ -214,7 +214,7 @@ def _reshape(shape: list[int], reshape_block_size: int) -> list[int]: def _undo_reshape_to_blocks( - A: torch.Tensor, padded_shape: torch.Size, orig_shape: torch.Size, axes: list[int] + A: torch.Tensor, padded_shape: torch.Size, orig_shape: torch.Size, axes: List[int] ) -> torch.Tensor: # Undo tile reshaping A = A.view(padded_shape) @@ -228,7 +228,7 @@ def _undo_reshape_to_blocks( return A -def get_s_e_m(value_in_float: float) -> tuple[int, int, int]: +def get_s_e_m(value_in_float: float) -> Tuple[int, int, int]: def float_to_bits(value_in_float: float) -> int: s = struct.pack("@f", value_in_float) return struct.unpack("@I", s)[0] @@ -411,7 +411,7 @@ def _shared_exponents( A: torch.Tensor, method: str = "max", rounding_mode: str = "even", - axes: Optional[list[int]] = None, + axes: Optional[List[int]] = None, ebits: int = 0, ) -> torch.Tensor: """ diff --git a/fbgemm_gpu/test/quantize/mx4_test.py b/fbgemm_gpu/test/quantize/mx4_test.py index fd2a0fb47b..983c1db47d 100644 --- a/fbgemm_gpu/test/quantize/mx4_test.py +++ b/fbgemm_gpu/test/quantize/mx4_test.py @@ -7,6 +7,7 @@ # pyre-strict import unittest +from typing import List, Tuple import fbgemm_gpu.quantize.quantize_ops # noqa F401 import hypothesis.strategies as st @@ -80,7 +81,7 @@ def fake_quantize_mx( max_norm: float = 6.0, group_size: int = 32, shared_exp_method: str = "max", - axes: list[int] = [-1], + axes: List[int] = [-1], # noqa round: str = "nearest", flush_fp32_subnorms: bool = False, ) -> torch.Tensor: @@ -253,11 +254,11 @@ def test_mx4(self, power: int, sizes: int) -> None: @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) def test_mx4_cases( self, - shape: list[int], + shape: List[int], group_size: int, rounding_mode: RoundingMode, magnitude: int, - mx4_format: tuple[int, int], + mx4_format: Tuple[int, int], device: str, ) -> None: """Test correctness of mx4 routines with random inputs and unusual shapes.""" @@ -320,11 +321,11 @@ def test_mx4_index_overflow(self) -> None: @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) def test_mx4_large_cases( self, - shape: list[int], + shape: List[int], group_size: int, rounding_mode: RoundingMode, magnitude: int, - mx4_format: tuple[int, int], + mx4_format: Tuple[int, int], device: str, ) -> None: """Test correctness of mx4 routines with random inputs and shapes that overflow int32.""" diff --git a/fbgemm_gpu/test/release/utils.py b/fbgemm_gpu/test/release/utils.py index 1e376cb75a..4218ebc950 100644 --- a/fbgemm_gpu/test/release/utils.py +++ b/fbgemm_gpu/test/release/utils.py @@ -8,8 +8,7 @@ import inspect import typing -from collections.abc import Iterable, Sequence # noqa: F401 -from typing import Optional, Union +from typing import Iterable, List, Optional, Sequence, Union # noqa: F401 import torch from torch import device, dtype, Tensor, types @@ -64,7 +63,7 @@ def get_supported_param_types(): SUPPORTED_RETURN_TYPES = { Tensor: "Tensor", - list[Tensor]: "Tensor[]", + typing.List[Tensor]: "Tensor[]", int: "int", float: "float", bool: "bool", diff --git a/fbgemm_gpu/test/runtime_monitor_test.py b/fbgemm_gpu/test/runtime_monitor_test.py index 2d687a648e..487af7368a 100644 --- a/fbgemm_gpu/test/runtime_monitor_test.py +++ b/fbgemm_gpu/test/runtime_monitor_test.py @@ -8,7 +8,7 @@ import unittest -from typing import cast +from typing import cast, List, Tuple import fbgemm_gpu import torch @@ -30,7 +30,7 @@ class TesteeAsyncSeriesTimer(AsyncSeriesTimer): - outputs: list[tuple[str, float]] + outputs: List[Tuple[str, float]] def __init__(self) -> None: self.outputs = [] @@ -43,7 +43,7 @@ def report_callback(ctx: str, duration: float) -> None: class RuntimeMonitorTest(unittest.TestCase): def assert_context( - self, timer: TesteeAsyncSeriesTimer, context_list: list[str] + self, timer: TesteeAsyncSeriesTimer, context_list: List[str] ) -> None: timer._lazy_report() self.assertEqual([t[0] for t in timer.outputs], context_list) diff --git a/fbgemm_gpu/test/sparse/block_bucketize_2d_weights_test.py b/fbgemm_gpu/test/sparse/block_bucketize_2d_weights_test.py index 5c63c1b4f9..427e8aaf91 100644 --- a/fbgemm_gpu/test/sparse/block_bucketize_2d_weights_test.py +++ b/fbgemm_gpu/test/sparse/block_bucketize_2d_weights_test.py @@ -10,6 +10,7 @@ # pyre-ignore-all-errors[56] import unittest +from typing import Type import hypothesis.strategies as st import torch @@ -60,7 +61,7 @@ def validate_out_of_order_output( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_2d_weights( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], bucketize_pos: bool, sequence: bool, weights_dim: int, @@ -179,7 +180,7 @@ def test_block_bucketize_sparse_features_2d_weights( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_2d_weights_vs_original( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], bucketize_pos: bool, sequence: bool, ) -> None: @@ -335,7 +336,7 @@ def test_block_bucketize_sparse_features_2d_weights_vs_original( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_2d_weights_pooled_vs_sequence( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], bucketize_pos: bool, weights_dtype: torch.dtype, ) -> None: @@ -484,7 +485,7 @@ def test_block_bucketize_sparse_features_2d_weights_pooled_vs_sequence( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_2d_weights_keep_orig_idx( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], bucketize_pos: bool, sequence: bool, keep_orig_idx: bool, @@ -595,7 +596,7 @@ def test_block_bucketize_sparse_features_2d_weights_keep_orig_idx( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_2d_weights_keep_orig_idx_per_feature( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], bucketize_pos: bool, sequence: bool, ) -> None: @@ -701,7 +702,7 @@ def test_block_bucketize_sparse_features_2d_weights_keep_orig_idx_per_feature( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_2d_weights_total_num_blocks( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], bucketize_pos: bool, sequence: bool, ) -> None: @@ -803,7 +804,7 @@ def test_block_bucketize_sparse_features_2d_weights_total_num_blocks( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_2d_weights_block_bucketize_pos( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], bucketize_pos: bool, sequence: bool, ) -> None: @@ -913,7 +914,7 @@ def test_block_bucketize_sparse_features_2d_weights_block_bucketize_pos( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_2d_weights_with_variable_batch_sizes( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], bucketize_pos: bool, sequence: bool, ) -> None: diff --git a/fbgemm_gpu/test/sparse/block_bucketize_test.py b/fbgemm_gpu/test/sparse/block_bucketize_test.py index 282145fab2..811b4a1cd4 100644 --- a/fbgemm_gpu/test/sparse/block_bucketize_test.py +++ b/fbgemm_gpu/test/sparse/block_bucketize_test.py @@ -11,7 +11,7 @@ import random import unittest -from typing import Optional +from typing import Optional, Type import hypothesis.strategies as st import torch @@ -963,7 +963,7 @@ def test_block_bucketize_sparse_features_total_num_blocks_raw_ids( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], has_weight: bool, bucketize_pos: bool, sequence: bool, @@ -1128,7 +1128,7 @@ def test_block_bucketize_sparse_features( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_inference( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], ) -> None: # pyre-ignore [6] lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=index_type) @@ -1208,7 +1208,7 @@ def test_block_bucketize_sparse_features_inference( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_populate_bucketized_permute( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], ) -> None: # pyre-ignore [6] lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=index_type) @@ -1566,7 +1566,7 @@ def test_block_bucketize_sparse_features_with_block_bucketize_pos( @settings(verbosity=Verbosity.verbose, max_examples=32, deadline=None) def test_block_bucketize_sparse_features_large( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], has_weight: bool, bucketize_pos: bool, sequence: bool, @@ -1673,7 +1673,7 @@ def test_block_bucketize_sparse_features_large( @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) def test_block_bucketize_sparse_features_float64_weights( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], bucketize_pos: bool, sequence: bool, ) -> None: diff --git a/fbgemm_gpu/test/sparse/common.py b/fbgemm_gpu/test/sparse/common.py index bf8aa9c5e2..8abddca752 100644 --- a/fbgemm_gpu/test/sparse/common.py +++ b/fbgemm_gpu/test/sparse/common.py @@ -11,7 +11,7 @@ import os import unittest from itertools import accumulate -from typing import Callable, Optional +from typing import Callable, Dict, List, Optional, Tuple, Type import fbgemm_gpu import torch @@ -26,7 +26,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") -suppressed_list: list[HealthCheck] = ( +suppressed_list: List[HealthCheck] = ( [HealthCheck.differing_executors] if getattr(HealthCheck, "differing_executors", False) else [] @@ -40,7 +40,7 @@ def permute_indices_ref_( weights: Optional[torch.Tensor], permute: torch.LongTensor, is_1D: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: T = lengths.size(0) B = lengths.size(1) if T == 0 or B == 0: @@ -97,7 +97,7 @@ def permute_indices_ref_( @torch.jit.script def permute_scripted( permute: torch.Tensor, lengths: torch.Tensor, indices: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ( permuted_lengths_cpu, permuted_indices_cpu, @@ -111,12 +111,12 @@ def permute_scripted( def extend_test_class( - klass: type[unittest.TestCase], + klass: Type[unittest.TestCase], # e.g. "test_faketensor__test_cumsum": [unittest.expectedFailure] # Please avoid putting tests here, you should put operator-specific # skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. - additional_decorators: Optional[dict[str, list[Callable]]] = None, + additional_decorators: Optional[Dict[str, List[Callable]]] = None, ) -> None: failures_dict_path: str = get_file_path_2( "", os.path.dirname(__file__), "failures_dict.json" diff --git a/fbgemm_gpu/test/sparse/cumsum_test.py b/fbgemm_gpu/test/sparse/cumsum_test.py index 79a3d8776b..5df9cf25fd 100644 --- a/fbgemm_gpu/test/sparse/cumsum_test.py +++ b/fbgemm_gpu/test/sparse/cumsum_test.py @@ -9,6 +9,7 @@ # pyre-ignore-all-errors[56] import unittest +from typing import Tuple, Type import hypothesis.strategies as st import numpy as np @@ -41,7 +42,7 @@ class CumSumTest(unittest.TestCase): def test_cumsum( self, n: int, - index_types: tuple[type[object], type[object]], + index_types: Tuple[Type[object], Type[object]], device: torch.device, ) -> None: (pt_index_dtype, np_index_dtype) = index_types @@ -105,7 +106,7 @@ def test_asynchronous_complete_cumsum_2d( self, n: int, b: int, - index_types: tuple[type[object], type[object]], + index_types: Tuple[Type[object], Type[object]], device: torch.device, ) -> None: (pt_index_dtype, np_index_dtype) = index_types diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py index 6c61b77bf8..6b457025d8 100644 --- a/fbgemm_gpu/test/sparse/index_select_test.py +++ b/fbgemm_gpu/test/sparse/index_select_test.py @@ -14,7 +14,7 @@ import logging import random import unittest -from typing import Callable +from typing import Callable, Dict, List import hypothesis.strategies as st import numpy as np @@ -48,7 +48,7 @@ class IndexSelectTest(unittest.TestCase): def test_index_select_dim0( self, N: int, - shape: list[int], + shape: List[int], dtype: torch.dtype, use_cpu: bool, consecutive_indices: bool, @@ -120,7 +120,7 @@ def test_group_index_select_dim0( self, num_indices: int, max_num_input_rows: int, - shape: list[int], + shape: List[int], dtype: torch.dtype, use_cpu: bool, num_groups: int, @@ -130,10 +130,10 @@ def test_group_index_select_dim0( ) -> None: device = torch.device("cpu" if use_cpu else "cuda") - input_group: list[torch.Tensor] = [] - input_ref_group: list[torch.Tensor] = [] - indices_group: list[torch.Tensor] = [] - grad_group: list[torch.Tensor] = [] + input_group: List[torch.Tensor] = [] + input_ref_group: List[torch.Tensor] = [] + indices_group: List[torch.Tensor] = [] + grad_group: List[torch.Tensor] = [] for _ in range(num_groups): if use_var_num_input_rows: num_input_rows = ( @@ -206,10 +206,10 @@ def test_group_index_select_dim0( cat_output.backward(cat_grad) def compare_tensor_groups( - test_group: list[torch.Tensor], - ref_group: list[torch.Tensor], + test_group: List[torch.Tensor], + ref_group: List[torch.Tensor], tensor_type: str, - tols: dict["str", float], + tols: Dict["str", float], ) -> None: passed = True failure_count = 0 @@ -285,9 +285,9 @@ def test_batch_index_select_dim0( # noqa: C901 ).tolist() def validate( - test_list: list[torch.Tensor], - ref_list: list[torch.Tensor], - rows: list[int], + test_list: List[torch.Tensor], + ref_list: List[torch.Tensor], + rows: List[int], val_fn: Callable[[torch.Tensor, torch.Tensor], bool], name: str, ) -> None: @@ -413,7 +413,7 @@ def validate( # Please avoid putting tests here, you should put operator-specific # skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. -additional_decorators: dict[str, list[Callable]] = { +additional_decorators: Dict[str, List[Callable]] = { "test_aot_dispatch_dynamic__test_index_select_dim0": [unittest.skip("hangs")], "test_aot_dispatch_static__test_index_select_dim0": [unittest.skip("hangs")], "test_faketensor__test_index_select_dim0": [unittest.skip("hangs")], diff --git a/fbgemm_gpu/test/sparse/misc_ops_test.py b/fbgemm_gpu/test/sparse/misc_ops_test.py index f3371677ed..013de3b01d 100644 --- a/fbgemm_gpu/test/sparse/misc_ops_test.py +++ b/fbgemm_gpu/test/sparse/misc_ops_test.py @@ -12,7 +12,7 @@ import itertools import random import unittest -from typing import Union +from typing import Type, Union import hypothesis.strategies as st import numpy as np @@ -67,7 +67,7 @@ def test_offsets_range( N: int, # pyre-fixme[11]: Annotation `int32` is not defined as a type. # pyre-fixme[11]: Annotation `int64` is not defined as a type. - offsets_type: "Union[type[torch.int32], type[torch.int64]]", + offsets_type: "Union[Type[torch.int32], Type[torch.int64]]", ) -> None: lengths = np.array([np.random.randint(low=0, high=20) for _ in range(N)]) offsets = np.cumsum(np.concatenate([[0], lengths]))[:-1] @@ -94,7 +94,7 @@ def test_offsets_range( @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) def test_bucketize_sparse_features( self, - index_type: type[torch.dtype], + index_type: Type[torch.dtype], has_weight: bool, bucketize_pos: bool, ) -> None: diff --git a/fbgemm_gpu/test/sparse/permute_embeddings_test.py b/fbgemm_gpu/test/sparse/permute_embeddings_test.py index e4394a7033..dae48ec3f9 100644 --- a/fbgemm_gpu/test/sparse/permute_embeddings_test.py +++ b/fbgemm_gpu/test/sparse/permute_embeddings_test.py @@ -11,7 +11,7 @@ import random import unittest -from typing import Any, Callable +from typing import Any, Callable, Tuple import hypothesis.strategies as st import torch @@ -30,9 +30,9 @@ class PermuteEmbeddingsTest(unittest.TestCase): @staticmethod def permute_embeddings_( - permute_fn: Callable[..., tuple[torch.Tensor, ...]], + permute_fn: Callable[..., Tuple[torch.Tensor, ...]], *args: Any, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: if permute_fn == torch.ops.fbgemm.permute_2D_sparse_data: permuted_lengths, permuted_embeddings, _ = permute_fn(*args, None) return permuted_lengths, permuted_embeddings @@ -59,7 +59,7 @@ def test_permute_embeddings( T: int, L: int, long_index: bool, - permute_fn: Callable[..., tuple[torch.Tensor, ...]], + permute_fn: Callable[..., Tuple[torch.Tensor, ...]], ) -> None: index_dtype = torch.int64 if long_index else torch.int32 lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) diff --git a/fbgemm_gpu/test/sparse/permute_indices_test.py b/fbgemm_gpu/test/sparse/permute_indices_test.py index a48f43c80e..d661dec171 100644 --- a/fbgemm_gpu/test/sparse/permute_indices_test.py +++ b/fbgemm_gpu/test/sparse/permute_indices_test.py @@ -12,7 +12,7 @@ import random import unittest from itertools import accumulate -from typing import Optional +from typing import List, Optional import hypothesis.strategies as st import torch @@ -55,7 +55,7 @@ def test_permute_indices( W: int, ) -> None: index_dtype = torch.int64 if long_index else torch.int32 - length_splits: Optional[list[torch.Tensor]] = None + length_splits: Optional[List[torch.Tensor]] = None if is_1D: if B == 0: batch_sizes = [0] * W diff --git a/fbgemm_gpu/test/sparse/permute_sparse_features_test.py b/fbgemm_gpu/test/sparse/permute_sparse_features_test.py index 38c9c9581d..55f2f425ed 100644 --- a/fbgemm_gpu/test/sparse/permute_sparse_features_test.py +++ b/fbgemm_gpu/test/sparse/permute_sparse_features_test.py @@ -11,7 +11,7 @@ import random import unittest -from typing import cast, Optional +from typing import cast, Optional, Tuple import hypothesis.strategies as st import torch @@ -35,7 +35,7 @@ def permute_sparse_features_ref_( indices: torch.Tensor, weights: Optional[torch.Tensor], permute: torch.LongTensor, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: T = lengths.size(0) B = lengths.size(1) permuted_lengths = torch.index_select(lengths.view(T, B), 0, permute) @@ -81,7 +81,7 @@ def test_permute_sparse_features( indices = torch.randint( low=1, high=int(1e5), - size=cast(tuple[int, ...], (lengths.sum().item(),)), + size=cast(Tuple[int, ...], (lengths.sum().item(),)), ).type(index_dtype) permute_list = list(range(T)) random.shuffle(permute_list) @@ -143,7 +143,7 @@ def test_permute_sparse_features_with_repeats( indices = torch.randint( low=1, high=int(1e5), - size=cast(tuple[int, ...], (lengths.sum().item(),)), + size=cast(Tuple[int, ...], (lengths.sum().item(),)), ).type(index_dtype) permute_list = list(range(T)) diff --git a/fbgemm_gpu/test/tbe/bench/tbe_data_config_loader_test.py b/fbgemm_gpu/test/tbe/bench/tbe_data_config_loader_test.py index 1db3841732..35acb66903 100644 --- a/fbgemm_gpu/test/tbe/bench/tbe_data_config_loader_test.py +++ b/fbgemm_gpu/test/tbe/bench/tbe_data_config_loader_test.py @@ -9,6 +9,7 @@ import random import unittest +from typing import List import click import torch @@ -26,7 +27,7 @@ def rand_int(min_value: int, max_value: int) -> int: return torch.randint(min_value, max_value, (1,)).tolist()[0] -def clean_command(command: str) -> list[str]: +def clean_command(command: str) -> List[str]: return [x for x in command.strip().split() if x] diff --git a/fbgemm_gpu/test/tbe/cache/cache_common.py b/fbgemm_gpu/test/tbe/cache/cache_common.py index 73af22b249..48b1df66ed 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_common.py +++ b/fbgemm_gpu/test/tbe/cache/cache_common.py @@ -10,7 +10,7 @@ # pyre-ignore-all-errors[56] from dataclasses import dataclass -from typing import Optional, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -49,7 +49,7 @@ class TestingStatsReporter(TBEStatsReporter): def __init__(self, reporting_interval: int = 1) -> None: # Event -> args for that call - self.reported_data: dict[str, list[list[Union[int, str, float]]]] = {} + self.reported_data: Dict[str, List[List[Union[int, str, float]]]] = {} self.reporting_interval = reporting_interval def should_report(self, iteration_step: int) -> bool: @@ -106,7 +106,7 @@ def generate_cache_tbes( gather_uvm_cache_stats: bool = False, reporter_config: Optional[TestingStatsReporterConfig] = None, multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, -) -> tuple[ +) -> Tuple[ SplitTableBatchedEmbeddingBagsCodegen, SplitTableBatchedEmbeddingBagsCodegen, int, diff --git a/fbgemm_gpu/test/tbe/cache/cache_test.py b/fbgemm_gpu/test/tbe/cache/cache_test.py index 8f3a3eec52..a19579bd9a 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_test.py @@ -11,7 +11,7 @@ import random import unittest -from typing import Any, cast, Optional +from typing import Any, cast, List, Optional, Tuple import hypothesis.strategies as st import numpy as np @@ -56,10 +56,10 @@ class CacheTest(unittest.TestCase): def _compute_grad_output_shape( self, B: int, - D_offsets: list[int], + D_offsets: List[int], mixed_B: bool, - Bs_feature_rank: Optional[list[list[int]]] = None, - ) -> tuple[int, ...]: + Bs_feature_rank: Optional[List[List[int]]] = None, + ) -> Tuple[int, ...]: """ Compute output gradient shape If mixed_B = True (variable batch size), the shape is sum(Bi * Di for @@ -296,10 +296,10 @@ def _prefetch( _prefetch(cc, batch_i) - input_batch_count: list[int] = [] + input_batch_count: List[int] = [] intput_original_size: int = 0 intput_long_size: int = 0 - output_batch_count: list[int] = [] + output_batch_count: List[int] = [] output_original_size: int = 0 while batch_i: indices, offsets, _, Bs_feature_rank = batch_i.unpack_4() @@ -361,8 +361,8 @@ def _prefetch( def assert_event_exist( event_name: str, - steps: list[int], - expected_value: Optional[list[int]] = None, + steps: List[int], + expected_value: Optional[List[int]] = None, ) -> None: self.assertEqual( len(stats_reporter.reported_data[event_name]), len(steps) diff --git a/fbgemm_gpu/test/tbe/cache/lxu_cache_test.py b/fbgemm_gpu/test/tbe/cache/lxu_cache_test.py index aa0f1d6c50..abe8ce99e2 100644 --- a/fbgemm_gpu/test/tbe/cache/lxu_cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/lxu_cache_test.py @@ -12,6 +12,7 @@ import random import unittest from itertools import accumulate +from typing import Tuple import hypothesis.strategies as st import numpy as np @@ -176,7 +177,7 @@ def unique_lookup( offsets: Tensor, cache_hash_size_cumsum: Tensor, total_cache_hash_size: int, - ) -> tuple[Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor]: linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( cache_hash_size_cumsum, indices, diff --git a/fbgemm_gpu/test/tbe/common.py b/fbgemm_gpu/test/tbe/common.py index 95de120cc2..861a810665 100644 --- a/fbgemm_gpu/test/tbe/common.py +++ b/fbgemm_gpu/test/tbe/common.py @@ -7,6 +7,7 @@ # pyre-strict +from typing import List, Tuple import fbgemm_gpu import numpy as np @@ -49,7 +50,7 @@ VERBOSITY: Verbosity = Verbosity.verbose -def gen_mixed_B_batch_sizes(B: int, T: int) -> tuple[list[list[int]], list[int]]: +def gen_mixed_B_batch_sizes(B: int, T: int) -> Tuple[List[List[int]], List[int]]: num_ranks = np.random.randint(low=1, high=4) low = max(int(0.25 * B), 1) high = int(B) @@ -65,7 +66,7 @@ def gen_mixed_B_batch_sizes(B: int, T: int) -> tuple[list[list[int]], list[int]] def format_ref_tensors_in_mixed_B_layout( - ref_tensors: list[torch.Tensor], Bs_rank_feature: list[list[int]] + ref_tensors: List[torch.Tensor], Bs_rank_feature: List[List[int]] ) -> torch.Tensor: # Relayout the reference tensor # Jagged dimension: (rank, table, local batch) diff --git a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py index 34d996ea6f..eb208d2f8d 100644 --- a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py +++ b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py @@ -10,6 +10,7 @@ import time import unittest from time import sleep +from typing import List import fbgemm_gpu import torch @@ -98,7 +99,7 @@ def test_set_get_embeddings(self) -> None: ), ) - def equal_one_of(t1: torch.Tensor, t2: list[torch.Tensor]) -> bool: + def equal_one_of(t1: torch.Tensor, t2: List[torch.Tensor]) -> bool: any_equal = False for t in t2: any_equal = torch.equal(t1, t) @@ -141,7 +142,7 @@ def test_inplace_update(self) -> None: ) full_ids: torch.Tensor = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.int64) - def equal_one_of(t1: torch.Tensor, t2: list[torch.Tensor]) -> bool: + def equal_one_of(t1: torch.Tensor, t2: List[torch.Tensor]) -> bool: any_equal = False for t in t2: any_equal = torch.equal(t1, t) diff --git a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_test.py b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_test.py index 9d6a6c3797..65f188c429 100644 --- a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_test.py +++ b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_test.py @@ -10,7 +10,7 @@ import logging import unittest -from typing import Any +from typing import Any, Dict, List, Tuple import hypothesis.strategies as st import numpy as np @@ -29,7 +29,7 @@ ) MAX_EXAMPLES = 20 -default_st: dict[str, Any] = { +default_st: Dict[str, Any] = { "T": st.integers(min_value=1, max_value=10), "D": st.integers(min_value=2, max_value=128), "log_E": st.integers(min_value=2, max_value=3), @@ -37,7 +37,7 @@ "weights_precision": st.sampled_from([SparseType.FP32, SparseType.FP16]), } -default_settings: dict[str, Any] = { +default_settings: Dict[str, Any] = { "verbosity": Verbosity.verbose, "max_examples": MAX_EXAMPLES, "deadline": None, @@ -54,7 +54,7 @@ def generate_fbgemm_kv_tbe( log_E: int, weights_precision: SparseType, mixed: bool, - ) -> tuple[SSDTableBatchedEmbeddingBags, list[int], list[int], int]: + ) -> Tuple[SSDTableBatchedEmbeddingBags, List[int], List[int], int]: E = int(10**log_E) D = D * 4 if not mixed: diff --git a/fbgemm_gpu/test/tbe/inference/inference_converter_test.py b/fbgemm_gpu/test/tbe/inference/inference_converter_test.py index 32cd92383e..e9c93b518e 100644 --- a/fbgemm_gpu/test/tbe/inference/inference_converter_test.py +++ b/fbgemm_gpu/test/tbe/inference/inference_converter_test.py @@ -12,7 +12,7 @@ import math import random import unittest -from typing import Optional +from typing import Optional, Tuple import hypothesis.strategies as st import numpy as np @@ -62,7 +62,7 @@ def to_device(t: torch.Tensor, use_cpu: bool) -> torch.Tensor: def get_table_batched_offsets_from_dense( merged_indices: torch.Tensor, use_cpu: bool = False -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: (T, B, L) = merged_indices.size() lengths = np.ones((T, B)) * L flat_lengths = lengths.flatten() diff --git a/fbgemm_gpu/test/tbe/inference/nbit_cache_test.py b/fbgemm_gpu/test/tbe/inference/nbit_cache_test.py index 9b9450a71f..8d0918f1f9 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_cache_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_cache_test.py @@ -10,7 +10,7 @@ # pyre-ignore-all-errors[56] import unittest -from typing import Callable +from typing import Callable, Dict, List import hypothesis.strategies as st import numpy as np @@ -44,7 +44,7 @@ # pyre-ignore -additional_decorators: dict[str, list[Callable]] = { +additional_decorators: Dict[str, List[Callable]] = { "test_faketensor__test_nbit_uvm_cache_stats": [ unittest.skip("very slow"), ], diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index 15f98428d8..1a1cde1753 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -10,7 +10,7 @@ import random import unittest -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import hypothesis.strategies as st import numpy as np @@ -51,7 +51,7 @@ # pyre-ignore -additional_decorators: dict[str, list[Callable]] = { +additional_decorators: Dict[str, List[Callable]] = { "test_faketensor__test_nbit_forward_uvm_cache": [ unittest.skip("CUDA Assert"), ], @@ -587,7 +587,7 @@ def test_nbit_forward_cpu( @settings(deadline=None) @unittest.skipIf(*gpu_unavailable) def test_nbit_forward_gpu_no_cache_max_sizes( - self, indices_dtype: torch.dtype, weights_ty_and_D: tuple[SparseType, int] + self, indices_dtype: torch.dtype, weights_ty_and_D: Tuple[SparseType, int] ) -> None: weights_ty, D = weights_ty_and_D self.execute_nbit_forward_( @@ -876,7 +876,7 @@ def test_nbit_forward_cpu_seq_int8( quant_cc.fill_random_weights() raw_embedding_weights = quant_cc.split_embedding_weights() # we mimic 1.0 scale, 0.0 bias for better results comparison - embedding_weights: list[tuple[torch.Tensor, Optional[torch.Tensor]]] = [ + embedding_weights: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = [ (table_weight, torch.tensor([1, 0], dtype=torch.float16).view(torch.uint8)) for table_weight, _ in raw_embedding_weights ] diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index 3a82868ec2..381b2ba1e1 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -11,7 +11,7 @@ import random import unittest -from typing import Callable +from typing import Callable, Dict, List import hypothesis.strategies as st import numpy as np @@ -40,7 +40,7 @@ VERBOSITY: Verbosity = Verbosity.verbose # pyre-ignore -additional_decorators: dict[str, list[Callable]] = { +additional_decorators: Dict[str, List[Callable]] = { "test_faketensor__test_nbit_forward_cpu_seq_int4": { unittest.skip( "Operator outputs int4 tensors which do not support opcheck tests" diff --git a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py index 2f7c0b08aa..19b8abfa69 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py @@ -15,7 +15,7 @@ import time import unittest -from typing import Any, Optional +from typing import Any, Dict, List, Optional, Tuple from unittest.mock import patch import hypothesis.strategies as st @@ -41,7 +41,7 @@ MAX_EXAMPLES = 100 WORLD_SIZE = 4 -default_st: dict[str, Any] = { +default_st: Dict[str, Any] = { "T": st.integers(min_value=1, max_value=10), "D": st.integers(min_value=2, max_value=128), "log_E": st.integers(min_value=2, max_value=3), @@ -49,7 +49,7 @@ "weights_precision": st.sampled_from([SparseType.FP32, SparseType.FP16]), } -default_settings: dict[str, Any] = { +default_settings: Dict[str, Any] = { "verbosity": Verbosity.verbose, "max_examples": MAX_EXAMPLES, "deadline": None, @@ -146,7 +146,7 @@ def generate_fbgemm_kv_tbe( kv_zch_params: Optional[KVZCHParams] = None, backend_type: BackendType = BackendType.SSD, flushing_block_size: int = 1000, - ) -> tuple[SSDTableBatchedEmbeddingBags, list[int], list[int]]: + ) -> Tuple[SSDTableBatchedEmbeddingBags, List[int], List[int]]: E = int(10**log_E) D = D * 4 if not mixed: diff --git a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py index 0fba336e8e..adb3102bfb 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py @@ -9,7 +9,7 @@ import gc import tempfile import unittest -from typing import Any +from typing import Any, Dict from unittest import TestCase import fbgemm_gpu # noqa E402 @@ -30,7 +30,7 @@ MAX_EXAMPLES = 20 MAX_D = 256 -default_settings: dict[str, Any] = { +default_settings: Dict[str, Any] = { "verbosity": Verbosity.verbose, "max_examples": MAX_EXAMPLES, "deadline": None, diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index c3f12903fa..56a8a66026 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -10,7 +10,7 @@ import tempfile import unittest -from typing import Any, Optional +from typing import Any, List, Optional import hypothesis.strategies as st import numpy as np @@ -47,7 +47,7 @@ @unittest.skipIf(*running_in_oss) @unittest.skipIf(*gpu_unavailable) class SSDSplitTableBatchedEmbeddingsTest(SSDSplitTableBatchedEmbeddingsTestCommon): - def get_physical_table_arg_indices_(self, feature_table_map: list[int]): + def get_physical_table_arg_indices_(self, feature_table_map: List[int]): """ Get the physical table arg indices for the reference and TBE. The first element in each tuple is for accessing the reference embedding diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_tbe_training_adam_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_tbe_training_adam_test.py index 93703de35e..bebcfe9b50 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_tbe_training_adam_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_tbe_training_adam_test.py @@ -8,7 +8,7 @@ # pyre-ignore-all-errors[3,6,56] import unittest -from typing import Any +from typing import Any, Dict import hypothesis.strategies as st import torch @@ -29,7 +29,7 @@ VIRTUAL_TABLE_ROWS, ) -default_st: dict[str, Any] = default_strategies | { +default_st: Dict[str, Any] = default_strategies | { "m1_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), "m2_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), "num_buckets": st.integers(min_value=10, max_value=15), diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_tbe_training_partial_rowwise_adam_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_tbe_training_partial_rowwise_adam_test.py index 4a64603a6d..e997dafba1 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_tbe_training_partial_rowwise_adam_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_tbe_training_partial_rowwise_adam_test.py @@ -8,7 +8,7 @@ # pyre-ignore-all-errors[3,6,56] import unittest -from typing import Any +from typing import Any, Dict import hypothesis.strategies as st import torch @@ -31,7 +31,7 @@ VIRTUAL_TABLE_ROWS, ) -default_st: dict[str, Any] = default_strategies | { +default_st: Dict[str, Any] = default_strategies | { "m1_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), "m2_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), "num_buckets": st.integers(min_value=10, max_value=15), diff --git a/fbgemm_gpu/test/tbe/ssd/training_common.py b/fbgemm_gpu/test/tbe/ssd/training_common.py index 1dcf584c3b..f62a6fb656 100644 --- a/fbgemm_gpu/test/tbe/ssd/training_common.py +++ b/fbgemm_gpu/test/tbe/ssd/training_common.py @@ -11,7 +11,7 @@ import unittest from enum import Enum -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import hypothesis.strategies as st import numpy as np @@ -39,7 +39,7 @@ def find_different_rows( atol: float = 1.0e-4, rtol: float = 1.0e-4, return_values: bool = False, -) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Find the indices of rows that are different between two tensors. @@ -158,7 +158,7 @@ def print_different_rows( 2**18 ) # relatively large for now given optimizer is still pre-allocated -default_strategies: dict["str", Any] = { +default_strategies: Dict["str", Any] = { "T": st.integers(min_value=1, max_value=10), "D": st.integers(min_value=2, max_value=128), "B": st.integers(min_value=1, max_value=128), @@ -190,7 +190,7 @@ class FlushLocation(Enum): class SSDSplitTableBatchedEmbeddingsTestCommon(unittest.TestCase): - def get_physical_table_arg_indices_(self, feature_table_map: list[int]): + def get_physical_table_arg_indices_(self, feature_table_map: List[int]): """ Get the physical table arg indices for the reference and TBE. The first element in each tuple is for accessing the reference embedding @@ -214,11 +214,11 @@ def get_physical_table_arg_indices_(self, feature_table_map: list[int]): def generate_in_bucket_indices( self, hash_mode: int, - bucket_id_range: tuple[int, int], + bucket_id_range: Tuple[int, int], bucket_size: int, # max height in ref_emb, the logical id high, physically id in kv is a shift from [0,h) to [table_offset, table_offset+h] high: int, - size: tuple[int, int], + size: Tuple[int, int], ) -> torch.Tensor: """ Generate indices in embedding bucket, this is guarantee on the torchrec input_dist @@ -246,21 +246,21 @@ def generate_inputs_( self, B: int, L: int, - Es: list[int], - feature_table_map: list[int], + Es: List[int], + feature_table_map: List[int], weights_precision: SparseType = SparseType.FP32, trigger_bounds_check: bool = False, mixed_B: bool = False, is_kv_tbes: bool = False, - bucket_offsets: Optional[list[tuple[int, int]]] = None, - bucket_sizes: Optional[list[int]] = None, - ) -> tuple[ - list[torch.Tensor], - list[torch.Tensor], + bucket_offsets: Optional[List[Tuple[int, int]]] = None, + bucket_sizes: Optional[List[int]] = None, + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor, - Optional[list[list[int]]], + Optional[List[List[int]]], ]: """ Generate indices and per sample weights @@ -356,15 +356,15 @@ def generate_kvzch_tbes( mixed: bool = False, enable_optimizer_offloading: bool = False, backend_return_whole_row: bool = False, - optimizer_state_dtypes: dict[OptimType, SparseType] = {}, # noqa: B006 + optimizer_state_dtypes: Dict[OptimType, SparseType] = {}, # noqa: B006 embedding_cache_mode: bool = False, - ) -> tuple[ + ) -> Tuple[ SSDTableBatchedEmbeddingBags, - list[torch.nn.EmbeddingBag], - list[int], - list[int], - list[tuple[int, int]], - list[int], + List[torch.nn.EmbeddingBag], + List[int], + List[int], + List[Tuple[int, int]], + List[int], ]: """ Generate embedding modules (i,e., SSDTableBatchedEmbeddingBags and @@ -563,8 +563,8 @@ def generate_ssd_tbes( lazy_bulk_init_enabled: bool = False, backend_type: BackendType = BackendType.SSD, enable_raw_embedding_streaming: bool = False, - optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006 - ) -> tuple[SSDTableBatchedEmbeddingBags, list[torch.nn.EmbeddingBag]]: + optimizer_state_dtypes: Dict[str, SparseType] = {}, # noqa: B006 + ) -> Tuple[SSDTableBatchedEmbeddingBags, List[torch.nn.EmbeddingBag]]: """ Generate embedding modules (i,e., SSDTableBatchedEmbeddingBags and torch.nn.EmbeddingBags) @@ -702,7 +702,7 @@ def generate_ssd_tbes( def concat_ref_tensors( self, - tensors: list[torch.Tensor], + tensors: List[torch.Tensor], do_pooling: bool, B: int, D: int, @@ -713,8 +713,8 @@ def concat_ref_tensors( def concat_ref_tensors_vbe( self, - tensors: list[torch.Tensor], - batch_size_per_feature_per_rank: list[list[int]], + tensors: List[torch.Tensor], + batch_size_per_feature_per_rank: List[List[int]], ) -> torch.Tensor: """ rearrange tensors into VBE format and concat them into one tensor @@ -742,9 +742,9 @@ def concat_ref_tensors_vbe( def execute_ssd_forward_( self, emb: SSDTableBatchedEmbeddingBags, - emb_ref: list[torch.nn.EmbeddingBag], - indices_list: list[torch.Tensor], - per_sample_weights_list: list[torch.Tensor], + emb_ref: List[torch.nn.EmbeddingBag], + indices_list: List[torch.Tensor], + per_sample_weights_list: List[torch.Tensor], indices: torch.Tensor, offsets: torch.Tensor, per_sample_weights: torch.Tensor, @@ -753,8 +753,8 @@ def execute_ssd_forward_( weighted: bool, tolerance: Optional[float] = None, it: int = -1, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, - ) -> tuple[list[torch.Tensor], torch.Tensor]: + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, + ) -> Tuple[List[torch.Tensor], torch.Tensor]: """ Execute the forward functions of SSDTableBatchedEmbeddingBags and torch.nn.EmbeddingBag and compare outputs @@ -837,12 +837,12 @@ def execute_ssd_forward_( def execute_ssd_backward_( self, - output_ref_list: list[torch.Tensor], + output_ref_list: List[torch.Tensor], output: torch.Tensor, B: int, D: int, pooling_mode: PoolingMode, - batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> None: # Generate output gradient output_grad_list = [torch.randn_like(out) for out in output_ref_list] @@ -867,7 +867,7 @@ def execute_ssd_backward_( def split_optimizer_states_( self, emb: SSDTableBatchedEmbeddingBags - ) -> list[list[torch.Tensor]]: + ) -> List[List[torch.Tensor]]: _, bucket_asc_ids_list, _, _ = emb.split_embedding_weights( no_snapshot=False, should_flush=True ) diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py index 2aec95c674..8b89c4e2ea 100755 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py @@ -9,7 +9,7 @@ import sys -from typing import Any +from typing import Any, Dict import hypothesis.strategies as st @@ -68,7 +68,7 @@ VERBOSITY: Verbosity = Verbosity.verbose -common_strategy: dict[str, Any] = { +common_strategy: Dict[str, Any] = { "T": st.integers(min_value=1, max_value=5), "D": st.integers(min_value=2, max_value=128), "B": st.integers(min_value=1, max_value=128), @@ -89,7 +89,7 @@ ), } -common_settings: dict[str, Any] = { +common_settings: Dict[str, Any] = { "verbosity": VERBOSITY, "max_examples": MAX_EXAMPLES_LONG_RUNNING, "deadline": None, @@ -525,7 +525,7 @@ def execute_backward_adagrad( # noqa C901 ) -def adjust_mixed_B_st(kwargs: dict[str, Any]) -> dict[str, Any]: +def adjust_mixed_B_st(kwargs: Dict[str, Any]) -> Dict[str, Any]: # VBE is supported in rowwise_adagrad only assert "row_wise" in kwargs if not kwargs["row_wise"]: diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py index 2e50662377..ab33600fee 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py @@ -11,7 +11,7 @@ import copy import unittest -from typing import Any +from typing import Any, Dict, List import numpy as np import torch @@ -46,7 +46,7 @@ # Set up test strategy -test_st: dict[str, Any] = { +test_st: Dict[str, Any] = { "T": st.integers(min_value=1, max_value=5), "D": st.integers(min_value=2, max_value=128), "B": st.integers(min_value=1, max_value=128), @@ -140,7 +140,7 @@ def apply_gwd_per_table( def apply_gwd( T: int, - Bs: list[int], + Bs: List[int], emb: SplitTableBatchedEmbeddingBagsCodegen, prev_iter_dev: torch.Tensor, step: int, diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_large_dim_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_large_dim_test.py index e4093563c3..4d61783be9 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_large_dim_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_large_dim_test.py @@ -10,7 +10,7 @@ # pyre-ignore-all-errors[56] import unittest -from typing import Any +from typing import Any, Dict from hypothesis import given, settings @@ -29,7 +29,7 @@ ) # Set up test strategy -test_st: dict[str, Any] = common_strategy.copy() +test_st: Dict[str, Any] = common_strategy.copy() test_st["D"] = st.integers(min_value=128, max_value=512) diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py index 2c22a3fd5e..abe9865002 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py @@ -10,7 +10,7 @@ # pyre-ignore-all-errors[56] import unittest -from typing import Any +from typing import Any, Dict import torch @@ -30,9 +30,9 @@ ) # Set up test strategy -test_st: dict[str, Any] = common_strategy.copy() +test_st: Dict[str, Any] = common_strategy.copy() test_st["D"] = st.integers(min_value=2, max_value=128) -test_st_cpu: dict[str, Any] = test_st.copy() +test_st_cpu: Dict[str, Any] = test_st.copy() test_st_cpu["use_cpu"] = st.just(True) test_st_cpu["row_wise"] = st.just(True) test_st_cpu["output_dtype"] = st.sampled_from([SparseType.FP32, SparseType.FP16]) diff --git a/fbgemm_gpu/test/tbe/training/backward_none_test.py b/fbgemm_gpu/test/tbe/training/backward_none_test.py index 2899d72288..1680313650 100644 --- a/fbgemm_gpu/test/tbe/training/backward_none_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_none_test.py @@ -11,7 +11,7 @@ import random import unittest -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import hypothesis.strategies as st import numpy as np @@ -304,7 +304,7 @@ def execute_backward_none_( # noqa C901 # as weight if weights_precision != output_dtype: fs = [f.to(output_dtype.as_dtype()) for f in fs] - gos: Union[list[Tensor], Tensor] = [torch.randn_like(f) for f in fs] + gos: Union[List[Tensor], Tensor] = [torch.randn_like(f) for f in fs] [f.backward(go) for (f, go) in zip(fs, gos)] else: bs_ = SplitTableBatchedEmbeddingBagsCodegen( @@ -331,7 +331,7 @@ def execute_backward_none_( # noqa C901 to_device(xw.contiguous().view(-1), use_cpu), ) ) - gos: Union[list[Tensor], Tensor] = torch.rand_like(fs) + gos: Union[List[Tensor], Tensor] = torch.rand_like(fs) fs.backward(gos) cc = SplitTableBatchedEmbeddingBagsCodegen( diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index 14477b3883..5147f770e9 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -11,7 +11,7 @@ import math import unittest -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import hypothesis.strategies as st import numpy as np @@ -105,7 +105,7 @@ def execute_backward_optimizers_( # noqa C901 use_cpu: bool, weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, uvm_non_rowwise_momentum: bool = False, - optimizer_state_dtypes: Optional[dict[str, SparseType]] = None, + optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None, use_rowwise_bias_correction: bool = False, counter_weight_decay_mode: Optional[CounterWeightDecayMode] = None, counter_halflife: int = -1, @@ -258,7 +258,7 @@ def execute_backward_optimizers_( # noqa C901 [f.backward(go) for (f, go) in zip(fs, gos)] # do SGD update - optimizer_kwargs: dict[str, Any] = {"learning_rate": 0.5} + optimizer_kwargs: Dict[str, Any] = {"learning_rate": 0.5} (lr, eps, beta1, beta2, weight_decay, momentum, eta) = ( 0.5, 1e-4, @@ -794,7 +794,7 @@ def _get_grad_from_counter_adagrad( prev_iter: torch.Tensor, iter_: int, weight_decay: float, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: row_counter = row_counter.view(row_counter.numel(), 1) prev_iter = prev_iter.view(prev_iter.numel(), 1) freq = torch.ones_like(row_counter) @@ -1129,7 +1129,7 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901 pooling_mode: PoolingMode, use_cpu: bool, uvm_non_rowwise_momentum: bool, - optimizer_state_dtypes: dict[str, SparseType], + optimizer_state_dtypes: Dict[str, SparseType], ) -> None: self.execute_backward_optimizers_( T, @@ -1364,7 +1364,7 @@ def test_backward_optimizers_ensemble_rowwise_adagrad( # noqa C901 pooling_mode: PoolingMode, use_cpu: bool, uvm_non_rowwise_momentum: bool, - optimizer_state_dtypes: dict[str, SparseType], + optimizer_state_dtypes: Dict[str, SparseType], ) -> None: self.execute_backward_optimizers_( T, diff --git a/fbgemm_gpu/test/tbe/utils/generate_vbe_metadata_test.py b/fbgemm_gpu/test/tbe/utils/generate_vbe_metadata_test.py index 08ca3eab9a..59348521f2 100755 --- a/fbgemm_gpu/test/tbe/utils/generate_vbe_metadata_test.py +++ b/fbgemm_gpu/test/tbe/utils/generate_vbe_metadata_test.py @@ -24,13 +24,15 @@ else: from fbgemm_gpu.test.test_utils import gpu_unavailable +from typing import List + class GenerateVBEMetadataTest(unittest.TestCase): def generate_vbe_metadata_ref( self, offsets: torch.Tensor, info_B_num_bits: int, - batch_size_per_feature_per_rank: list[list[int]], + batch_size_per_feature_per_rank: List[List[int]], output_offset_feature_rank: torch.Tensor, feature_dims: torch.Tensor, ): diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py index 3dd1bc2cd4..4f18a4b99f 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py @@ -11,7 +11,7 @@ import random import unittest -from typing import Callable +from typing import Callable, Dict, List import hypothesis.strategies as st import numpy as np @@ -49,7 +49,7 @@ # pyre-ignore -additional_decorators: dict[str, list[Callable]] = {} +additional_decorators: Dict[str, List[Callable]] = {} @optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators) @@ -440,8 +440,8 @@ def test_reset_embedding_weight_momentum( emb_op = SplitTableBatchedEmbeddingBagsCodegen E = int(10**log_E) D = D * 4 - Ds: list[int] = [] - Es: list[int] = [] + Ds: List[int] = [] + Es: List[int] = [] if not mixed: Ds = [D] * T Es = [E] * T @@ -482,10 +482,10 @@ def test_reset_embedding_weight_momentum( output_dtype=output_dtype, ) - pruned_indices: list[int] = [] - pruned_indices_offsets: list[int] = [0] - logical_table_ids: list[int] = [] - buffer_ids: list[int] = [] + pruned_indices: List[int] = [] + pruned_indices_offsets: List[int] = [0] + logical_table_ids: List[int] = [] + buffer_ids: List[int] = [] for i in range(len(Es)): indices = [ np.random.randint(low=1, high=int(Es[i] - 2)) @@ -514,10 +514,10 @@ def test_reset_embedding_weight_momentum( torch.tensor(buffer_ids, dtype=torch.int32, requires_grad=False), False ) - momentum1: list[Tensor] = [ + momentum1: List[Tensor] = [ s for (s,) in cc.split_optimizer_states() ] # List[rows] - weight: list[Tensor] = cc.split_embedding_weights() # List[(rows, dim)] + weight: List[Tensor] = cc.split_embedding_weights() # List[(rows, dim)] for t in range(T): momentum1[t].fill_(1) weight[t].fill_(1) diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py index 0e5282b9f0..fd9ccdcad1 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py @@ -14,6 +14,7 @@ import random import unittest from itertools import accumulate +from typing import List, Tuple import fbgemm_gpu # noqa E402 @@ -45,10 +46,10 @@ def gen_inputs( - hash_sizes: list[int], + hash_sizes: List[int], batch_size: int, max_len: int, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ the lengths of bags are chosen from a uniform distribution from [0, max_len] @@ -79,7 +80,7 @@ def transpose_embedding_input_ref( indices: torch.Tensor, offsets: torch.Tensor, info_B_num_bits: int, -) -> tuple[ +) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, @@ -487,7 +488,7 @@ def test_pruning( use_cpu_hashtable: bool, use_array_for_index_remapping: bool, ) -> None: - E = 1000 + E = int(1000) LOAD_FACTOR = 0.8 pruning_ratio = 0.5 diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 5d56bad9a7..48d5c28f3b 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -12,7 +12,7 @@ import unittest from contextlib import contextmanager from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import fbgemm_gpu import hypothesis.strategies as st @@ -28,7 +28,7 @@ # Skip pt2 compliant tag test for certain operators # TODO: remove this once the operators are pt2 compliant # pyre-ignore -additional_decorators: dict[str, list[Callable]] = { +additional_decorators: Dict[str, List[Callable]] = { # vbe_generate_metadata_cpu return different values from vbe_generate_metadata_meta # this fails fake_tensor test as the test expects them to be the same # fake_tensor test is added in failures_dict but failing fake_tensor test still cause pt2_compliant tag test to fail @@ -45,7 +45,7 @@ } # Used for `@unittest.skipIf` -gpu_unavailable: tuple[bool, str] = ( +gpu_unavailable: Tuple[bool, str] = ( not torch.cuda.is_available() or torch.cuda.device_count() == 0, "CUDA is not available or no GPUs detected", ) @@ -58,31 +58,31 @@ is_nvidia_device and torch.cuda.get_device_capability()[0] >= 8 ) -running_on_sm70: tuple[bool, str] = ( +running_on_sm70: Tuple[bool, str] = ( not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, "Skip test if SM70, since the code is hardcoded to sm80+ support", ) # Used for `@unittest.skipIf` for tests that pass in internal CI, but fail on the GitHub runners -running_on_github: tuple[bool, str] = ( +running_on_github: Tuple[bool, str] = ( os.getenv("GITHUB_ENV") is not None, "Test is currently known to fail or hang when run in the GitHub runners", ) -running_in_oss: tuple[bool, str] = ( +running_in_oss: Tuple[bool, str] = ( # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. getattr(fbgemm_gpu, "open_source", False), "Test is currently known to fail in OSS mode", ) -running_on_rocm: tuple[bool, str] = ( +running_on_rocm: Tuple[bool, str] = ( TEST_WITH_ROCM, "Test currently doesn't work on the ROCm stack", ) # Tests with this marker generally fails with `free(): corrupted unsorted chunks` # errors when fbgemm_gpu is compiled under Clang -on_oss_clang: tuple[bool, str] = ( +on_oss_clang: Tuple[bool, str] = ( ( hasattr(fbgemm_gpu, "open_source") and os.system("c++ --version | grep -i clang") == 0 @@ -91,7 +91,7 @@ ) # Used for `@unittest.skipIf` for tests that currently fail on ARM platform -on_arm_platform: tuple[bool, str] = ( +on_arm_platform: Tuple[bool, str] = ( subprocess.run(["uname", "-m"], stdout=subprocess.PIPE) .stdout.decode("utf-8") .strip() @@ -100,7 +100,7 @@ ) -def cpu_and_maybe_gpu() -> st.SearchStrategy[list[torch.device]]: +def cpu_and_maybe_gpu() -> st.SearchStrategy[List[torch.device]]: gpu_available = torch.cuda.is_available() and torch.cuda.device_count() > 0 # st.sampled_from is not guaranteed to test all the values passed to it. # However, Hypothesis, by default, generates 100 test cases from the specified strategies. @@ -138,7 +138,7 @@ def generate_opcheck_tests( *, fast: bool = False, # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. - additional_decorators: Optional[dict[str, Callable]] = None, + additional_decorators: Optional[Dict[str, Callable]] = None, ): if additional_decorators is None: additional_decorators = {} @@ -221,7 +221,7 @@ def gradcheck( # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. f: Callable, # pyre-ignore[2] - inputs: Union[torch.Tensor, tuple[Any, ...]], + inputs: Union[torch.Tensor, Tuple[Any, ...]], *args: Any, **kwargs: Any, ) -> None: @@ -234,7 +234,7 @@ def gradcheck( torch.autograd.gradcheck(f, inputs, *args, **kwargs) -def cpu_only() -> st.SearchStrategy[list[torch.device]]: +def cpu_only() -> st.SearchStrategy[List[torch.device]]: return st.sampled_from([torch.device("cpu")]) @@ -326,7 +326,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator -def symint_vector_unsupported() -> tuple[bool, str]: +def symint_vector_unsupported() -> Tuple[bool, str]: major, minor = torch.__version__.split(".")[0:2] return ( int(major) < 2 or (int(major) == 2 and int(minor) < 1), diff --git a/fbgemm_gpu/test/uvm/copy_test.py b/fbgemm_gpu/test/uvm/copy_test.py index d6f38fb844..3e857d71e9 100644 --- a/fbgemm_gpu/test/uvm/copy_test.py +++ b/fbgemm_gpu/test/uvm/copy_test.py @@ -9,6 +9,7 @@ # pyre-ignore-all-errors[56] import unittest +from typing import List import fbgemm_gpu import hypothesis.strategies as st @@ -48,7 +49,7 @@ class CopyTest(unittest.TestCase): ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) # pyre-fixme[2]: Parameter must be annotated. - def test_uvm_to_cpu(self, sizes: list[int], uvm_op) -> None: + def test_uvm_to_cpu(self, sizes: List[int], uvm_op) -> None: if uvm_op is torch.ops.fbgemm.new_unified_tensor: is_host_mapped = False uvm_t = uvm_op( @@ -90,7 +91,7 @@ def test_uvm_to_cpu(self, sizes: list[int], uvm_op) -> None: ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) # pyre-fixme[2]: Parameter must be annotated. - def test_uvm_to_device(self, sizes: list[int], uvm_op) -> None: + def test_uvm_to_device(self, sizes: List[int], uvm_op) -> None: if uvm_op is torch.ops.fbgemm.new_unified_tensor: is_host_mapped = False uvm_t = uvm_op( @@ -132,7 +133,7 @@ def test_uvm_to_device(self, sizes: list[int], uvm_op) -> None: ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) # pyre-fixme[2]: Parameter must be annotated. - def test_uvm_to_cpu_clone(self, sizes: list[int], uvm_op) -> None: + def test_uvm_to_cpu_clone(self, sizes: List[int], uvm_op) -> None: if uvm_op is torch.ops.fbgemm.new_unified_tensor: is_host_mapped = False uvm_t = uvm_op( diff --git a/fbgemm_gpu/test/uvm/uvm_test.py b/fbgemm_gpu/test/uvm/uvm_test.py index fc0b6cd8f5..d1e691463b 100644 --- a/fbgemm_gpu/test/uvm/uvm_test.py +++ b/fbgemm_gpu/test/uvm/uvm_test.py @@ -10,6 +10,7 @@ import random import unittest +from typing import List import fbgemm_gpu import hypothesis.strategies as st @@ -52,7 +53,7 @@ class UvmTest(unittest.TestCase): ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) # pyre-fixme[2]: Parameter must be annotated. - def test_is_uvm_tensor(self, sizes: list[int], uvm_op) -> None: + def test_is_uvm_tensor(self, sizes: List[int], uvm_op) -> None: if uvm_op is torch.ops.fbgemm.new_unified_tensor: is_host_mapped = random.choice([True, False]) uvm_t = uvm_op( @@ -86,7 +87,7 @@ def test_enum(self) -> None: ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) # pyre-fixme[2]: Parameter must be annotated. - def test_cudaMemAdvise(self, sizes: list[int], uvm_op) -> None: + def test_cudaMemAdvise(self, sizes: List[int], uvm_op) -> None: if uvm_op is torch.ops.fbgemm.new_unified_tensor: is_host_mapped = False uvm_t = uvm_op( @@ -118,7 +119,7 @@ def test_cudaMemAdvise(self, sizes: list[int], uvm_op) -> None: ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) # pyre-fixme[2]: Parameter must be annotated. - def test_cudaMemPrefetchAsync(self, sizes: list[int], uvm_op) -> None: + def test_cudaMemPrefetchAsync(self, sizes: List[int], uvm_op) -> None: if uvm_op is torch.ops.fbgemm.new_unified_tensor: is_host_mapped = False uvm_t = uvm_op( @@ -152,7 +153,7 @@ def test_cudaMemPrefetchAsync(self, sizes: list[int], uvm_op) -> None: ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) # pyre-fixme[2]: Parameter must be annotated. - def test_uvm_slice(self, sizes: list[int], uvm_op) -> None: + def test_uvm_slice(self, sizes: List[int], uvm_op) -> None: if uvm_op is torch.ops.fbgemm.new_unified_tensor: is_host_mapped = False uvm_t = uvm_op( @@ -193,7 +194,7 @@ def test_uvm_slice(self, sizes: list[int], uvm_op) -> None: ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) # pyre-fixme[2]: Parameter must be annotated. - def test_uvm_memadviceDontFork(self, sizes: list[int], uvm_op) -> None: + def test_uvm_memadviceDontFork(self, sizes: List[int], uvm_op) -> None: if uvm_op is torch.ops.fbgemm.new_unified_tensor: is_host_mapped = False uvm_t = uvm_op( @@ -218,7 +219,7 @@ def test_uvm_memadviceDontFork(self, sizes: list[int], uvm_op) -> None: ), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_new_managed_tensor_meta(self, sizes: list[int]) -> None: + def test_new_managed_tensor_meta(self, sizes: List[int]) -> None: cpu_tensor = torch.empty(sizes).to("meta") cpu_tensor_meta = torch.ops.fbgemm.new_managed_tensor(cpu_tensor, sizes) assert cpu_tensor.shape == cpu_tensor_meta.shape