diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index cc9eb9f02ee..c0f60529895 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -1,6 +1,8 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -# pyre-strict +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. import argparse @@ -24,55 +26,7 @@ sys.path.insert(0, ".") from llama_transformer import InputManager, load_model - - -class SplitLinearModule(torch.nn.Module): - def __init__(self, in_features, out_features, target_split_size, max_splits): - super(SplitLinearModule, self).__init__() - num_splits = max(out_features // target_split_size, 1) - if num_splits > max_splits: - num_splits = max_splits - - self.split_size = out_features // num_splits - self.split_remainder = out_features % num_splits - self.splits = torch.nn.ModuleList( - [torch.nn.Linear(in_features, self.split_size) for _ in range(num_splits)] - ) - print( - f"Splitting out_features={out_features} into {num_splits} of size {self.split_size}" - ) - if self.split_remainder > 0: - print( - f"Warning: remainder {self.split_remainder} after splitting out_features={out_features} into {num_splits} of size {self.split_size}" - ) - self.splits.append(torch.nn.Linear(in_features, self.split_remainder)) - - def split_sizes(self): - return [split.out_features for split in self.splits] - - def forward(self, x): - return torch.cat([split(x) for split in self.splits], dim=-1) - - -def replace_linear_with_split_linear(model, target_split_size, max_splits): - for name, module in model.named_children(): - if isinstance(module, torch.nn.Linear): - new_module = SplitLinearModule( - module.in_features, module.out_features, target_split_size, max_splits - ) - split_sizes = new_module.split_sizes() - if module.bias is not None: - split_bias = module.bias.split(split_sizes) - split_weights = module.weight.split(split_sizes, dim=0) - for i, split in enumerate(new_module.splits): - split.weight = torch.nn.Parameter(split_weights[i]) - if module.bias is not None: - split.bias = torch.nn.Parameter(split_bias[i]) - else: - split.bias = None - setattr(model, name, new_module) - else: - replace_linear_with_split_linear(module, target_split_size, max_splits) +from utils import replace_linear_with_split_linear def main() -> None: @@ -175,7 +129,13 @@ def main() -> None: if export_args.target_split_size is not None: replace_linear_with_split_linear( - model, export_args.target_split_size, export_args.max_splits + model, + out_target_split_size=export_args.target_split_size, + out_max_splits=export_args.max_splits, + # I have not found splitting on in_features to be beneficial, + # and it often leads to OOM so I set in_max_splits to 1 + in_target_split_size=1, + in_max_splits=1, ) model.eval() @@ -241,6 +201,7 @@ def main() -> None: ep, preserve_ops=[ torch.ops.aten.scaled_dot_product_attention.default, + # preserve norm op for numerical stability torch.ops.aten.linalg_vector_norm.default, ], compile_config=EdgeCompileConfig( diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index a9efedf6bbe..14dff0c8580 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -38,8 +38,9 @@ The runner can also be used to run an eager model model to compare with CoreML n We are actively experimenting with different settings. But here are ones that we've found work well for Llama1B on iPhone 15 Pro: -* Set use_cache_list -* Split linear layers with target_split_size=1024, max_splits=8 -* Use seq_length=32 or seq_length=64, both of which offer reasonable tradeoffs for prefill and decode performance. seq_length=32 is better at decode and seq_length=64 is better at prefill. - -In our tests, we set max_seq_length=1024, but if your application allows for it, performance can improve with max_seq_length=512 or by keeping max_seq_length=1024 and setting cache_size=512-seq_length. +* Set use_cache_list. +* Use seq_length = 32, which offers a good balance between prefill/decode performance. +* Split out_features in linear layers with target_split_size=1024, max_splits=8. +* For ANE, set dtype = fp16, coreml-quantize = c4w. The requires doing QAT on Llama1B for good accuracy. +* Set embedding-quantize to "4,32". +* Set max_seq_length to 128, 256, 512, 1024, and 2048, depending on needed context. Note that performance drops with max_seq_length. More specifically, performance drops with cache_size, and the best experience may require a good cache eviction policy. The python runner in run.py uses a last-in-last-out policy when cache_size is specified. diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py index 501aaee07ed..de22794dee1 100644 --- a/examples/apple/coreml/llama/run.py +++ b/examples/apple/coreml/llama/run.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import argparse import sys diff --git a/examples/apple/coreml/llama/test.py b/examples/apple/coreml/llama/test.py new file mode 100644 index 00000000000..895cf2e1cce --- /dev/null +++ b/examples/apple/coreml/llama/test.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +sys.path.insert(0, ".") +import copy + +import torch +from utils import replace_linear_with_split_linear + + +def get_split_model( + model, + out_target_split_size=1, + out_max_splits=1, + in_target_split_size=1, + in_max_splits=1, +): + model_copy = copy.deepcopy(model) + replace_linear_with_split_linear( + model_copy, + out_target_split_size, + out_max_splits, + in_target_split_size, + in_max_splits, + ) + return model_copy + + +def test_split_model(): + inputs = torch.randn(10, 5, 1, 512) + + model = torch.nn.Sequential(*[torch.nn.Linear(512, 1024, bias=False)]) + model1 = get_split_model(model, 64, 2, 64, 1000) + model2 = get_split_model(model, 64, 2, 64, 1) + model3 = get_split_model(model, 64, 1, 64, 1000) + + assert torch.allclose(model(inputs), model1(inputs), atol=1e-5) + assert torch.allclose(model(inputs), model2(inputs), atol=1e-5) + assert torch.allclose(model(inputs), model3(inputs), atol=1e-5) + + +if __name__ == "__main__": + test_split_model() diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py new file mode 100644 index 00000000000..1e5a842fed5 --- /dev/null +++ b/examples/apple/coreml/llama/utils.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class SplitLinearModule(torch.nn.Module): + def __init__( + self, + in_features, + out_features, + out_target_split_size=1, + out_max_splits=1, + in_target_split_size=1, + in_max_splits=1, + ): + super(SplitLinearModule, self).__init__() + self.out_split_sizes = self._get_split_sizes( + out_features, out_target_split_size, out_max_splits + ) + self.in_split_sizes = self._get_split_sizes( + in_features, in_target_split_size, in_max_splits + ) + print( + f"Splitting out_features={out_features} into {len(self.out_split_sizes)} of size {self.out_split_sizes[0]}." + ) + print( + f"Splitting in_features={in_features} into {len(self.in_split_sizes)} of size {self.in_split_sizes[0]}." + ) + + # self.ops contains a list of linear ops for different pieces of the output matrix + # The index of an op at (in_idx, out_idx) is given by self.op_index(in_idx, out_idx) + self.ops = torch.nn.ModuleList() + for idx_out, s_out in enumerate(self.out_split_sizes): + for idx_in, s_in in enumerate(self.in_split_sizes): + assert len(self.ops) == self.op_index(idx_in, idx_out) + self.ops.append(torch.nn.Linear(s_in, s_out, bias=False)) + + def op_index(self, in_index, out_index): + idx = out_index * len(self.in_split_sizes) + in_index + return idx + + def _get_split_sizes(self, n_features, target_split_size, max_splits): + num_splits = max(n_features // target_split_size, 1) + if num_splits > max_splits: + num_splits = max_splits + + split_size = n_features // num_splits + split_remainder = n_features % num_splits + if split_remainder > 0: + raise ValueError( + f"Cannot split {n_features} with target_split_size={target_split_size} and max_splits={max_splits} because it leaves a remainder of {split_remainder}." + ) + + ret = [split_size for _ in range(num_splits)] + return ret + + def set_params(self, weight): + split_weights = [] + for w_out in weight.split(self.out_split_sizes, dim=0): + for w in w_out.split(self.in_split_sizes, dim=1): + split_weights.append(w) + + for i, split in enumerate(self.ops): + split.weight = torch.nn.Parameter(split_weights[i]) + + def forward(self, x): + if len(self.in_split_sizes) == 1: + out_chunks = [op(x) for op in self.ops] + else: + x_splits = x.split(self.in_split_sizes, dim=-1) + out_chunks = [ + torch.sum( + torch.stack( + [ + self.ops[self.op_index(in_idx, out_idx)].forward( + x_splits[in_idx] + ) + for in_idx in range(len(self.in_split_sizes)) + ], + ), + dim=0, + ) + for out_idx in range(len(self.out_split_sizes)) + ] + + return torch.concat(out_chunks, dim=-1) + + +def replace_linear_with_split_linear( + model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1 +): + for name, module in model.named_children(): + if isinstance(module, torch.nn.Linear): + assert module.bias is None, "SplitLinearModule does not support bias" + new_module = SplitLinearModule( + module.in_features, + module.out_features, + out_target_split_size, + out_max_splits, + in_target_split_size, + in_max_splits, + ) + new_module.set_params(module.weight) + setattr(model, name, new_module) + else: + replace_linear_with_split_linear( + module, + out_target_split_size, + out_max_splits, + in_target_split_size, + in_max_splits, + )