Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions .github/workflows/float8_test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Run Float8 Tests
name: Run Float8 Tests Nightly

on:
push:
Expand All @@ -9,6 +9,10 @@ on:
branches:
- main
- 'gh/**'
# schedule:
# # 3.27 am PST every day
# - cron: "27 11 * * *"
# workflow_dispatch:

concurrency:
group: float8_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
Expand All @@ -28,11 +32,11 @@ jobs:
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124'
gpu-arch-type: "cuda"
gpu-arch-version: "12.4"
- name: H100
runs-on: linux.aws.h100
torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124'
gpu-arch-type: "cuda"
gpu-arch-version: "12.4"
# - name: H100
# runs-on: linux.aws.h100
# torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124'
# gpu-arch-type: "cuda"
# gpu-arch-version: "12.4"

permissions:
id-token: write
Expand All @@ -53,3 +57,4 @@ jobs:
pip install -r dev-requirements.txt
pip install .
pytest test/float8 --verbose -s
pytest test/dtypes/test_affine_quantized_float.py --verbose -s
32 changes: 24 additions & 8 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
)
# @common_utils.parametrize(
# "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
# )
# Inputs are (M,..), K, N
@common_utils.parametrize(
"sizes",
Expand Down Expand Up @@ -147,7 +147,10 @@ def test_fp8_linear_variants(
)
def test_invalid_granularity(self):
with pytest.raises(ValueError, match="Invalid granularity specification"):
float8_dynamic_activation_float8_weight(granularity="invalid")
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model, float8_dynamic_activation_float8_weight(granularity="invalid")
)

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
Expand All @@ -157,7 +160,13 @@ def test_mismatched_granularity(self):
ValueError,
match="Different granularities for activation and weight are not supported",
):
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model,
float8_dynamic_activation_float8_weight(
granularity=(PerTensor(), PerRow())
),
)

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
Expand All @@ -166,9 +175,16 @@ def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass

with pytest.raises(ValueError, match="Invalid granularity types"):
float8_dynamic_activation_float8_weight(
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
with pytest.raises(
ValueError,
match="Invalid granularity types:",
):
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model,
float8_dynamic_activation_float8_weight(
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
),
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_axiswise_reshape(self):
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
@unittest.skipIf(not is_sm_at_least_89(), "Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
Expand Down
14 changes: 7 additions & 7 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ def test_inductor_from_config_params(
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@unittest.skipIf(
not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available"
)
# @unittest.skipIf(
# not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available"
# )
def test_inductor_from_recipe(recipe_name):
torch._dynamo.reset()
config = Float8LinearConfig.from_recipe_name(recipe_name)
Expand Down Expand Up @@ -233,10 +233,10 @@ def forward(self, x):
return x_fp8

# TODO(future): figure out why the test below fails on CUDA capability 8.9
@unittest.skipIf(
not torch.cuda.is_available() or not is_sm_at_least_90(),
"CUDA with capability 9.0 or greater not available",
)
# @unittest.skipIf(
# not torch.cuda.is_available() or not is_sm_at_least_90(),
# "CUDA with capability 9.0 or greater not available",
# )
def test_float8_with_graph_break_in_the_middle(self):
"""Test that having Float8Tensor object at the boundary of a subgraph"""
cnts = CompileCounterWithBackend("inductor")
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_encoder_fw_bw_from_config_params(
],
)
@pytest.mark.skipif(
not is_sm_at_least_90(), reason="requires SM90 compatible machine"
not is_sm_at_least_89(), reason="requires SM90 compatible machine"
)
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
def test_encoder_fw_bw_from_recipe(
Expand Down
Loading