Skip to content

Commit

Permalink
Move ApiContext global into _GlobalSettingsData.
Browse files Browse the repository at this point in the history
And slightly adjust variable names.
  • Loading branch information
csadorf committed Feb 8, 2023
1 parent d0ab476 commit a8677ef
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 48 deletions.
25 changes: 25 additions & 0 deletions python/cuml/internals/api_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass


@dataclass
class ApiContext:

stack_level: int = 0
previous_output_type = None
output_type = None
output_dtype = None
68 changes: 24 additions & 44 deletions python/cuml/internals/api_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import inspect
import typing
import warnings
from dataclasses import dataclass

# TODO: Try to resolve circular import that makes this necessary:
from cuml.internals import input_utils as iu
from cuml.internals import logger
from cuml.internals.api_context import ApiContext
from cuml.internals.array_sparse import SparseCumlArray
from cuml.internals.constants import CUML_WRAPPED_FLAG
from cuml.internals.global_settings import GlobalSettings
Expand Down Expand Up @@ -103,61 +103,39 @@ def _using_mirror_output_type():
yield output_type


@contextlib.contextmanager
def _restore_dtype():
prev_output_dtype = GlobalSettings().output_dtype
try:
yield
finally:
GlobalSettings().output_dtype = prev_output_dtype


@dataclass
class ApiContext:

stack_level: int = 0
previous_output_type = None
output_type_override = None
output_dtype_override = None


_API_CONTEXT = ApiContext()
GlobalSettings()._api_context = ApiContext()


@contextlib.contextmanager
def api_context():
global _API_CONTEXT

_API_CONTEXT.stack_level += 1
GlobalSettings()._api_context.stack_level += 1

try:
if _API_CONTEXT.stack_level == 1:
if GlobalSettings()._api_context.stack_level == 1:
with contextlib.ExitStack() as stack:
_API_CONTEXT.output_type_override = None
_API_CONTEXT.output_dtype_override = None
GlobalSettings()._api_context.output_type = None
GlobalSettings()._api_context.output_dtype = None
stack.enter_context(cupy_using_allocator(rmm_cupy_allocator))
stack.enter_context(_restore_dtype())
_API_CONTEXT.previous_output_type =\
# stack.enter_context(_restore_dtype())
GlobalSettings()._api_context.previous_output_type =\
stack.enter_context(_using_mirror_output_type())
yield
else:
yield
finally:
_API_CONTEXT.stack_level -= 1
GlobalSettings()._api_context.stack_level -= 1


def in_internal_api():
return _API_CONTEXT.stack_level > 1
return GlobalSettings()._api_context.stack_level > 1


def set_api_output_type(output_type):
global _API_CONTEXT
_API_CONTEXT.output_type_override = output_type
GlobalSettings()._api_context.output_type = output_type


def set_api_output_dtype(output_dtype):
global _API_CONTEXT
_API_CONTEXT.output_dtype_override = output_dtype
GlobalSettings()._api_context.output_dtype = output_dtype


def _convert_to_cumlarray(ret_val):
Expand Down Expand Up @@ -333,16 +311,19 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

# Check for global output type override
global_api_context = GlobalSettings()._api_context
global_output_type = GlobalSettings().output_type
assert global_output_type in (None, "mirror", "input")
out_type_override = _API_CONTEXT.previous_output_type \
or _API_CONTEXT.output_type_override
if out_type_override not in (None, "mirror", "input"):
out_type = out_type_override
out_type = \
global_api_context.previous_output_type \
or global_api_context.output_type

if out_type not in (None, "mirror", "input"):
out_type = out_type
assert not out_type == "input"

# Check for global output dtype override
output_dtype = _API_CONTEXT.output_dtype_override \
output_dtype = global_api_context.output_dtype \
or output_dtype

return process_generic(ret, out_type, output_dtype)
Expand Down Expand Up @@ -409,14 +390,13 @@ def wrapper(*args, **kwargs):

@contextlib.contextmanager
def exit_internal_api():
global _API_CONTEXT
try:
previous_context = _API_CONTEXT
with using_output_type(_API_CONTEXT.previous_output_type):
_API_CONTEXT = ApiContext()
previous_context = GlobalSettings()._api_context
with using_output_type(previous_context.previous_output_type):
GlobalSettings()._api_context = ApiContext()
yield
finally:
_API_CONTEXT = previous_context
GlobalSettings()._api_context = previous_context


def mirror_args(
Expand Down
7 changes: 3 additions & 4 deletions python/cuml/internals/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import threading
from cuml.internals.api_context import ApiContext
from cuml.internals.available_devices import is_cuda_available
from cuml.internals.device_type import DeviceType
from cuml.internals.logger import warn
Expand Down Expand Up @@ -44,17 +45,15 @@ def __init__(self):
default_device_type = DeviceType.host
default_memory_type = MemoryType.host
self.shared_state = {
'_api_context': ApiContext(),
'_output_type': None,
'_output_dtype': None,
'_device_type': default_device_type,
'_memory_type': default_memory_type,
'root_cm': None
}
else:
self.shared_state = {
'_api_context': ApiContext(),
'_output_type': None,
'_output_dtype': None,
'root_cm': None
}


Expand Down

0 comments on commit a8677ef

Please sign in to comment.