Skip to content

Commit e7f9beb

Browse files
committed
Update on "Remove sharded ckpt from export_llama"
Sharded checkpoint isn't used anymore; removing it and simplifying export_llama. Differential Revision: [D87828518](https://our.internmc.facebook.com/intern/diff/D87828518/) [ghstack-poisoned]
2 parents e835c63 + 8bd7088 commit e7f9beb

File tree

76 files changed

+3036
-765
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+3036
-765
lines changed

.github/workflows/pull.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,15 +862,24 @@ jobs:
862862
# Install Node.js and Emscripten
863863
source .ci/scripts/setup-emscripten.sh
864864
865+
export PNPM_VERSION=10.24.0
866+
867+
curl -fsSL https://get.pnpm.io/install.sh | env PNPM_VERSION=$PNPM_VERSION SHELL="$(which bash)" sh -
868+
869+
export PNPM_HOME="$HOME/.local/share/pnpm"
870+
export PATH="$PNPM_HOME:$PATH"
871+
872+
pnpm --version
873+
865874
# Test selective build
866875
bash scripts/build_wasm_tests.sh ${{ matrix.enable-etdump }}
867876
868877
# Install Jest
869878
cd cmake-out-wasm/extension/wasm/test
870-
npm install --save-dev jest
879+
pnpm add -D jest@30.2.0 --ignore-scripts
871880
872881
# Run unit test
873-
npm test
882+
pnpm test
874883
875884
unittest-nxp-neutron:
876885
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/aoti/aoti_backend.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import typing
1010
from abc import ABC, abstractmethod
1111
from enum import Enum
12-
from typing import Any, Dict, List, Optional, Set
12+
from typing import Any, Dict, List, Set
1313

1414
import torch
1515
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
@@ -91,39 +91,24 @@ def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]
9191
)
9292

9393
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
94-
self,
95-
kernel: str,
96-
args: list[str],
97-
device: str,
98-
*,
99-
debug_args: Optional[list[str]] = None,
100-
debug_handle: Optional[int] = None,
101-
):
94+
self, kernel: str, *args: Any, **kwargs: Any
95+
) -> None:
10296
if kernel not in supported_kernels:
10397
missing_fallback_kernels.add(kernel)
10498

105-
original_generate_c_shim_extern_kernel_call(
106-
self,
107-
kernel,
108-
args,
109-
device,
110-
debug_args=debug_args,
111-
debug_handle=debug_handle,
99+
return original_generate_c_shim_extern_kernel_call(
100+
self, kernel, *args, **kwargs
112101
)
113102

114103
def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
115-
self,
116-
op_overload,
117-
raw_args,
118-
output_args,
119-
raw_outputs,
120-
):
104+
self, op_overload: Any, *args: Any, **kwargs: Any
105+
) -> None:
121106
kernel_name = getattr(op_overload, "_name", str(op_overload))
122107
if kernel_name not in supported_kernels:
123108
missing_fallback_kernels.add(kernel_name)
124109

125-
original_generate_fallback_kernel_with_runtime_lookup_aot(
126-
self, op_overload, raw_args, output_args, raw_outputs
110+
return original_generate_fallback_kernel_with_runtime_lookup_aot(
111+
self, op_overload, *args, **kwargs
127112
)
128113

129114
CppWrapperCpu.generate_c_shim_extern_kernel_call = (

backends/arm/quantizer/quantization_annotator.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import torch
1818
import torch.fx
19-
import torch.nn.functional as F
2019
from executorch.backends.arm.common.debug import get_node_debug_info
2120
from executorch.backends.arm.common.type import ensure_type
2221
from executorch.backends.arm.quantizer import QuantizationConfig
@@ -477,7 +476,11 @@ def get_quant_properties( # noqa: C901
477476
def any_or_hardtanh_min_zero(n: Node):
478477
"""Return True for any op or hardtanh with ``min_val == 0``."""
479478
# Check that if the node is a hardtanh, its min_val is zero
480-
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
479+
return (
480+
n.target
481+
not in (torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default)
482+
or n.args[1] == 0
483+
)
481484

482485
if _match_pattern(
483486
node,
@@ -487,11 +490,14 @@ def any_or_hardtanh_min_zero(n: Node):
487490
torch.ops.aten.conv2d.default,
488491
torch.ops.aten.conv2d.padding,
489492
],
490-
[torch.ops.aten.batch_norm.default, F.batch_norm],
493+
[
494+
torch.ops.aten.batch_norm.default,
495+
],
491496
[
492497
torch.ops.aten.relu.default,
493498
torch.ops.aten.relu_.default,
494499
torch.ops.aten.hardtanh.default,
500+
torch.ops.aten.hardtanh_.default,
495501
],
496502
],
497503
filter_fn=any_or_hardtanh_min_zero,
@@ -510,6 +516,7 @@ def any_or_hardtanh_min_zero(n: Node):
510516
torch.ops.aten.relu.default,
511517
torch.ops.aten.relu_.default,
512518
torch.ops.aten.hardtanh.default,
519+
torch.ops.aten.hardtanh_.default,
513520
):
514521
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
515522

@@ -521,7 +528,9 @@ def any_or_hardtanh_min_zero(n: Node):
521528
torch.ops.aten.conv2d.default,
522529
torch.ops.aten.conv2d.padding,
523530
],
524-
[torch.ops.aten.batch_norm.default, F.batch_norm],
531+
[
532+
torch.ops.aten.batch_norm.default,
533+
],
525534
],
526535
):
527536
if node.target in (
@@ -534,7 +543,9 @@ def any_or_hardtanh_min_zero(n: Node):
534543
_QuantProperty(1, weight_qspec, mark_annotated=True),
535544
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
536545
]
537-
elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]:
546+
elif node.target in [
547+
torch.ops.aten.batch_norm.default,
548+
]:
538549
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
539550
elif _match_pattern(
540551
node,
@@ -549,6 +560,7 @@ def any_or_hardtanh_min_zero(n: Node):
549560
torch.ops.aten.relu.default,
550561
torch.ops.aten.relu_.default,
551562
torch.ops.aten.hardtanh.default,
563+
torch.ops.aten.hardtanh_.default,
552564
],
553565
],
554566
any_or_hardtanh_min_zero,

backends/arm/test/misc/test_bn_relu_folding_qat.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from typing import Tuple
77

88
import torch
9-
import torch.nn.functional as F
109
from executorch.backends.arm.quantizer.arm_quantizer import (
1110
get_symmetric_quantization_config,
1211
TOSAQuantizer,
1312
)
14-
from executorch.backends.arm.test import common, conftest
13+
from executorch.backends.arm.test import common
1514
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
15+
from executorch.backends.arm.tosa import TosaSpecification
1616

1717
from executorch.backends.xnnpack.test.tester.tester import Quantize
1818
from torch import nn
@@ -21,51 +21,104 @@
2121
input_t1 = Tuple[torch.Tensor] # Input x
2222

2323

24-
class ConvModule(torch.nn.Module):
24+
class Conv2dModule(torch.nn.Module):
2525
input_shape = (1, 28, 28)
2626
batch_size = 64
2727
test_data: input_t1 = (torch.randn(batch_size, *input_shape),)
2828

29-
def __init__(self, batch_norm: bool = True) -> None:
29+
def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None:
3030
super().__init__()
3131
self.conv = torch.nn.Conv2d(1, 16, 3, stride=2)
3232
self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity()
33+
self.relu = nn.ReLU(inplace=inplace)
3334

3435
def forward(self, x: torch.Tensor):
3536
x = self.conv(x)
3637
x = self.bn(x)
37-
x = F.relu(x)
38+
x = self.relu(x)
39+
40+
return x
41+
42+
43+
class Conv1dModule(torch.nn.Module):
44+
input_shape = (3, 10)
45+
batch_size = 2
46+
test_data: input_t1 = (torch.randn(batch_size, *input_shape),)
47+
48+
def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None:
49+
super().__init__()
50+
self.conv = torch.nn.Conv1d(3, 8, 5, padding=2)
51+
self.bn = nn.BatchNorm1d(num_features=8) if batch_norm else nn.Identity()
52+
self.relu = nn.ReLU(inplace=inplace)
53+
54+
def forward(self, x: torch.Tensor):
55+
x = self.conv(x)
56+
x = self.bn(x)
57+
x = self.relu(x)
3858

3959
return x
4060

4161

4262
models = {
4363
# name : (model, is_per_channel)
44-
"conv_bn_relu_per_channel": (ConvModule(batch_norm=True), True),
45-
"conv_relu_per_channel": (ConvModule(batch_norm=False), True),
46-
"conv_bn_relu_per_tensor": (ConvModule(batch_norm=True), False),
47-
"conv_relu_per_tensor": (ConvModule(batch_norm=False), False),
64+
"conv1d_bn_relu_per_channel": (Conv1dModule(batch_norm=True), True),
65+
"conv1d_relu_per_channel": (Conv1dModule(batch_norm=False), True),
66+
"conv1d_bn_relu_per_tensor": (Conv1dModule(batch_norm=True), False),
67+
"conv1d_relu_per_tensor": (Conv1dModule(batch_norm=False), False),
68+
"conv2d_bn_relu_per_channel": (Conv2dModule(batch_norm=True), True),
69+
"conv2d_relu_per_channel": (Conv2dModule(batch_norm=False), True),
70+
"conv2d_bn_relu_per_tensor": (Conv2dModule(batch_norm=True), False),
71+
"conv2d_relu_per_tensor": (Conv2dModule(batch_norm=False), False),
72+
"conv1d_bn_relu_inplace_per_channel": (
73+
Conv1dModule(batch_norm=True, inplace=True),
74+
True,
75+
),
76+
"conv1d_relu_inplace_per_channel": (
77+
Conv1dModule(batch_norm=False, inplace=True),
78+
True,
79+
),
80+
"conv1d_bn_relu_inplace_per_tensor": (
81+
Conv1dModule(batch_norm=True, inplace=True),
82+
False,
83+
),
84+
"conv1d_relu_inplace_per_tensor": (
85+
Conv1dModule(batch_norm=False, inplace=True),
86+
False,
87+
),
88+
"conv2d_bn_relu_inplace_per_channel": (
89+
Conv2dModule(batch_norm=True, inplace=True),
90+
True,
91+
),
92+
"conv2d_relu_inplace_per_channel": (
93+
Conv2dModule(batch_norm=False, inplace=True),
94+
True,
95+
),
96+
"conv2d_bn_relu_inplace_per_tensor": (
97+
Conv2dModule(batch_norm=True, inplace=True),
98+
False,
99+
),
100+
"conv2d_relu_inplace_per_tensor": (
101+
Conv2dModule(batch_norm=False, inplace=True),
102+
False,
103+
),
48104
}
49105

50106

51-
@common.parametrize("test_data", models)
107+
@common.parametrize(
108+
"test_data",
109+
models,
110+
)
52111
def test_qat_tosa_INT(test_data):
53112
model, per_channel = test_data
54113
pipeline = TosaPipelineINT[input_t1](model, model.test_data, [], [], qtol=1)
55-
tosa_version = conftest.get_option("tosa_version")
56-
tosa_profiles = {
57-
"1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"),
58-
}
59-
tosa_spec = tosa_profiles[tosa_version]
60-
quantizer = TOSAQuantizer(tosa_spec)
114+
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
61115
pipeline.change_args(
62116
"quantize",
63117
Quantize(
64118
quantizer=quantizer,
65119
quantization_config=get_symmetric_quantization_config(
66120
is_qat=True, is_per_channel=per_channel
67121
),
68-
is_qat=True,
69122
),
70123
)
71124
pipeline.run()

backends/arm/test/ops/test_layer_norm.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,50 @@ def test_native_layer_norm_vgf_INT(test_data):
137137
tosa_version="TOSA-1.0+INT",
138138
)
139139
pipeline.run()
140+
141+
142+
@common.parametrize("test_data", test_data_suite)
143+
def test_native_layer_norm_tosa_INT_a16w8(test_data):
144+
"""Test layer_norm with int16 I/O quantization for TOSA INT."""
145+
test_input, model = test_data()
146+
pipeline = TosaPipelineINT[input_t](
147+
model,
148+
test_input,
149+
"torch.ops.aten.sub.Tensor", # check for sub op in decomposition
150+
symmetric_io_quantization=True,
151+
tosa_extensions=["int16"],
152+
epsilon=2**16,
153+
)
154+
pipeline.run()
155+
156+
157+
@common.parametrize("test_data", test_data_suite)
158+
@common.XfailIfNoCorstone300
159+
def test_native_layer_norm_16a8w_u55_INT16(test_data):
160+
"""Test layer_norm with int16 I/O quantization for U55"""
161+
test_input, model = test_data()
162+
pipeline = EthosU55PipelineINT[input_t](
163+
model,
164+
test_input,
165+
"torch.ops.aten.sub.Tensor",
166+
symmetric_io_quantization=True,
167+
a16w8_quantization=True,
168+
epsilon=2**16,
169+
)
170+
pipeline.run()
171+
172+
173+
@common.parametrize("test_data", test_data_suite)
174+
@common.XfailIfNoCorstone320
175+
def test_native_layer_norm_16a8w_u85_INT16(test_data):
176+
"""Test layer_norm with int16 I/O quantization for U85"""
177+
test_input, model = test_data()
178+
pipeline = EthosU85PipelineINT[input_t](
179+
model,
180+
test_input,
181+
"torch.ops.aten.sub.Tensor",
182+
symmetric_io_quantization=True,
183+
a16w8_quantization=True,
184+
epsilon=2**16,
185+
)
186+
pipeline.run()

0 commit comments

Comments
 (0)