Skip to content

Commit

Permalink
refactor new patching utils
Browse files Browse the repository at this point in the history
  • Loading branch information
speediedan committed May 17, 2024
1 parent 004f873 commit 5643aeb
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
**/__pycache__
**/.pyc
**/.classpath
**/.dockerignore
**/.env
Expand Down
2 changes: 1 addition & 1 deletion src/fts_examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
_HF_AVAILABLE = module_available("transformers") and module_available("datasets")
_SP_AVAILABLE = module_available("sentencepiece")

from fts_examples.stable.dep_patch_shim import _ACTIVE_PATCHES # noqa: E402, F401
from fts_examples.stable.patching.dep_patch_shim import _ACTIVE_PATCHES # noqa: E402, F401
Empty file.
31 changes: 31 additions & 0 deletions src/fts_examples/stable/patching/_patch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import operator
import sys
from typing import Callable
import importlib.metadata
from packaging.version import Version
from functools import lru_cache


@lru_cache
def lwt_compare_version(package: str, op: Callable, version: str, use_base_version: bool = True,
local_version: str = None) -> bool:
try:
pkg_version = Version(importlib.metadata.version(package))
except (importlib.metadata.PackageNotFoundError):
return False
except TypeError:
# possibly mocked by Sphinx so needs to return True to generate summaries
return True
if local_version:
if not operator.eq(local_version, pkg_version.local):
return False
if use_base_version:
pkg_version = Version(pkg_version.base_version)
return op(pkg_version, Version(version))


def _prepare_module_ctx(module_path, orig_globals):
_orig_file = orig_globals.pop('__file__')
orig_globals.update(vars(sys.modules.get(module_path)))
orig_globals['__file__'] = _orig_file
return orig_globals
Original file line number Diff line number Diff line change
@@ -1,27 +1,7 @@
import operator
import sys
from typing import NamedTuple, Tuple, Callable
import importlib.metadata
from packaging.version import Version
from functools import lru_cache


@lru_cache
def lwt_compare_version(package: str, op: Callable, version: str, use_base_version: bool = True,
local_version: str = None) -> bool:
try:
pkg_version = Version(importlib.metadata.version(package))
except (importlib.metadata.PackageNotFoundError):
return False
except TypeError:
# possibly mocked by Sphinx so needs to return True to generate summaries
return True
if local_version:
if not operator.eq(local_version, pkg_version.local):
return False
if use_base_version:
pkg_version = Version(pkg_version.base_version)
return op(pkg_version, Version(version))
from fts_examples.stable.patching._patch_utils import lwt_compare_version


class DependencyPatch(NamedTuple):
Expand All @@ -39,7 +19,7 @@ def _dep_patch_repr(self):


def _patch_unsupported_numpy_arrow_extractor():
from fts_examples.stable.patched_numpyarrowextractor import NumpyArrowExtractor
from fts_examples.stable.patching.patched_numpyarrowextractor import NumpyArrowExtractor
# since the TorchFormatter and NumpyFormatter classes are already defined we need to patch both definitions
# to use our patched `NumpyArrowExtractor`
for old_mod, stale_ref in zip(['torch_formatter', 'np_formatter'], ['TorchFormatter', 'NumpyFormatter']):
Expand All @@ -48,7 +28,7 @@ def _patch_unsupported_numpy_arrow_extractor():


def _patch_triton():
from fts_examples.stable.patched_triton_jit_fn_init import _new_init
from fts_examples.stable.patching.patched_triton_jit_fn_init import _new_init
target_mod = 'triton.runtime.jit'
sys.modules.get(target_mod).__dict__.get('JITFunction').__init__ = _new_init

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys
from fts_examples.stable.patching._patch_utils import _prepare_module_ctx


globals().update(_prepare_module_ctx('datasets.formatting.formatting', globals()))

globals().update(vars(sys.modules.get('datasets.formatting.formatting')))

# we ignore these for the entire file since we're using our global namespace trickeration to patch
# ruff: noqa: F821
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
from fts_examples.stable.patching._patch_utils import _prepare_module_ctx
import re
from triton.runtime import jit # noqa: F401


globals().update(vars(sys.modules.get('triton.runtime.jit')))
globals().update(_prepare_module_ctx('triton.runtime.jit', globals()))

# we ignore these for the entire file since we're using our global namespace trickeration to patch
# ruff: noqa: F821
Expand Down

0 comments on commit 5643aeb

Please sign in to comment.