-
Notifications
You must be signed in to change notification settings - Fork 696
[aoti-et] Store weights outside of .so #15180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e7a15a3
263e3b5
b13cc3f
60535b9
a3ab05c
d05a305
feff90f
89c0357
1b2200e
426cf90
a51d12e
fdd8333
a5a9a54
952e548
79ff7c9
d62a701
989def3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| 53a2908a10f414a2f85caa06703a26a40e873869 | ||
| e6f766c7d750d40603eee3f66c5915bac606b3ea |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,15 +27,6 @@ | |
|
|
||
| namespace executorch::backends::cuda { | ||
|
|
||
| #define LOAD_SYMBOL(handle, member, name, so_handle) \ | ||
| do { \ | ||
| auto symbol_res = get_function(so_handle, #name); \ | ||
| if (!symbol_res.ok()) { \ | ||
| return symbol_res.error(); \ | ||
| } \ | ||
| handle->member = reinterpret_cast<name##Func>(symbol_res.get()); \ | ||
| } while (0) | ||
|
|
||
| using namespace std; | ||
| using namespace aoti; | ||
|
|
||
|
|
@@ -61,29 +52,37 @@ class ET_EXPERIMENTAL CudaBackend final | |
| Error load_function_pointers_into_handle( | ||
| void* so_handle, | ||
| AOTIDelegateHandle* handle) const { | ||
| LOAD_SYMBOL( | ||
| handle, | ||
| create_with_device, | ||
| AOTInductorModelContainerCreateWithDevice, | ||
| so_handle); | ||
| #define LOAD_SYMBOL(member, name) \ | ||
| do { \ | ||
| auto symbol_res = get_function(so_handle, #name); \ | ||
| if (!symbol_res.ok()) { \ | ||
| return symbol_res.error(); \ | ||
| } \ | ||
| handle->member = reinterpret_cast<name##Func>(symbol_res.get()); \ | ||
| } while (0) | ||
|
|
||
| LOAD_SYMBOL(create_with_device, AOTInductorModelContainerCreateWithDevice); | ||
|
|
||
| LOAD_SYMBOL( | ||
| handle, delete_container, AOTInductorModelContainerDelete, so_handle); | ||
| LOAD_SYMBOL(delete_container, AOTInductorModelContainerDelete); | ||
|
|
||
| LOAD_SYMBOL( | ||
| handle, | ||
| get_num_inputs, | ||
| AOTInductorModelContainerGetNumInputs, | ||
| so_handle); | ||
| LOAD_SYMBOL(get_num_inputs, AOTInductorModelContainerGetNumInputs); | ||
|
|
||
| LOAD_SYMBOL( | ||
| handle, | ||
| get_num_outputs, | ||
| AOTInductorModelContainerGetNumOutputs, | ||
| so_handle); | ||
| LOAD_SYMBOL(get_num_outputs, AOTInductorModelContainerGetNumOutputs); | ||
|
|
||
| LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle); | ||
| LOAD_SYMBOL(run, AOTInductorModelContainerRun); | ||
| #undef LOAD_SYMBOL | ||
|
|
||
| auto symbol_res = | ||
| get_function(so_handle, "AOTInductorModelUpdateConstantsFromBlob"); | ||
| if (symbol_res.ok()) { | ||
| handle->update_constants_from_blob = | ||
| reinterpret_cast<AOTInductorModelUpdateConstantsFromBlobFunc>( | ||
| symbol_res.get()); | ||
| } else { | ||
| ET_LOG( | ||
| Info, | ||
| "Failed to load AOTInductorModelUpdateConstantsFromBlob. This .so is probably compiled on an old version of torch (<2.9.0)"); | ||
| } | ||
| return Error::Ok; | ||
| } | ||
|
|
||
|
|
@@ -112,13 +111,13 @@ class ET_EXPERIMENTAL CudaBackend final | |
| method_name.empty() ? "so_blob" : method_name + "_so_blob"; | ||
|
|
||
| const NamedDataMap* named_data_map = context.get_named_data_map(); | ||
| auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str()); | ||
| auto aoti_dso_buffer = named_data_map->get_data(so_blob_key.c_str()); | ||
| ET_CHECK_OR_RETURN_ERROR( | ||
| aoti_cuda_buffer.ok(), | ||
| aoti_dso_buffer.ok(), | ||
| Internal, | ||
| "Failed to get data for key %s: 0x%x", | ||
| so_blob_key.c_str(), | ||
| static_cast<uint32_t>(aoti_cuda_buffer.error())); | ||
| static_cast<uint32_t>(aoti_dso_buffer.error())); | ||
|
|
||
| // Generate dynamic temporary file path | ||
| filesystem::path temp_dir = filesystem::temp_directory_path(); | ||
|
|
@@ -132,19 +131,21 @@ class ET_EXPERIMENTAL CudaBackend final | |
| ET_LOG( | ||
| Info, | ||
| "Writing %zu bytes to %s", | ||
| aoti_cuda_buffer->size(), | ||
| aoti_dso_buffer->size(), | ||
| so_path.c_str()); | ||
|
|
||
| outfile.write( | ||
| static_cast<const char*>(aoti_cuda_buffer->data()), | ||
| aoti_cuda_buffer->size()); | ||
| static_cast<const char*>(aoti_dso_buffer->data()), | ||
| aoti_dso_buffer->size()); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| outfile, AccessFailed, "Failed to write to file %s", so_path.c_str()); | ||
|
|
||
| // Finish writing the file to disk | ||
| outfile.close(); | ||
|
|
||
| // Free the buffer immediately after writing to disk | ||
| aoti_dso_buffer->Free(); | ||
| // Load the lib | ||
| Result<void*> lib_handle_res = load_library(so_path); | ||
| if (!lib_handle_res.ok()) { | ||
|
|
@@ -172,6 +173,19 @@ class ET_EXPERIMENTAL CudaBackend final | |
|
|
||
| handle->container_handle = container_handle; | ||
|
|
||
| // Look into named data map for constant data | ||
| std::string weights_blob_key = | ||
| method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; | ||
| auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); | ||
| if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) { | ||
| ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str()); | ||
| const void* weights_blob = buffer_res->data(); | ||
| // Feed the weights blob into the container. Under the hood it's copying | ||
larryliu0820 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // weights, so we should free the buffer immediately. | ||
|
Comment on lines
+183
to
+184
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the weights are mmapped, so this isn't halving the maximum amount of weights we can handle, right? even so, seems unfortunate that we have to copy and therefore can't keep them simply mmapped though; peak CPU memory now needs to hold them, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good point, let me test this on my RTX 5080.
Yeah would be good if aoti can just take it without copying. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I might be missing something here, but I assume you're mmaping into the CPU memory right? AOTI copies it into the CUDA memory, and since we're running this on CUDA, we have to copy it to CUDA some time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a mmap equivalent on CUDA? If so, on ET side we can create a dataloader to directly load into CUDA memory. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess you can use GPUDirect Storage. |
||
| ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob( | ||
| handle->container_handle, static_cast<const uint8_t*>(weights_blob))); | ||
| buffer_res->Free(); | ||
| } | ||
larryliu0820 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Create a CUDA stream for asynchronous execution | ||
| cudaStream_t cuda_stream; | ||
| ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream)); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if existence/non-existence of blob_path matches the options
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only check for blob_path is None if package_constant_on_disk_format option is set to "binary_blob"