Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor API context managers #5142

Draft
wants to merge 49 commits into
base: branch-23.04
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
908e522
Do not require explicit "needs_self" option.
csadorf Nov 28, 2022
6ce6d20
Micro-refactor of api_context_managers module.
csadorf Nov 28, 2022
dc176df
Eliminate ProcessEnterBaseMixin class.
csadorf Nov 28, 2022
3321983
Remove ProcessEnter and ProcessReturn classes.
csadorf Nov 28, 2022
f79b637
WIP: Refactor internal api contexts to simplify cm hierarchy.
csadorf Dec 7, 2022
8cc4e22
Use ExitStack for stacking CMs in api_context CM.
csadorf Jan 19, 2023
fa151f7
Use GlobalSettings() more consistently.
csadorf Jan 19, 2023
30bd645
Further debugging.
csadorf Jan 19, 2023
f9765a3
Do not use global state for processing return values.
csadorf Jan 20, 2023
4b94a7b
Partially revert 4b102b9576abdd59fcc281514477030d10b8451b .
csadorf Jan 20, 2023
d6ff914
Fix unrelated style issues.
csadorf Jan 20, 2023
98998b2
Restore the _API_OUTPUT_DTYPE_OVERRIDE global.
csadorf Jan 20, 2023
c9776bf
Move override logic into calling function.
csadorf Jan 23, 2023
9d528a2
Convert solver_model.coef_ to CumlArray in more robust way.
csadorf Jan 23, 2023
10b3b35
Fix ElasticNet.
csadorf Jan 23, 2023
83c3fc1
Fix ARIMA.
csadorf Jan 24, 2023
e5e09f3
Fix ColumnTransformer
csadorf Jan 24, 2023
013a199
The internals.set_api_output_type does not change global output_type.
csadorf Jan 24, 2023
c0e4543
Fix DBSCAN.
csadorf Jan 24, 2023
d77a09e
Fix ExponentialSmoothing.
csadorf Jan 24, 2023
48b6782
Fix MBSGDClassifier.
csadorf Jan 24, 2023
ff78fd0
Fix MBSGDRegressor.
csadorf Jan 24, 2023
6a31e7b
Refactor the api_context() CM function.
csadorf Jan 24, 2023
b8984d2
Implement exit_internal_api().
csadorf Jan 24, 2023
1237c1e
Fix PCA.
csadorf Jan 24, 2023
80b704e
Fix entropy module.
csadorf Jan 24, 2023
a80da27
Fix multiclass.
csadorf Jan 24, 2023
b5c488b
Fix output_type override logic.
csadorf Jan 25, 2023
eb0abf3
Fix SVC.
csadorf Jan 25, 2023
f3e837a
Consolidate API context globals into dataclass.
csadorf Jan 25, 2023
78c077c
Fix kneighbors_graph.
csadorf Jan 25, 2023
1f020b7
Fix BaseRandomForestModel.
csadorf Jan 26, 2023
c72d720
Set set_output_type=True for auto-wrapped "fit" functions.
csadorf Feb 7, 2023
1aab6a5
Revert changes to test_api.py
csadorf Feb 8, 2023
11b0f18
Minor refactor of f8a3106ed8f43c8b3f948ea5db3f25857dc92fc1.
csadorf Feb 8, 2023
66adc3a
Fixup ARIMA (19a7d397a1513737a2daaa240c5b172b5bee6d9a).
csadorf Feb 8, 2023
5b17f6d
Partially revert 440d6abcbf5ba3bfd6617ed8ab517048c7f7a54b.
csadorf Feb 8, 2023
8bd3902
Move ApiContext global into _GlobalSettingsData.
csadorf Feb 8, 2023
798ee95
Auto-convert to CumlArray for API-internal functions.
csadorf Feb 8, 2023
1764a46
Fix bug in the generic conversion of sparse arrays.
csadorf Feb 8, 2023
0e1ac45
Consolidate internals namespace imports.
csadorf Feb 9, 2023
6fc2e18
Consolidate api decorator definitions.
csadorf Feb 9, 2023
369d610
Update TODO items.
csadorf Feb 9, 2023
b3d1a0c
Respect globally specified output_type also for internal api calls.
csadorf Feb 9, 2023
e81fad3
Remove obsolete output_dtype property from global settings.
csadorf Feb 9, 2023
689717b
Fix MBSGDClassifier.
csadorf Jan 24, 2023
a307b42
Fix MBSGDRegressor.
csadorf Jan 24, 2023
c907da2
Fix SVC.
csadorf Jan 25, 2023
b5fc198
Fix TfidfTransformer.
csadorf Feb 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _list_indexing(X, key, key_dtype):


def _transform_one(transformer, X, y, weight, **fit_params):
res = transformer.transform(X).to_output('cupy')
res = transformer.transform(X).to_output("array")
# if we have a weight for this transformer, multiply output
if weight is None:
return res
Expand All @@ -316,16 +316,15 @@ def _fit_transform_one(transformer,
be multiplied by ``weight``.
"""
with _print_elapsed_time(message_clsname, message):
with cuml.using_output_type("cupy"):
transformer.accept_sparse = True
if hasattr(transformer, 'fit_transform'):
res = transformer.fit_transform(X, y, **fit_params)
else:
res = transformer.fit(X, y, **fit_params).transform(X)
transformer.accept_sparse = True
if hasattr(transformer, 'fit_transform'):
res = transformer.fit_transform(X, y, **fit_params)
else:
res = transformer.fit(X, y, **fit_params).transform(X)

if weight is None:
return res, transformer
return res * weight, transformer
return res.to_output("array") * weight, transformer


def _name_estimators(estimators):
Expand Down Expand Up @@ -899,6 +898,7 @@ def fit_transform(self, X, y=None) -> SparseCumlArray:
return np.zeros((X.shape[0], 0))

Xs, transformers = zip(*result)
Xs = [x.to_output("array") if hasattr(x, "to_output") else x for x in Xs]

# determine if concatenated output will be sparse or not
if any(issparse(X) for X in Xs):
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/common/array_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from dataclasses import dataclass, field
from cuml.internals.array import CumlArray
import cuml
from cuml.internals.global_settings import GlobalSettings
from cuml.internals.input_utils import (
input_to_cuml_array, determine_array_type
)
Expand Down Expand Up @@ -114,7 +114,7 @@ def __get__(self, instance, owner):
assert len(existing.values) > 0

# Get the global output type
output_type = cuml.global_settings.output_type
output_type = GlobalSettings().output_type

# First, determine if we need to call to_output at all
if (output_type == "mirror"):
Expand Down
1 change: 0 additions & 1 deletion python/cuml/decomposition/pca.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@ class PCA(UniversalBase,
'type': 'dense_sparse',
'description': 'Transformed values',
'shape': '(n_samples, n_components)'})
@cuml.internals.api_base_return_array_skipall
@enable_device_interop
def fit_transform(self, X, y=None) -> CumlArray:
"""
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/feature_extraction/_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def fit(self, X) -> "TfidfTransformer":

return self

@cuml.internals.api_base_return_any_skipall
@cuml.internals.api_base_return_array()
def transform(self, X, copy=True):
"""Transform a count matrix to a tf or tf-idf representation

Expand Down Expand Up @@ -238,7 +238,7 @@ def transform(self, X, copy=True):

return X

@cuml.internals.api_base_return_any_skipall
@cuml.internals.api_base_return_array()
def fit_transform(self, X, copy=True):
"""
Fit TfidfTransformer to X, then transform X.
Expand Down
8 changes: 3 additions & 5 deletions python/cuml/internals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,18 @@
from cuml.internals.api_decorators import (
_deprecate_pos_args,
api_base_fit_transform,
api_base_return_any_skipall,
api_base_return_any,
api_base_return_array_skipall,
api_base_return_any_skipall,
api_base_return_array,
api_base_return_generic_skipall,
api_base_return_array_skipall,
api_base_return_generic,
api_base_return_generic_skipall,
api_base_return_sparse_array,
api_return_any,
api_return_array,
api_return_generic,
api_return_sparse_array,
exit_internal_api,
)
from cuml.internals.api_context_managers import (
in_internal_api,
set_api_output_dtype,
set_api_output_type,
Expand Down
25 changes: 25 additions & 0 deletions python/cuml/internals/api_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass


@dataclass
class ApiContext:

stack_level: int = 0
previous_output_type = None
output_type = None
output_dtype = None
Loading