Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions backends/apple/coreml/compiler/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,22 @@ def dequantize_codebook(context, node):

# Assert codebook is as expected. codebook.dim() = codes.dim() + 2
assert len(codebook.shape) == 4, "Only rank 4 inputs are supported for codebook"
assert codebook.shape[0] == 1, "Only grouped_channel granularity is supported"
n_luts = codebook.shape[1]
assert (
codes.shape[1] % n_luts == 0
), "codes.shape[1] must be divisible by codebook.shape[1]"
assert (codebook.shape[0] == 1) or (
codebook.shape[1] == 1
), "Only grouped_channel granularity is supported"
if codebook.shape[0] == 1:
# LUT is per column group
n_luts = codebook.shape[1]
assert (
codes.shape[1] % n_luts == 0
), "codes.shape[1] must be divisible by codebook.shape[1]"
else:
# LUT is per row group
n_luts = codebook.shape[0]
assert (
codes.shape[0] % n_luts == 0
), "codes.shape[0] must be divisible by codebook.shape[0]"

assert codebook.shape[2] == 2**nbits
assert codebook.shape[3] == 1, "Only scalar look up values are supported"

Expand Down
67 changes: 62 additions & 5 deletions backends/apple/coreml/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _coreml_partitioner(self):

def _get_test_model(self):
model = torch.nn.Sequential(
torch.nn.Embedding(64, 128), torch.nn.Linear(128, 128), torch.nn.ReLU()
torch.nn.Embedding(64, 128), torch.nn.Linear(128, 256), torch.nn.ReLU()
)
example_inputs = (torch.LongTensor([0]),)
return model, example_inputs
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test_dequantize_codebook_linear(self):
def test_dequantize_codebook_linear_per_grouped_col(self):
model, example_inputs = self._get_test_model()
quantize_(
model,
Expand All @@ -185,7 +185,34 @@ def test_dequantize_codebook_linear(self):
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test_dequantize_codebook_embedding(self):
def test_dequantize_codebook_linear_per_grouped_row(self):
model, example_inputs = self._get_test_model()
quantize_(
model,
CodebookWeightOnlyConfig(dtype=torch.uint2, block_size=[16, -1]),
)
ep = torch.export.export(model, example_inputs)
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
delegated_program = executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[self._coreml_partitioner()],
)
for node in delegated_program.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"executorch_call_delegate",
"getitem",
], f"Got unexpected node target after delegation: {node.target.__name__}"

assert (
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
in format_delegated_graph(delegated_program.exported_program().graph_module)
)

et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test_dequantize_codebook_embedding_per_grouped_col(self):
model, example_inputs = self._get_test_model()
quantize_(
model,
Expand Down Expand Up @@ -213,6 +240,34 @@ def test_dequantize_codebook_embedding(self):
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test_dequantize_codebook_embedding_per_grouped_row(self):
model, example_inputs = self._get_test_model()
quantize_(
model,
CodebookWeightOnlyConfig(dtype=torch.uint3, block_size=[16, -1]),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
ep = torch.export.export(model, example_inputs)
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
delegated_program = executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[self._coreml_partitioner()],
)
for node in delegated_program.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"executorch_call_delegate",
"getitem",
], f"Got unexpected node target after delegation: {node.target.__name__}"

assert (
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
in format_delegated_graph(delegated_program.exported_program().graph_module)
)

et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)


if __name__ == "__main__":
test_runner = TestTorchOps()
Expand All @@ -221,5 +276,7 @@ def test_dequantize_codebook_embedding(self):
test_runner.test_dequantize_affine_c4w_embedding()
test_runner.test_dequantize_affine_c4w_linear()
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
test_runner.test_dequantize_codebook_linear()
test_runner.test_dequantize_codebook_embedding()
test_runner.test_dequantize_codebook_linear_per_grouped_col()
test_runner.test_dequantize_codebook_linear_per_grouped_row()
test_runner.test_dequantize_codebook_embedding_per_grouped_col()
test_runner.test_dequantize_codebook_embedding_per_grouped_row()
2 changes: 1 addition & 1 deletion third-party/ao
Submodule ao updated 289 files
Loading