diff --git a/sktime/datatypes/_convert.py b/sktime/datatypes/_convert.py index c9ccd2f7d57..785185d8b5e 100644 --- a/sktime/datatypes/_convert.py +++ b/sktime/datatypes/_convert.py @@ -101,8 +101,9 @@ def convert( obj : object to convert - any type, should comply with mtype spec for as_scitype from_type : str - the type to convert "obj" to, a valid mtype string valid mtype strings, with explanation, are in datatypes.MTYPE_REGISTER - to_type : str - the type to convert "obj" to, a valid mtype string - valid mtype strings, with explanation, are in datatypes.MTYPE_REGISTER + to_type : str - the mtype to convert "obj" to, a valid mtype string + or list of str, this specifies admissible types for conversion to; + if list, will convert to first mtype of the same scitype as from_mtype as_scitype : str, optional - name of scitype the object "obj" is considered as default = inferred from from_type valid scitype strings, with explanation, are in datatypes.SCITYPE_REGISTER @@ -127,9 +128,14 @@ def convert( if obj is None: return None + # if to_type is a list, we do the following: + # if on the list, then don't do a conversion (convert to from_type) + # if not on the list, we find and convert to first mtype that has same scitype + to_type = _get_first_mtype_of_same_scitype( + from_mtype=from_type, to_mtypes=to_type, varname="to_type" + ) + # input type checks - if not isinstance(to_type, str): - raise TypeError("to_type must be a str") if not isinstance(from_type, str): raise TypeError("from_type must be a str") if as_scitype is None: @@ -186,8 +192,9 @@ def convert_to( Parameters ---------- obj : object to convert - any type, should comply with mtype spec for as_scitype - to_type : str - the type to convert "obj" to, a valid mtype string - or list of str, this specifies admissible types for conversion to + to_type : str - the mtype to convert "obj" to, a valid mtype string + or list of str, this specifies admissible types for conversion to; + if list, will convert to first mtype of the same scitype as obj valid mtype strings, with explanation, are in datatypes.MTYPE_REGISTER as_scitype : str, optional - name of scitype the object "obj" is considered as pre-specifying the scitype reduces the number of checks done in type inference @@ -240,25 +247,6 @@ def convert_to( from_type = infer_mtype(obj=obj, as_scitype=as_scitype) as_scitype = mtype_to_scitype(from_type) - # if to_type is a list, we do the following: - # if on the list, then don't do a conversion (convert to from_type) - # if not on the list, we find and convert to first mtype that has same scitype - if isinstance(to_type, list): - # no conversion of from_type is in the list - if from_type in to_type: - to_type = from_type - # otherwise convert to first element of same scitype - else: - same_scitype_mtypes = [ - mtype for mtype in to_type if mtype_to_scitype(mtype) == as_scitype - ] - if len(same_scitype_mtypes) == 0: - raise TypeError( - "to_type contains no mtype compatible with the scitype of obj," - f"which is {as_scitype}" - ) - to_type = same_scitype_mtypes[0] - converted_obj = convert( obj=obj, from_type=from_type, @@ -271,6 +259,42 @@ def convert_to( return converted_obj +def _get_first_mtype_of_same_scitype(from_mtype, to_mtypes, varname="to_mtypes"): + """Return first mtype in list mtypes that has same scitype as from_mtype. + + Parameters + ---------- + from_mtype : str - mtype of object to convert from + to_mtypes : list of str - mtypes to convert to + varname : str - name of variable to_mtypes, for error message + + Returns + ------- + to_type : str - first mtype in to_mtypes that has same scitype as from_mtype + """ + if isinstance(to_mtypes, str): + return to_mtypes + + if not isinstance(to_mtypes, list): + raise TypeError(f"{varname} must be a str or a list of str") + + # no conversion of from_type is in the list + if from_mtype in to_mtypes: + return from_mtype + # otherwise convert to first element of same scitype + scitype = mtype_to_scitype(from_mtype) + same_scitype_mtypes = [ + mtype for mtype in to_mtypes if mtype_to_scitype(mtype) == scitype + ] + if len(same_scitype_mtypes) == 0: + raise TypeError( + f"{varname} contains no mtype compatible with the scitype of obj," + f"which is {scitype}" + ) + to_type = same_scitype_mtypes[0] + return to_type + + def _conversions_defined(scitype: str): """Return an indicator matrix which conversions are defined for scitype. diff --git a/sktime/datatypes/_series_as_panel/_convert.py b/sktime/datatypes/_series_as_panel/_convert.py index 9300ca4d195..685dc76924b 100644 --- a/sktime/datatypes/_series_as_panel/_convert.py +++ b/sktime/datatypes/_series_as_panel/_convert.py @@ -21,7 +21,7 @@ from sktime.datatypes import convert_to, scitype -def convert_Series_to_Panel(obj, store=None): +def convert_Series_to_Panel(obj, store=None, return_to_mtype=False): """Convert series to a single-series panel. Adds a dummy dimension to the series. @@ -35,6 +35,10 @@ def convert_Series_to_Panel(obj, store=None): Parameters ---------- obj: an object of scitype Series, of mtype pd.DataFrame, pd.Series, or np.ndarray. + store: dict, optional + converter store for back-conversion + return_to_mtype: bool, optional (default=False) + if True, also returns the str of the mtype converted to Returns ------- @@ -46,7 +50,10 @@ def convert_Series_to_Panel(obj, store=None): obj = pd.DataFrame(obj) if isinstance(obj, pd.DataFrame): - return [obj] + if return_to_mtype: + return [obj], "df-list" + else: + return [obj] if isinstance(obj, np.ndarray): if len(obj.shape) == 2: @@ -55,18 +62,23 @@ def convert_Series_to_Panel(obj, store=None): # numpy3D = (instances, variables, time) obj = np.expand_dims(obj, 0) obj = np.swapaxes(obj, 1, 2) + obj_mtype = "numpy3D" elif len(obj.shape) == 1: # from numpy1D to numpy3D # numpy1D = (time) # numpy3D = (instances, variables, time) obj = np.expand_dims(obj, (0, 1)) + obj_mtype = "numpy3D" else: raise ValueError("if obj is np.ndarray, must be of dim 1 or 2") - return obj + if return_to_mtype: + return obj, obj_mtype + else: + return obj -def convert_Panel_to_Series(obj, store=None): +def convert_Panel_to_Series(obj, store=None, return_to_mtype=False): """Convert single-series panel to a series. Removes panel index from the single-series panel to obtain a series. @@ -78,6 +90,10 @@ def convert_Panel_to_Series(obj, store=None): Parameters ---------- obj: an object of scitype Panel, of mtype pd-multiindex, numpy3d, or df-list. + store: dict, optional + converter store for back-conversion + return_to_mtype: bool, optional (default=False) + if True, also returns the str of the mtype converted to Returns ------- @@ -86,12 +102,16 @@ def convert_Panel_to_Series(obj, store=None): """ if isinstance(obj, list): if len(obj) == 1: - return obj[0] + if return_to_mtype: + return obj[0], "pd.DataFrame" + else: + return obj[0] else: raise ValueError("obj must be of length 1") if isinstance(obj, pd.DataFrame): obj.index = obj.index.droplevel(level=0) + obj_mtype = "pd.DataFrame" if isinstance(obj, np.ndarray): if obj.ndim != 3 or obj.shape[0] != 1: @@ -101,11 +121,15 @@ def convert_Panel_to_Series(obj, store=None): # numpy3D = (instances, variables, time) obj = np.reshape(obj, (obj.shape[1], obj.shape[2])) obj = np.swapaxes(obj, 0, 1) + obj_mtype = "np.ndarray" - return obj + if return_to_mtype: + return obj, obj_mtype + else: + return obj -def convert_Series_to_Hierarchical(obj, store=None): +def convert_Series_to_Hierarchical(obj, store=None, return_to_mtype=False): """Convert series to a single-series hierarchical object. Adds two dimensions to the series to obtain a 3-level MultiIndex, 2 levels added. @@ -117,6 +141,10 @@ def convert_Series_to_Hierarchical(obj, store=None): Parameters ---------- obj: an object of scitype Series, of mtype pd.DataFrame, pd.Series, or np.ndarray. + store: dict, optional + converter store for back-conversion + return_to_mtype: bool, optional (default=False) + if True, also returns the str of the mtype converted to Returns ------- @@ -128,10 +156,14 @@ def convert_Series_to_Hierarchical(obj, store=None): obj_df["__level2"] = 0 obj_df = obj_df.set_index(["__level1", "__level2"], append=True) obj_df = obj_df.reorder_levels([1, 2, 0]) - return obj_df + + if return_to_mtype: + return obj_df, "pd_multiindex_hier" + else: + return obj_df -def convert_Hierarchical_to_Series(obj, store=None): +def convert_Hierarchical_to_Series(obj, store=None, return_to_mtype=False): """Convert single-series hierarchical object to a series. Removes two dimensions to obtain a series, by removing 2 levels from MultiIndex. @@ -143,6 +175,10 @@ def convert_Hierarchical_to_Series(obj, store=None): Parameters ---------- obj: an object of scitype Hierarchical. + store: dict, optional + converter store for back-conversion + return_to_mtype: bool, optional (default=False) + if True, also returns the str of the mtype converted to Returns ------- @@ -151,10 +187,14 @@ def convert_Hierarchical_to_Series(obj, store=None): obj_df = convert_to(obj, to_type="pd_multiindex_hier", as_scitype="Hierarchical") obj_df = obj_df.copy() obj_df.index = obj_df.index.get_level_values(-1) - return obj_df + if return_to_mtype: + return obj_df, "pd.DataFrame" + else: + return obj_df -def convert_Panel_to_Hierarchical(obj, store=None): + +def convert_Panel_to_Hierarchical(obj, store=None, return_to_mtype=False): """Convert panel to a single-panel hierarchical object. Adds a dimensions to the panel to obtain a 3-level MultiIndex, 1 level is added. @@ -166,6 +206,10 @@ def convert_Panel_to_Hierarchical(obj, store=None): Parameters ---------- obj: an object of scitype Panel. + store: dict, optional + converter store for back-conversion + return_to_mtype: bool, optional (default=False) + if True, also returns the str of the mtype converted to Returns ------- @@ -176,10 +220,14 @@ def convert_Panel_to_Hierarchical(obj, store=None): obj_df["__level2"] = 0 obj_df = obj_df.set_index(["__level2"], append=True) obj_df = obj_df.reorder_levels([2, 0, 1]) - return obj_df + + if return_to_mtype: + return obj_df, "pd_multiindex_hier" + else: + return obj_df -def convert_Hierarchical_to_Panel(obj, store=None): +def convert_Hierarchical_to_Panel(obj, store=None, return_to_mtype=False): """Convert single-series hierarchical object to a series. Removes one dimensions to obtain a panel, by removing 1 level from MultiIndex. @@ -190,7 +238,11 @@ def convert_Hierarchical_to_Panel(obj, store=None): Parameters ---------- - obj: an object of scitype Hierarchical. + obj: an object of scitype Hierarchical + store: dict, optional + converter store for back-conversion + return_to_mtype: bool, optional (default=False) + if True, also returns the str of the mtype converted to Returns ------- @@ -199,10 +251,20 @@ def convert_Hierarchical_to_Panel(obj, store=None): obj_df = convert_to(obj, to_type="pd_multiindex_hier", as_scitype="Hierarchical") obj_df = obj_df.copy() obj_df.index = obj_df.index.get_level_values([-2, -1]) - return obj_df + + if return_to_mtype: + return obj_df, "pd-multiindex" + else: + return obj_df -def convert_to_scitype(obj, to_scitype, from_scitype=None, store=None): +def convert_to_scitype( + obj, + to_scitype, + from_scitype=None, + store=None, + return_to_mtype=False, +): """Convert single-series or single-panel between mtypes. Assumes input is conformant with one of the mtypes @@ -218,6 +280,8 @@ def convert_to_scitype(obj, to_scitype, from_scitype=None, store=None): scitype that obj is of, and being converted from if avoided, function will skip type inference from obj store : dict, optional. Converter store for back-conversion. + return_to_mtype: bool, optional (default=False) + if True, also returns the str of the mtype converted to Returns ------- @@ -239,4 +303,4 @@ def convert_to_scitype(obj, to_scitype, from_scitype=None, store=None): func_name = f"convert_{from_scitype}_to_{to_scitype}" func = eval(func_name) - return func(obj, store=store) + return func(obj, store=store, return_to_mtype=return_to_mtype) diff --git a/sktime/transformations/base.py b/sktime/transformations/base.py index 68e7be46fbe..56b7977c93b 100644 --- a/sktime/transformations/base.py +++ b/sktime/transformations/base.py @@ -58,6 +58,7 @@ class name: BaseTransformer VectorizedDF, check_is_mtype, check_is_scitype, + convert, convert_to, mtype_to_scitype, update_data, @@ -927,6 +928,7 @@ def _scitype_A_higher_B(scitypeA, scitypeB): # checking X X_metadata_required = ["is_univariate"] + X_valid, msg, X_metadata = check_is_scitype( X, scitype=ALLOWED_SCITYPES, @@ -998,6 +1000,7 @@ def _scitype_A_higher_B(scitypeA, scitypeB): raise TypeError("y " + msg_invalid_input) y_scitype = y_metadata["scitype"] + y_mtype = y_metadata["mtype"] else: # y_scitype is used below - set to None if y is None @@ -1023,7 +1026,9 @@ def _scitype_A_higher_B(scitypeA, scitypeB): as_scitype = "Panel" else: as_scitype = "Hierarchical" - X = convert_to_scitype(X, to_scitype=as_scitype, from_scitype=X_scitype) + X, X_mtype = convert_to_scitype( + X, to_scitype=as_scitype, from_scitype=X_scitype, return_to_mtype=True + ) X_scitype = as_scitype # then pass to case 1, which we've reduced to, X now has inner scitype @@ -1032,8 +1037,9 @@ def _scitype_A_higher_B(scitypeA, scitypeB): # and does not require vectorization because of cols (multivariate) if not requires_vectorization: # converts X - X_inner = convert_to( + X_inner = convert( X, + from_type=X_mtype, to_type=X_inner_mtype, store=metadata["_converter_store_X"], store_behaviour="reset", @@ -1041,8 +1047,9 @@ def _scitype_A_higher_B(scitypeA, scitypeB): # converts y, returns None if y is None if y_inner_mtype != ["None"] and y is not None: - y_inner = convert_to( + y_inner = convert( y, + from_type=y_mtype, to_type=y_inner_mtype, as_scitype=y_scitype, ) @@ -1145,11 +1152,17 @@ def _convert_output(self, X, metadata, inverse=False): # we cannot convert back to pd.Series, do pd.DataFrame instead then # this happens only for Series, not Panel if X_input_scitype == "Series": + if X_input_mtype == "pd.Series": + Xt_metadata_required = ["is_univariate"] + else: + Xt_metadata_required = [] + valid, msg, metadata = check_is_mtype( Xt, ["pd.DataFrame", "pd.Series", "np.ndarray"], - return_metadata=True, + return_metadata=Xt_metadata_required, ) + if not valid: raise TypeError( f"_transform output of {type(self)} does not comply " @@ -1157,9 +1170,20 @@ def _convert_output(self, X, metadata, inverse=False): " for mtype specifications. Returned error message:" f" {msg}. Returned object: {Xt}" ) - if not metadata["is_univariate"] and X_input_mtype == "pd.Series": + if X_input_mtype == "pd.Series" and not metadata["is_univariate"]: X_output_mtype = "pd.DataFrame" - + # Xt_mtype = metadata["mtype"] + # else: + # Xt_mtype = X_input_mtype + + # Xt = convert( + # Xt, + # from_type=Xt_mtype, + # to_type=X_output_mtype, + # as_scitype=X_input_scitype, + # store=_converter_store_X, + # store_behaviour="freeze", + # ) Xt = convert_to( Xt, to_type=X_output_mtype,