-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
module: aotinductoraot inductoraot inductormodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesFor torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Problem is with AOTI intermediate debug logger with FP8.
repro:
import torch
import torch._inductor.config as config
config.aot_inductor.debug_intermediate_value_printer = "2"
config.aot_inductor.filtered_kernel_names = "triton_poi_fused__to_copy_add_0"
class Model(torch.nn.Module):
def forward(self, x):
x = x.to(torch.float)
return x + 1
model = Model().cuda()
x = torch.randn(10).cuda().to(torch.float8_e4m3fn)
ep = torch.export.export(model, (x,))
path = torch._inductor.aoti_compile_and_package(ep)
aot_model = torch._inductor.aoti_load_package(path)
aot_model(x)
print("done")
logs:
[ CUDAFloat8_e4m3fnType{10} ]
Number of elements: 10
Dtype: c10::Float8_e4m3fn
Mean value: -0.124023
Min value: Error: "min_all_cuda" not implemented for 'Float8_e4m3fn'
Versions
trunk
cc @yanbing-j @vkuzo @albanD @kadeng @penguinwu @desertfire @chenyang78 @yushangdi @benjaminglass1 @chauhang @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
Metadata
Metadata
Assignees
Labels
module: aotinductoraot inductoraot inductormodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesFor torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module