Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def triton_quantize_mx4_unpack(
stochastic_casting (bool): Whether to use stochastic casting.

Returns:
torch.Tensor: [M / 2] mx4 scaled tensor packed into in8
torch.Tensor: [M / 2] mx4 scaled tensor packed into uint8
torch.Tensor: [M / group_size] mx4 shared exponents into int8

eg.
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,7 +2385,7 @@ def quantize(self, x, w):

def compute(self, xq, wq, x_scale, w_scale, global_scale):
return torch.ops.fbgemm.f4f4bf16(
xq, wq, x_scale, w_scale, global_scale=global_scale, use_mx=False
xq, wq, x_scale, w_scale, global_scale=global_scale
)

def quantize_and_compute(self, x, w):
Expand Down Expand Up @@ -2471,7 +2471,7 @@ def quantize(self, x, w):

def compute(self, xq, wq, x_scale, w_scale, global_scale):
return torch.ops.fbgemm.f4f4bf16(
xq, wq, x_scale, w_scale, global_scale=global_scale, use_mx=False
xq, wq, x_scale, w_scale, global_scale=global_scale
)

def quantize_and_compute(self, x, w):
Expand Down
Loading
Loading