4949torch .manual_seed (0 )
5050_DEVICE = get_current_accelerator_device ()
5151
52+
5253class ToyLinearModel (torch .nn .Module ):
5354 def __init__ (self , in_features , out_features ):
5455 super ().__init__ ()
@@ -141,14 +142,16 @@ def test_fp8_linear_variants(
141142 )
142143
143144 @unittest .skipIf (
144- _DEVICE == "cuda" and not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
145+ _DEVICE == "cuda" and not is_sm_at_least_89 (),
146+ "Requires GPU with compute capability >= 8.9" ,
145147 )
146148 def test_invalid_granularity (self ):
147149 with pytest .raises (ValueError , match = "Invalid granularity specification" ):
148150 Float8DynamicActivationFloat8WeightConfig (granularity = "invalid" )
149151
150152 @unittest .skipIf (
151- _DEVICE == "cuda" and not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
153+ _DEVICE == "cuda" and not is_sm_at_least_89 (),
154+ "Requires GPU with compute capability >= 8.9" ,
152155 )
153156 def test_mismatched_granularity (self ):
154157 with pytest .raises (
@@ -160,7 +163,8 @@ def test_mismatched_granularity(self):
160163 )
161164
162165 @unittest .skipIf (
163- _DEVICE == "cuda" and not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
166+ _DEVICE == "cuda" and not is_sm_at_least_89 (),
167+ "Requires GPU with compute capability >= 8.9" ,
164168 )
165169 def test_unsupported_granularity (self ):
166170 class UnsupportedGranularity :
@@ -356,7 +360,8 @@ def test_mm_float8dq_per_row(
356360
357361 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
358362 @unittest .skipIf (
359- _DEVICE == "cuda" and not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
363+ _DEVICE == "cuda" and not is_sm_at_least_89 (),
364+ "Requires GPU with compute capability >= 8.9" ,
360365 )
361366 @common_utils .parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
362367 @common_utils .parametrize ("output_dtype" , [torch .float32 , torch .bfloat16 ])
@@ -399,7 +404,8 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
399404
400405 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
401406 @unittest .skipIf (
402- _DEVICE == "cuda" and not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
407+ _DEVICE == "cuda" and not is_sm_at_least_89 (),
408+ "Requires GPU with compute capability >= 8.9" ,
403409 )
404410 @common_utils .parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
405411 @common_utils .parametrize ("output_dtype" , [torch .float32 , torch .bfloat16 ])
@@ -432,7 +438,8 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
432438
433439 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
434440 @unittest .skipIf (
435- _DEVICE == "cuda" and not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
441+ _DEVICE == "cuda" and not is_sm_at_least_89 (),
442+ "Requires GPU with compute capability >= 8.9" ,
436443 )
437444 def test_dequantize_affine_float8_scale_broadcasting (self ):
438445 """Test that scale broadcasting works correctly for block-wise quantization"""
0 commit comments