Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions backends/transforms/fuse_conv_with_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

import sys

import executorch.backends.vulkan.custom_ops_lib # noqa

import torch
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
conv_with_clamp_op,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down
2 changes: 1 addition & 1 deletion backends/transforms/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def define_common_targets():
deps = [
":utils",
"//caffe2:torch",
"//executorch/backends/vulkan/_passes:custom_ops_defs",
"//executorch/backends/vulkan:custom_ops_lib",
"//executorch/exir:pass_base",
"//executorch/exir:sym_util",
"//executorch/exir/dialects:lib",
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ set(vulkan_standard_shaders_cpp ${generated_spv_cpp})
set(SCHEMA_INCLUDE_DIR ${CMAKE_BINARY_DIR}/schema/include)

set(GENERATED_HEADER
${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/schema_generated.h
${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/serialization/schema_generated.h
)

add_custom_command(
OUTPUT ${GENERATED_HEADER}
COMMAND
${FLATC_EXECUTABLE} --cpp --cpp-std c++11 --scoped-enums -o
"${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/" ${_vulkan_schema__srcs}
"${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/serialization/" ${_vulkan_schema__srcs}
WORKING_DIRECTORY ${EXECUTORCH_ROOT}
COMMENT "Generating vulkan_schema headers"
VERBATIM
Expand Down
33 changes: 0 additions & 33 deletions backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
@@ -1,37 +1,4 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets(is_fbcode = True)

runtime.python_library(
name = "vulkan_preprocess",
srcs = [
"serialization/vulkan_graph_builder.py",
"serialization/vulkan_graph_schema.py",
"serialization/vulkan_graph_serialize.py",
"vulkan_preprocess.py",
],
resources = [
"serialization/schema.fbs",
],
visibility = [
"//executorch/...",
"//executorch/vulkan/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//executorch/backends/transforms:addmm_mm_to_linear",
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
"//executorch/backends/transforms:fuse_conv_with_clamp",
"//executorch/backends/transforms:fuse_dequant_linear",
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/exir:graph_module",
"//executorch/exir/_serialize:_bindings",
"//executorch/exir/_serialize:lib",
"//executorch/exir/backend:backend_details",
],
)
27 changes: 1 addition & 26 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "custom_ops_defs",
srcs = [
"custom_ops_defs.py",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//caffe2:torch",
],
)

python_unittest(
name = "test_custom_ops",
srcs = [
"test_custom_ops.py",
],
deps = [
":custom_ops_defs",
"//caffe2:torch",
],
)

runtime.python_library(
name = "insert_prepack_nodes",
srcs = ["insert_prepack_nodes.py"],
Expand Down Expand Up @@ -62,7 +37,7 @@ runtime.python_library(
"//executorch/backends/...",
],
deps = [
":custom_ops_defs",
"//executorch/backends/vulkan:custom_ops_lib",
"//pytorch/ao:torchao",
]
)
Expand Down
41 changes: 14 additions & 27 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,27 @@

# pyre-strict

from typing import List

import executorch.backends.vulkan._passes.custom_ops_defs # noqa
import executorch.backends.vulkan.custom_ops_lib # noqa

import torch

from executorch.backends.vulkan.op_registry import handles_own_prepacking

from executorch.exir.dialects._ops import ops as exir_ops

from torch._export.utils import is_buffer, is_param
from torch.export import ExportedProgram

USES_WEIGHTS: List[torch._ops.OpOverload] = [
exir_ops.edge.aten.embedding.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.et_vk.conv_with_clamp.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten._weight_int8pack_mm.default,
exir_ops.edge.et_vk.linear_weight_int4.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
"llama::sdpa_with_kv_cache",
]


def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
"""
Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator
is responsible for transferring the tensor data, which is serialized with the model,
to a GPU tensor object during the prepacking stage of model execution.

Some operators, listed in `USES_WEIGHTS` above, are performance sensitive and will
prefer to handle prepacking within the operator. For these ops, the constant tensor
data will be passed directly as an argument into the operator implementation.
Some operators are performance sensitive and will prefer to handle prepacking within
the operator. For these ops, the constant tensor data will be passed directly as an
argument into the operator implementation.
"""

def is_get_attr_node(node: torch.fx.Node) -> bool:
Expand All @@ -58,22 +46,21 @@ def is_param_node(node: torch.fx.Node) -> bool:
or is_constant(node)
)

def is_non_weight_param_tensor(node: torch.fx.Node) -> bool:
def prepack_not_required(node: torch.fx.Node) -> bool:
if not is_param_node(node):
return False
return True

for user in node.users:
if user.op == "call_function" and (
# pyre-ignore [16]
user.target in USES_WEIGHTS
or user.target.name() in USES_WEIGHTS
if user.op == "call_function" and handles_own_prepacking(
# pyre-ignore
user.target
):
return False
return True

return True
return False

for node in program.graph_module.graph.nodes:
if not is_non_weight_param_tensor(node):
if prepack_not_required(node):
continue

with program.graph_module.graph.inserting_after(node):
Expand Down
6 changes: 2 additions & 4 deletions backends/vulkan/_passes/int4_weight_only_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import logging
from typing import Any, Callable, Dict, Optional, Type

import executorch.backends.vulkan.custom_ops_lib # noqa

import torch
import torch.nn.functional as F

from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
linear_weight_int4_op,
)

from torchao.quantization.GPTQ import _check_linear_int4_k
from torchao.quantization.unified import Quantizer
from torchao.quantization.utils import groupwise_affine_quantize_tensor
Expand Down
124 changes: 0 additions & 124 deletions backends/vulkan/_passes/test_custom_ops.py

This file was deleted.

Loading
Loading