Skip to content

Commit 060c9e9

Browse files
committed
fix format issue
1 parent 4e46f12 commit 060c9e9

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
)
5151
_DEVICE = get_current_accelerator_device()
5252

53+
5354
def get_quantization_functions(
5455
do_sparse: bool, do_int4: bool, device: str = _DEVICE, int4_zp_int: bool = False
5556
):

test/dtypes/test_affine_quantized_float.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
torch.manual_seed(0)
5050
_DEVICE = get_current_accelerator_device()
5151

52+
5253
class ToyLinearModel(torch.nn.Module):
5354
def __init__(self, in_features, out_features):
5455
super().__init__()
@@ -141,14 +142,16 @@ def test_fp8_linear_variants(
141142
)
142143

143144
@unittest.skipIf(
144-
_DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
145+
_DEVICE == "cuda" and not is_sm_at_least_89(),
146+
"Requires GPU with compute capability >= 8.9",
145147
)
146148
def test_invalid_granularity(self):
147149
with pytest.raises(ValueError, match="Invalid granularity specification"):
148150
Float8DynamicActivationFloat8WeightConfig(granularity="invalid")
149151

150152
@unittest.skipIf(
151-
_DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
153+
_DEVICE == "cuda" and not is_sm_at_least_89(),
154+
"Requires GPU with compute capability >= 8.9",
152155
)
153156
def test_mismatched_granularity(self):
154157
with pytest.raises(
@@ -160,7 +163,8 @@ def test_mismatched_granularity(self):
160163
)
161164

162165
@unittest.skipIf(
163-
_DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
166+
_DEVICE == "cuda" and not is_sm_at_least_89(),
167+
"Requires GPU with compute capability >= 8.9",
164168
)
165169
def test_unsupported_granularity(self):
166170
class UnsupportedGranularity:
@@ -356,7 +360,8 @@ def test_mm_float8dq_per_row(
356360

357361
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
358362
@unittest.skipIf(
359-
_DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
363+
_DEVICE == "cuda" and not is_sm_at_least_89(),
364+
"Requires GPU with compute capability >= 8.9",
360365
)
361366
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
362367
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
@@ -399,7 +404,8 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
399404

400405
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
401406
@unittest.skipIf(
402-
_DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
407+
_DEVICE == "cuda" and not is_sm_at_least_89(),
408+
"Requires GPU with compute capability >= 8.9",
403409
)
404410
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
405411
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
@@ -432,7 +438,8 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
432438

433439
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
434440
@unittest.skipIf(
435-
_DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
441+
_DEVICE == "cuda" and not is_sm_at_least_89(),
442+
"Requires GPU with compute capability >= 8.9",
436443
)
437444
def test_dequantize_affine_float8_scale_broadcasting(self):
438445
"""Test that scale broadcasting works correctly for block-wise quantization"""

0 commit comments

Comments
 (0)