Skip to content

Commit

Permalink
Update the tff.program API to be async, using asyncio.
Browse files Browse the repository at this point in the history
* Updated the `tff.program` API to be async:
  * `MaterializableValueReference.get_value`
  * `tff.program.materialize_value`
  * `ReleaseManager.release`
  * `versions`, `load`, `load_latest` and `save` on `ProgramStateManager`

Basically, this is all the API that is used by program logic that accepts a `MaterializableValueReference`.

See [asyncio](https://docs.python.org/3/library/asyncio.html) for more information.

PiperOrigin-RevId: 441286145
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Apr 12, 2022
1 parent 763157e commit a98b5ed
Show file tree
Hide file tree
Showing 23 changed files with 557 additions and 305 deletions.
4 changes: 4 additions & 0 deletions tensorflow_federated/python/program/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ py_test(
srcs_version = "PY3",
deps = [
":file_utils",
":test_utils",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
"@org_tensorflow//tensorflow:tensorflow_py",
Expand Down Expand Up @@ -211,6 +212,7 @@ py_library(
":data_source",
":federated_context",
":value_reference",
"//tensorflow_federated/python/common_libs:async_utils",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computation_base",
Expand All @@ -230,6 +232,7 @@ py_test(
deps = [
":federated_context",
":native_platform",
":test_utils",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
Expand All @@ -255,6 +258,7 @@ py_test(
srcs_version = "PY3",
deps = [
":program_state_manager",
":test_utils",
"@absl_py//absl/testing:absltest",
],
)
Expand Down
36 changes: 18 additions & 18 deletions tensorflow_federated/python/program/file_program_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
* encodes files in the same way as `tf.io.gfile`
"""

import asyncio
import os
import os.path
from typing import Any, List, Optional, Union
Expand Down Expand Up @@ -95,18 +96,18 @@ def __init__(self,
self._keep_total = keep_total
self._keep_first = keep_first

def versions(self) -> Optional[List[int]]:
async def versions(self) -> Optional[List[int]]:
"""Returns a list of saved versions or `None`.
Returns:
A list of saved versions or `None` if there is no saved program state.
"""
if not tf.io.gfile.exists(self._root_dir):
if not await file_utils.exists(self._root_dir):
return None
versions = []
# Due to tensorflow/issues/19378, we cannot use `tf.io.gfile.glob` here
# because it returns directory contents recursively on Windows.
entries = tf.io.gfile.listdir(self._root_dir)
entries = await file_utils.listdir(self._root_dir)
for entry in entries:
if entry.startswith(self._prefix):
version = self._get_version_for_path(entry)
Expand Down Expand Up @@ -152,7 +153,7 @@ def _get_path_for_version(self, version: int) -> str:
basename = f'{self._prefix}{version}'
return os.path.join(self._root_dir, basename)

def load(self, version: int, structure: Any) -> Any:
async def load(self, version: int, structure: Any) -> Any:
"""Returns the program state for the given `version`.
Args:
Expand All @@ -170,10 +171,10 @@ def load(self, version: int, structure: Any) -> Any:
py_typecheck.check_type(version, int)

path = self._get_path_for_version(version)
if not tf.io.gfile.exists(path):
if not await file_utils.exists(path):
raise program_state_manager.ProgramStateManagerStateNotFoundError(
f'No program state found for version: {version}')
flattened_state = file_utils.read_saved_model(path)
flattened_state = await file_utils.read_saved_model(path)
try:
program_state = tree.unflatten_as(structure, flattened_state)
except ValueError as e:
Expand All @@ -185,27 +186,26 @@ def load(self, version: int, structure: Any) -> Any:
logging.info('Program state loaded: %s', path)
return program_state

def _remove(self, version: int):
async def _remove(self, version: int):
"""Removes program state for the given `version`."""
py_typecheck.check_type(version, int)

path = self._get_path_for_version(version)
if tf.io.gfile.exists(path):
tf.io.gfile.rmtree(path)
if await file_utils.exists(path):
await file_utils.rmtree(path)
logging.info('Program state removed: %s', path)

def _remove_old_program_state(self):
async def _remove_old_program_state(self):
"""Removes old program state."""
if self._keep_total <= 0:
return
versions = self.versions()
versions = await self.versions()
if versions is not None and len(versions) > self._keep_total:
start = 1 if self._keep_first else 0
stop = start - self._keep_total
for version in versions[start:stop]:
self._remove(version)
await asyncio.gather(*[self._remove(v) for v in versions[start:stop]])

def save(self, program_state: Any, version: int):
async def save(self, program_state: Any, version: int):
"""Saves `program_state` for the given `version`.
Args:
Expand All @@ -222,10 +222,10 @@ def save(self, program_state: Any, version: int):
py_typecheck.check_type(version, int)

path = self._get_path_for_version(version)
if tf.io.gfile.exists(path):
if await file_utils.exists(path):
raise program_state_manager.ProgramStateManagerStateAlreadyExistsError(
f'Program state already exists for version: {version}')
materialized_state = value_reference.materialize_value(program_state)
materialized_state = await value_reference.materialize_value(program_state)
flattened_state = tree.flatten(materialized_state)
file_utils.write_saved_model(flattened_state, path)
self._remove_old_program_state()
await file_utils.write_saved_model(flattened_state, path)
await self._remove_old_program_state()

0 comments on commit a98b5ed

Please sign in to comment.