Skip to content

Commit ac03072

Browse files
authored
Fix inference tests and add export test (#613)
1 parent 2d33197 commit ac03072

File tree

1 file changed

+77
-4
lines changed

1 file changed

+77
-4
lines changed

test/float8/test_inference_flows.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@
55
# LICENSE file in the root directory of this source tree.
66
import copy
77
import io
8+
import os
89
import random
910
import unittest
1011

1112
import pytest
12-
13-
from torchao.utils import TORCH_VERSION_AFTER_2_4
13+
from unittest.mock import patch
14+
from torchao.utils import (
15+
TORCH_VERSION_AFTER_2_4,
16+
unwrap_tensor_subclass,
17+
)
1418

1519
if not TORCH_VERSION_AFTER_2_4:
1620
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1721

1822
import torch
1923
import torch.nn as nn
2024
import torch.nn.functional as F
21-
from torchao.float8.config import ScalingType
25+
from torch.export._trace import _export as _export_private
2226
from torchao.float8.float8_linear_utils import convert_to_float8_training
2327
from torchao.float8.float8_tensor import Float8Tensor
2428
from torchao.float8.float8_utils import compute_error
@@ -53,6 +57,11 @@ def reset_parameters(self):
5357

5458

5559
class TestHPTrainToFP8LinearInference:
60+
@pytest.fixture(autouse=True)
61+
def setup_mock(self):
62+
with patch("torch._dynamo.config.cache_size_limit", 20):
63+
yield
64+
5665
def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor):
5766
with torch.no_grad():
5867
base_output = base_mlp(input_tensor)
@@ -126,7 +135,9 @@ def test_static_fp8_mlp(self, compile_backend, dtype):
126135

127136
# Compile the models
128137
compiled_original_mlp = torch.compile(
129-
original_mlp, backend=compile_backend, fullgraph=True
138+
original_mlp,
139+
backend=compile_backend,
140+
fullgraph=True,
130141
)
131142
compiled_static_fp8_mlp = torch.compile(
132143
static_fp8_mlp, backend=compile_backend, fullgraph=True
@@ -172,6 +183,11 @@ def test_weight_only_fp8_mlp(self, compile_backend, dtype):
172183

173184

174185
class TestFP8TrainToFP8LinearInference:
186+
@pytest.fixture(autouse=True)
187+
def setup_mock(self):
188+
with patch("torch._dynamo.config.cache_size_limit", 20):
189+
yield
190+
175191
def train(self, model: nn.Module, dtype: torch.dtype):
176192
model.train()
177193
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
@@ -241,5 +257,62 @@ def test_fp8_save_and_load(self, dtype: torch.dtype):
241257
assert torch.all(og_out == new_out).item()
242258

243259

260+
class TestFP8Export:
261+
@pytest.fixture(autouse=True)
262+
def setup_mock(self):
263+
with patch("torch._dynamo.config.cache_size_limit", 20):
264+
yield
265+
266+
@unittest.skipIf(
267+
not torch.cuda.is_available() or not is_H100,
268+
"CUDA not available or on non H100 machine",
269+
)
270+
def test_fp8_export(self):
271+
export_model = FeedForward().to("cuda")
272+
quant_config = QuantConfig(ActivationCasting.DYNAMIC)
273+
quantize_to_float8(export_model, quant_config)
274+
batch_size = 4
275+
num_tokens = 1024
276+
embedding_dim = 4096
277+
278+
inp = torch.randn(
279+
batch_size, num_tokens, embedding_dim, device="cuda", dtype=torch.float32
280+
)
281+
example_args = (inp,)
282+
283+
fp8_compile_model = copy.deepcopy(export_model)
284+
fp8_compile_model = torch.compile(fp8_compile_model)
285+
fp8_compile_out = fp8_compile_model(*example_args)
286+
287+
# Export model with subclass weights
288+
289+
export_model = unwrap_tensor_subclass(export_model)
290+
291+
# Export the model
292+
exported_model = _export_private(
293+
export_model,
294+
example_args,
295+
strict=False,
296+
pre_dispatch=False,
297+
)
298+
299+
so_path = None
300+
try:
301+
# Compile the exported program to a .so using AOTInductor
302+
with torch.no_grad():
303+
so_path = torch._inductor.aot_compile(
304+
exported_model.module(), example_args
305+
)
306+
307+
# Load and run the .so file in Python
308+
res = torch._export.aot_load(so_path, device="cuda")(example_args)
309+
torch.testing.assert_close(fp8_compile_out, res)
310+
311+
finally:
312+
# Cleanup: remove the .so file
313+
if so_path and os.path.exists(so_path):
314+
os.remove(so_path)
315+
316+
244317
if __name__ == "__main__":
245318
pytest.main([__file__])

0 commit comments

Comments
 (0)