Skip to content

Commit

Permalink
[AOTInductor] Include constants in AOTInductor .so file. (#108473)
Browse files Browse the repository at this point in the history
Summary:

Include constants in AOTInductor .so file.
Added some difference:
1) serialize with ctypes instead of the native of torch.storage
2) Use the underlying for_blob instead of from_blob to construct Tensor.

Test Plan:
Unit tests:
```
test/inductor/test_aot_inductor.py
```
fb:
MRS tests (https://fburl.com/gdoc/ffllzw72):
```
LOGLEVEL=DEBUG TORCHINDUCTOR_MAX_AUTOTUNE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 ../buck-out/v2/gen/fbcode/3408cf5f8424049a/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --load=manifold://ig_inference_model/tree/user/facebook/fblearner/predictor/966480198/289/gpu_lowering/input.predictor --skip-trt --sync-mode=0 --enable-aot-inductor
```

Previous failed buck tests:
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark -- --exact 'caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark - test_aot_inductor_benchmark_oemae (caffe2.torch.fb.model_transform.experimental.benchmark.test.test_aot_inductor_benchmark.AOTInductorBenchmark)'
```

Differential Revision: D48927532
  • Loading branch information
muchulee8 authored and facebook-github-bot committed Sep 5, 2023
1 parent 208fd1c commit 365e48c
Show file tree
Hide file tree
Showing 15 changed files with 428 additions and 71 deletions.
4 changes: 2 additions & 2 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,9 +1147,9 @@ def load(cls, model, example_inputs, eager_forward):

# Use a utility function for easier benchmarking
source = """
#include <torch/csrc/inductor/aot_inductor_model.h>
#include <torch/csrc/inductor/aot_inductor_model_container.h>
torch::aot_inductor::AOTInductorModel model;
torch::aot_inductor::AOTInductorModelContainer model(1);
void run(
const std::vector<at::Tensor>& input_tensors,
Expand Down
19 changes: 15 additions & 4 deletions test/cpp/aot_inductor/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,28 @@ TEST(AotInductorTest, BasicTest) {
Net net;
net.to(torch::kCUDA);

// We should fix the weight over here.
// This should match exactly with the one in test.py
torch::Tensor weights =
at::arange(640, at::dtype(at::kFloat).device(at::kCUDA));
weights = at::reshape(weights, {10, 64});
torch::Tensor bias = at::zeros({10}, at::dtype(at::kFloat).device(at::kCUDA));

for (const auto& pair : net.named_parameters()) {
if (pair.key().find("weight") != std::string::npos) {
pair.value().copy_(weights);
} else if (pair.key().find("bias") != std::string::npos) {
pair.value().copy_(bias);
}
}

torch::Tensor x =
at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA));
torch::Tensor y =
at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA));
torch::Tensor results_ref = net.forward(x, y);

// TODO: we need to provide an API to concatenate args and weights
std::vector<torch::Tensor> inputs;
for (const auto& pair : net.named_parameters()) {
inputs.push_back(pair.value());
}
inputs.push_back(x);
inputs.push_back(y);

Expand Down
6 changes: 6 additions & 0 deletions test/cpp/aot_inductor/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(64, 10)
weights = torch.arange(640)
weights = torch.reshape(weights, (10, 64))

with torch.no_grad():
self.fc.weight.copy_(weights)
self.fc.bias.copy_(torch.zeros(10))

def forward(self, x, y):
return self.fc(torch.sin(x) + torch.cos(y))
Expand Down
49 changes: 44 additions & 5 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def load(cls, model, example_inputs, example_outputs, options=None):

# Use a utility function for easier testing
source = """
#include <torch/csrc/inductor/aot_inductor_model.h>
#include <torch/csrc/inductor/aot_inductor_model_container.h>
torch::aot_inductor::AOTInductorModel model;
torch::aot_inductor::AOTInductorModelContainer model(1);
void run(
const std::vector<at::Tensor>& input_tensors,
Expand All @@ -63,12 +63,10 @@ def run(cls, model, example_inputs, example_outputs, options=None):
optimized, exported, output_tensors, output_spec = AOTInductorModelRunner.load(
model, example_inputs, example_outputs, options
)
param_buffer_values = list(exported.state_dict.values())
flat_example_inputs = fx_pytree.tree_flatten_spec(
example_inputs, exported.call_spec.in_spec
)
all_args = (*param_buffer_values, *flat_example_inputs)
optimized(all_args, output_tensors)
optimized(flat_example_inputs, output_tensors)
return pytree.tree_unflatten(output_tensors, output_spec)


Expand All @@ -91,6 +89,47 @@ def forward(self, x, y):
actual = AOTInductorModelRunner.run(model, example_inputs, expected)
self.assertTrue(same(actual, expected))

def test_large(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(250112, 512, device="cuda")

def forward(self, x, y):
return x + torch.nn.functional.linear(y, self.weight)

model = Repro()
example_inputs = (
torch.randn(1, 250112, device="cuda"),
torch.randn(1, 512, device="cuda"),
)
expected = model(*example_inputs)
actual = AOTInductorModelRunner.run(model, example_inputs, expected)
self.assertTrue(same(actual, expected))

def test_with_offset(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.orig_tensor = torch.randn(2, 15, 10, device="cuda")[0]
self.tensor = self.orig_tensor[5:, :]

def forward(self, x, y):
return (
x
+ torch.nn.functional.linear(y, self.orig_tensor[:10, :])
+ self.tensor
)

model = Repro()
example_inputs = (
torch.randn(10, 10, device="cuda"),
torch.randn(10, 10, device="cuda"),
)
expected = model(*example_inputs)
actual = AOTInductorModelRunner.run(model, example_inputs, expected)
self.assertTrue(same(actual, expected))

def test_missing_output(self):
class Repro(torch.nn.Module):
def __init__(self):
Expand Down
4 changes: 2 additions & 2 deletions test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def foo(mod, x):
# we unfuse the conv bias, but it should only have one constant in the kernel
if self.device == "cuda":
FileCheck().check_not(".run(").check("conv").check(".run(").check_same(
"constant"
).check_not("constant").check_next("return").run(code[0])
"frozen_param"
).check_not("frozen_param").check_next("return").run(code[0])

self.assertEqual(
out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2
Expand Down
45 changes: 40 additions & 5 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import dataclasses
import io
import re
Expand Down Expand Up @@ -627,7 +628,6 @@ def aot_compile(
Returns:
Path to the generated shared library, and the exported program
"""
from torch._inductor.compile_fx import compile_fx_aot
from torch._inductor.decomposition import select_decomp_table

global DECOMP_TABLE
Expand All @@ -636,11 +636,46 @@ def aot_compile(
# Reset the global value
DECOMP_TABLE = core_aten_decompositions()

param_buffer_values = list(ep.state_dict.values())
flat_example_inputs = fx_pytree.tree_flatten_spec(
combine_args_kwargs(args, kwargs), ep.call_spec.in_spec # type: ignore[arg-type]
)
all_args = (*param_buffer_values, *flat_example_inputs)

so_path = torch._inductor.aot_compile(ep.graph_module, list(all_args), options)
return so_path, ep
unlifted_module = ep.module()
unlifted_module.graph.set_codegen(torch.fx.CodeGen()) # type: ignore[attr-defined]
unlifted_module.recompile()
options = (
{"from_export": True}
if options is None
else {**options, "from_export": True}
)
so_path = torch._inductor.aot_compile(unlifted_module, flat_example_inputs, options) # type: ignore[arg-type]

user_inputs = []
user_outputs = []
for node in unlifted_module.graph.nodes:
if node.op == "placeholder":
user_inputs.append(node.name)
elif node.op == "output":
user_outputs = [arg.name for arg in node.args[0]]

unlifted_ep = ExportedProgram(
unlifted_module,
unlifted_module.graph,
ExportGraphSignature(
[],
[],
user_inputs,
user_outputs,
{},
{},
{},
None,
),
call_spec=copy.deepcopy(ep.call_spec),
state_dict={},
range_constraints=copy.deepcopy(ep.range_constraints),
equality_constraints=copy.deepcopy(ep.equality_constraints),
module_call_graph=ep.module_call_graph,
)

return so_path, unlifted_ep
118 changes: 101 additions & 17 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,13 @@ def get_lock_dir():
return lock_dir


def code_hash(code, extra: str = ""):
hashing_str = code
def code_hash(code: Union[str, bytes], extra: str = ""):
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
if extra != "":
hashing_str = hashing_str + "||" + extra
hashing_str = hashing_str + b"||" + extra.encode("utf-8")
return (
"c"
+ base64.b32encode(hashlib.sha256(hashing_str.encode("utf-8")).digest())[:51]
+ base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
.decode("utf-8")
.lower()
)
Expand Down Expand Up @@ -656,6 +656,10 @@ def pick_vec_isa():
return invalid_vec_isa


def get_compile_only(compile_only=True):
return "-c" if compile_only else ""


def get_shared(shared=True):
return "-shared -fPIC" if shared else ""

Expand Down Expand Up @@ -888,6 +892,7 @@ def cpp_compile_command(
vec_isa: VecISA = invalid_vec_isa,
cuda=False,
aot_mode=False,
compile_only=False,
):
ipaths, lpaths, libs, macros = get_include_and_linking_paths(
include_pytorch, vec_isa, cuda, aot_mode
Expand Down Expand Up @@ -918,11 +923,20 @@ def cpp_compile_command(
{use_custom_generated_macros()}
{use_fb_internal_macros()}
{use_standard_sys_dir_headers()}
{get_compile_only(compile_only)}
-o {out_name}
""",
).strip()


def run_command_and_check(cmd: str):
cmd = shlex.split(cmd)
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError as e:
raise exc.CppCompileError(cmd, e.output) from e


class CudaKernelParamCache:
cache = dict()
clear = staticmethod(cache.clear)
Expand Down Expand Up @@ -956,12 +970,41 @@ def compile(cls, graph, source_code, serialized_extern_kernel_nodes, cuda):
"i", "o", vec_isa=picked_vec_isa, cuda=cuda, aot_mode=graph.aot_mode
)
)
if config.is_fbcode():
ld_command = build_paths.ld()
objcopy_command = build_paths.objcopy()
else:
ld_command = "ld"
objcopy_command = "objcopy"
key, input_path = write(
source_code,
"cpp",
extra=cpp_command,
specified_dir=config.aot_inductor.output_path,
)

def _to_bytes(t: torch.Tensor) -> bytes:
# This serializes the tensor's untyped_storage to bytes by accessing
# the raw data of the underlying structure.
import ctypes

t_cpu = t.untyped_storage().cpu()
raw_array = ctypes.cast(
t_cpu.data_ptr(), ctypes.POINTER(ctypes.c_ubyte * t_cpu.nbytes())
)

return bytes(raw_array.contents)

aot_constants = b""
for idx, tensor in enumerate(graph.constants.values()):
aot_constants += _to_bytes(tensor)

consts_key, consts_path = write(
aot_constants,
"bin",
specified_dir=config.aot_inductor_output_path,
)

if key not in cls.cache:
from filelock import FileLock

Expand All @@ -978,20 +1021,61 @@ def compile(cls, graph, source_code, serialized_extern_kernel_nodes, cuda):
output_so = os.path.splitext(input_path)[0] + ".so"

if not os.path.exists(output_so):
cmd = shlex.split(
cpp_compile_command(
input=input_path,
output=output_so,
vec_isa=picked_vec_isa,
cuda=cuda,
aot_mode=graph.aot_mode,
)
output_o = os.path.splitext(input_path)[0] + ".o"
cmd = cpp_compile_command(
input=input_path,
output=output_o,
vec_isa=picked_vec_isa,
cuda=cuda,
aot_mode=graph.aot_mode,
compile_only=True,
)
log.debug("aot compilation command: %s", cmd)
run_command_and_check(cmd)

consts_o = os.path.splitext(consts_path)[0] + ".o"
cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}"
run_command_and_check(cmd)
log.debug("aot constant binary command: %s", cmd)

cmd = (
f"{objcopy_command} --rename-section"
" .data=.lrodata,alloc,load,readonly,data,contents"
f" {consts_o} {consts_o}"
)
log.debug("aot constant obj command: %s", cmd)
run_command_and_check(cmd)

cmd = f"rm {consts_path}"
log.debug("aot constant bin removal command: %s", cmd)
run_command_and_check(cmd)

body = re.sub(r"[\W_]+", "_", consts_path)
symbol_list = []
symbol_list.append(
f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}"
)
symbol_list.append(
f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_size {consts_o}"
)
symbol_list.append(
f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}"
)
log.debug(
"aot constant binary redefine symbol: %s", " ".join(symbol_list)
)
for cmd in symbol_list:
run_command_and_check(cmd)

cmd = cpp_compile_command(
input=f"{output_o} {consts_o}",
output=output_so,
vec_isa=picked_vec_isa,
cuda=cuda,
aot_mode=graph.aot_mode,
)
log.debug("aot compilation command: %s", " ".join(cmd))
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError as e:
raise exc.CppCompileError(cmd, e.output) from e
log.debug("aot linkage command: %s", cmd)
run_command_and_check(cmd)
else:
log.debug(
"aot_inductor dynamic library already exist: %s", output_so
Expand Down

0 comments on commit 365e48c

Please sign in to comment.