Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions devtools/bundled_program/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from typing import Dict, List, Optional, Sequence, Type, Union

import executorch.devtools.bundled_program.schema as bp_schema
from pyre_extensions import none_throws

import executorch.exir.schema as core_schema

import torch
Expand Down Expand Up @@ -44,10 +42,12 @@ class BundledProgram:

def __init__(
self,
executorch_program: Optional[Union[
ExecutorchProgram,
ExecutorchProgramManager,
]],
executorch_program: Optional[
Union[
ExecutorchProgram,
ExecutorchProgramManager,
]
],
method_test_suites: Sequence[MethodTestSuite],
pte_file_path: Optional[str] = None,
):
Expand All @@ -59,18 +59,24 @@ def __init__(
pte_file_path: The path to pte file to deserialize program if executorch_program is not provided.
"""
if not executorch_program and not pte_file_path:
raise RuntimeError("Either executorch_program or pte_file_path must be provided")
raise RuntimeError(
"Either executorch_program or pte_file_path must be provided"
)

if executorch_program and pte_file_path:
raise RuntimeError("Only one of executorch_program or pte_file_path can be used")
raise RuntimeError(
"Only one of executorch_program or pte_file_path can be used"
)

method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name)
if executorch_program:
self._assert_valid_bundle(executorch_program, method_test_suites)
self.executorch_program: Optional[Union[
ExecutorchProgram,
ExecutorchProgramManager,
]] = executorch_program
self.executorch_program: Optional[
Union[
ExecutorchProgram,
ExecutorchProgramManager,
]
] = executorch_program
self._pte_file_path: Optional[str] = pte_file_path

self.method_test_suites = method_test_suites
Expand All @@ -88,7 +94,8 @@ def serialize_to_schema(self) -> bp_schema.BundledProgram:
if self.executorch_program:
program = self._extract_program(self.executorch_program)
else:
with open(none_throws(self._pte_file_path), "rb") as f:
assert self._pte_file_path is not None
with open(self._pte_file_path, "rb") as f:
p_bytes = f.read()
program = _deserialize_pte_binary(p_bytes)

Expand Down
16 changes: 12 additions & 4 deletions devtools/bundled_program/test/test_bundle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

# pyre-strict

import tempfile
import unittest
from typing import List
import tempfile

import executorch.devtools.bundled_program.schema as bp_schema

import torch
Expand Down Expand Up @@ -73,7 +74,8 @@ def test_bundled_program(self) -> None:
bundled_program.serialize_to_schema().program,
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
)



def test_bundled_program_from_pte(self) -> None:
executorch_program, method_test_suites = get_common_executorch_program()

Expand All @@ -82,11 +84,17 @@ def test_bundled_program_from_pte(self) -> None:
with open(executorch_model_path, "wb") as f:
f.write(executorch_program.buffer)

bundled_program = BundledProgram(executorch_program=None, method_test_suites=method_test_suites, pte_file_path=executorch_model_path)
bundled_program = BundledProgram(
executorch_program=None,
method_test_suites=method_test_suites,
pte_file_path=executorch_model_path,
)

method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name)

for plan_id in range(len(executorch_program.executorch_program.execution_plan)):
for plan_id in range(
len(executorch_program.executorch_program.execution_plan)
):
bundled_plan_test = (
bundled_program.serialize_to_schema().method_test_suites[plan_id]
)
Expand Down
Loading