Skip to content

[AOTI][Debug logger] Min value: Error: "min_all_cuda" not implemented for 'Float8_e4m3fn' #149008

@henrylhtsang

Description

@henrylhtsang

🐛 Describe the bug

Problem is with AOTI intermediate debug logger with FP8.

repro:

import torch
import torch._inductor.config as config

config.aot_inductor.debug_intermediate_value_printer = "2"
config.aot_inductor.filtered_kernel_names = "triton_poi_fused__to_copy_add_0"


class Model(torch.nn.Module):
    def forward(self, x):
        x = x.to(torch.float)
        return x + 1

model = Model().cuda()
x = torch.randn(10).cuda().to(torch.float8_e4m3fn)
ep = torch.export.export(model, (x,))
path = torch._inductor.aoti_compile_and_package(ep)
aot_model = torch._inductor.aoti_load_package(path)
aot_model(x)

print("done")

logs:

[ CUDAFloat8_e4m3fnType{10} ]
Number of elements: 10
Dtype: c10::Float8_e4m3fn
Mean value: -0.124023
Min value: Error: "min_all_cuda" not implemented for 'Float8_e4m3fn'

Versions

trunk

cc @yanbing-j @vkuzo @albanD @kadeng @penguinwu @desertfire @chenyang78 @yushangdi @benjaminglass1 @chauhang @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

Labels

module: aotinductoraot inductormodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions