diff --git a/.github/workflows/1xL4_tests.yml b/.github/workflows/1xL4_tests.yml index 58980d8504..7a1c293074 100644 --- a/.github/workflows/1xL4_tests.yml +++ b/.github/workflows/1xL4_tests.yml @@ -51,3 +51,4 @@ jobs: pytest test/dtypes/test_affine_quantized_float.py --verbose -s ./test/float8/test_everything_single_gpu.sh python test/quantization/quantize_/workflows/float8/test_float8_tensor.py + python test/kernel/test_blockwise_triton.py --verbose -s diff --git a/benchmarks/benchmark_blockwise_scaled_linear_triton.py b/benchmarks/benchmark_blockwise_scaled_linear_triton.py index ffdd63ec8d..26ba04f2ce 100644 --- a/benchmarks/benchmark_blockwise_scaled_linear_triton.py +++ b/benchmarks/benchmark_blockwise_scaled_linear_triton.py @@ -13,7 +13,7 @@ from triton.testing import do_bench from torchao.float8.float8_utils import compute_error - from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( + from torchao.kernel.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, fp8_blockwise_weight_quant, diff --git a/test/prototype/test_blockwise_triton.py b/test/kernel/test_blockwise_triton.py similarity index 96% rename from test/prototype/test_blockwise_triton.py rename to test/kernel/test_blockwise_triton.py index 89f8cf869e..5de88ab7d9 100644 --- a/test/prototype/test_blockwise_triton.py +++ b/test/kernel/test_blockwise_triton.py @@ -11,7 +11,7 @@ triton = pytest.importorskip("triton", reason="Triton required to run this test") -from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( +from torchao.kernel.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, fp8_blockwise_weight_dequant, diff --git a/torchao/prototype/blockwise_fp8_inference/blockwise_quantization.py b/torchao/kernel/blockwise_quantization.py similarity index 100% rename from torchao/prototype/blockwise_fp8_inference/blockwise_quantization.py rename to torchao/kernel/blockwise_quantization.py diff --git a/torchao/prototype/blockwise_fp8_inference/__init__.py b/torchao/prototype/blockwise_fp8_inference/__init__.py index f2842417e4..eb6b7824bc 100644 --- a/torchao/prototype/blockwise_fp8_inference/__init__.py +++ b/torchao/prototype/blockwise_fp8_inference/__init__.py @@ -1,11 +1,12 @@ -from .blockwise_linear import BlockwiseQuantLinear -from .blockwise_quantization import ( +from torchao.kernel.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, fp8_blockwise_weight_dequant, fp8_blockwise_weight_quant, ) +from .blockwise_linear import BlockwiseQuantLinear + __all__ = [ "blockwise_fp8_gemm", "BlockwiseQuantLinear", diff --git a/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py b/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py index ebed3a84a4..a43574fa11 100644 --- a/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py +++ b/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py @@ -7,7 +7,7 @@ import torch from torch import nn -from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( +from torchao.kernel.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, )