# Metadata Routing

[Metadata Routing — scikit\-learn 1\.4\.0 documentation](https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_metadata_routing.html#sphx-glr-auto-examples-miscellaneous-plot-metadata-routing-py)

- meta-estimatorを通してestimatorにmetadataを転送する
- 2つの重要な概念:
    - routers: ほとんどの場合meta-estimatorである。与えられたdata,metadataを他のオブジェクトやestimatorに転送する。
    - consumers: 与えられたmetadataを受け取り、使用するオブジェクト。
- あるオブジェクトがrouterとconsumerの両方となることは可能

## Set Up

In [1]:
import warnings
from pprint import pprint

import numpy as np

from sklearn import set_config
from sklearn.base import (
    BaseEstimator,
    ClassifierMixin,
    MetaEstimatorMixin,
    RegressorMixin,
    TransformerMixin,
    clone,
)
from sklearn.linear_model import LinearRegression
from sklearn.utils import metadata_routing
from sklearn.utils.metadata_routing import (
    MetadataRouter,
    MethodMapping,
    get_routing_for_object,
    process_routing,
)
from sklearn.utils.validation import check_is_fitted

In [2]:
n_samples, n_features = 100, 4
rng = np.random.RandomState(42)
X = rng.rand(n_samples, n_features)
y = rng.randint(0, 2, size=n_samples)
my_groups = rng.randint(0, 10, size=n_samples)
my_weights = rng.rand(n_samples)
my_other_weights = rng.rand(n_samples)

In [3]:
set_config(enable_metadata_routing=True)

In [4]:
# This utility function is a dummy to check if a metadata is Passed.
def check_metadata(obj, **kwargs):
    for key, value in kwargs.items():
        if value is not None:
            print(
                f"Received {key} of length = {len(value)} in {obj.__class__.__name__}."
            )
        else:
            print(f"{key} is None in {obj.__class__.__name__}.")

In [5]:
# A utility function to nicely print the routing information of an object
def print_routing(obj):
    pprint(obj.get_metadata_routing()._serialize())

## Estimators

In [6]:
class ExampleClassifier(ClassifierMixin, BaseEstimator):
    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        # all classifiers need to expose a classes_ attribute once they're fit.
        self.classes_ = np.array([0, 1])
        return self

    def predict(self, X, groups=None):
        check_metadata(self, groups=groups)
        # return a constant value of 1, not a very smart classifier!
        return np.ones(len(X))

In [7]:
print_routing(ExampleClassifier())

{'fit': {'sample_weight': None},
 'predict': {'groups': None},
 'score': {'sample_weight': None}}


In [8]:
est = (
    ExampleClassifier()
    .set_fit_request(sample_weight=False)
    .set_predict_request(groups=True)
    .set_score_request(sample_weight=False)
)
display(est)
print_routing(est)

{'fit': {'sample_weight': False},
 'predict': {'groups': True},
 'score': {'sample_weight': False}}


```{note}
meta-estimatorの中で使用されないと、`set_***_request`で設定された値は無視される
```

In [9]:
est = ExampleClassifier()
est.fit(X, y, sample_weight=my_weights)
est.predict(X[:3, :], groups=my_groups)

Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleClassifier.


array([1., 1., 1.])

metadataを転送するためだけのmeta-estimatorを定義する

In [10]:
class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def get_metadata_routing(self):
        # This method defines the routing for this meta-estimator.
        # In order to do so, a `MetadataRouter` instance is created, and the
        # right routing is added to it. More explanations follow.
        router = MetadataRouter(owner=self.__class__.__name__).add(
            estimator=self.estimator, method_mapping="one-to-one"
        )
        return router

    def fit(self, X, y, **fit_params):
        # meta-estimators are responsible for validating the given metadata.
        # `get_routing_for_object` is a safe way to construct a
        # `MetadataRouter` or a `MetadataRequest` from the given object.
        request_router = get_routing_for_object(self)
        request_router.validate_metadata(params=fit_params, method="fit")
        # we can use provided utility methods to map the given metadata to what
        # is required by the underlying estimator. Here `method` refers to the
        # parent's method, i.e. `fit` in this example.
        routed_params = request_router.route_params(params=fit_params, caller="fit")

        # the output has a key for each object's method which is used here,
        # i.e. parent's `fit` method, containing the metadata which should be
        # routed to them, based on the information provided in
        # `get_metadata_routing`.
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
        self.classes_ = self.estimator_.classes_
        return self

    def predict(self, X, **predict_params):
        check_is_fitted(self)
        # same as in `fit`, we validate the given metadata
        request_router = get_routing_for_object(self)
        request_router.validate_metadata(params=predict_params, method="predict")
        # and then prepare the input to the underlying `predict` method.
        routed_params = request_router.route_params(
            params=predict_params, caller="predict"
        )
        return self.estimator_.predict(X, **routed_params.estimator.predict)

In [11]:
print(get_routing_for_object.__doc__)

Get a ``Metadata{Router, Request}`` instance from the given object.

    This function returns a
    :class:`~sklearn.utils.metadata_routing.MetadataRouter` or a
    :class:`~sklearn.utils.metadata_routing.MetadataRequest` from the given input.

    This function always returns a copy or an instance constructed from the
    input, such that changing the output of this function will not change the
    original object.

    .. versionadded:: 1.3

    Parameters
    ----------
    obj : object
        - If the object is already a
            :class:`~sklearn.utils.metadata_routing.MetadataRequest` or a
            :class:`~sklearn.utils.metadata_routing.MetadataRouter`, return a copy
            of that.
        - If the object provides a `get_metadata_routing` method, return a copy
            of the output of that method.
        - Returns an empty :class:`~sklearn.utils.metadata_routing.MetadataRequest`
            otherwise.

    Returns
    -------
    obj : MetadataRequest or Metada

In [12]:
# `sample_weight`がExampleClassifierに渡されている
est = MetaClassifier(estimator=ExampleClassifier().set_fit_request(sample_weight=True))
est.fit(X, y, sample_weight=my_weights)

Received sample_weight of length = 100 in ExampleClassifier.


In [13]:
# 何も指定しなければ渡されない
est.fit(X, y)

sample_weight is None in ExampleClassifier.


In [14]:
# 定義されていない名前のmetadataを渡すとエラーとなる
try:
    est.fit(X, y, test=my_weights)
except TypeError as e:
    print(e)

MetaClassifier.fit got unexpected argument(s) {'test'}, which are not requested metadata in any object.


In [15]:
# requestしてないmetadata (groups) を渡すとエラーとなる
try:
    est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups)
except ValueError as e:
    print(e)

Received sample_weight of length = 100 in ExampleClassifier.
[groups] are passed but are not explicitly set as requested or not for ExampleClassifier.predict


In [16]:
# requestしないと指定したmetadata (groups=False) を渡すとエラーとなる
est = MetaClassifier(
    estimator=ExampleClassifier()
    .set_fit_request(sample_weight=True)
    .set_predict_request(groups=False)
)
try:
    est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups)
except TypeError as e:
    print(e)

Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier.predict got unexpected argument(s) {'groups'}, which are not requested metadata in any object.


### `fit`の中では何が行われてる?

In [32]:
est = MetaClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight=True),
)

In [37]:
request_router = get_routing_for_object(est)
type(request_router)

sklearn.utils._metadata_requests.MetadataRouter

In [49]:
request_router.owner

'MetaClassifier'

- `mapping`はcallerの`fit`メソッドの中でcalleeの`fit`メソッドが呼ばれるというような関係を定義している
- `router`には、estimatorがrequestできるメソッドと、そのメソッドに渡す値が入っている

In [42]:
pprint(request_router._serialize())

{'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'partial_fit', 'caller': 'partial_fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'predict_proba',
                            'caller': 'predict_proba'},
                           {'callee': 'predict_log_proba',
                            'caller': 'predict_log_proba'},
                           {'callee': 'decision_function',
                            'caller': 'decision_function'},
                           {'callee': 'score', 'caller': 'score'},
                           {'callee': 'split', 'caller': 'split'},
                           {'callee': 'transform', 'caller': 'transform'},
                           {'callee': 'inverse_transform',
                            'caller': 'inverse_transform'},
                           {'callee': 'fit_transform',
                            'caller': 'fit_transform'},

In [50]:
fit_params = dict(sample_weight=None)
request_router.validate_metadata(params=fit_params, method="fit")

In [54]:
try:
    fit_params = dict(other_key=None)
    request_router.validate_metadata(params=fit_params, method="fit")
except TypeError as e:
    print(e)

MetaClassifier.fit got unexpected argument(s) {'other_key'}, which are not requested metadata in any object.


`MetadataRouter.route_params`では`params`でメソッドに渡すパラメータを指定し、`caller`でそのパラメータを渡すメソッドを指定する

In [61]:
fit_params = dict(sample_weight=None)
routed_params = request_router.route_params(params=fit_params, caller="fit")
display(routed_params)

{'estimator': {'fit': {}}}

In [62]:
fit_params = dict(sample_weight=my_weights)
routed_params = request_router.route_params(params=fit_params, caller="fit")
display(routed_params)

{'estimator': {'fit': {'sample_weight': array([0.55543171, 0.76898742, 0.94476573, 0.84964739, 0.2473481 ,
          0.45054414, 0.12915942, 0.95405103, 0.60617463, 0.22864281,
          0.67170068, 0.61812824, 0.35816272, 0.11355759, 0.6715732 ,
          0.5203077 , 0.77231839, 0.5201635 , 0.8521815 , 0.55190684,
          0.56093797, 0.8766536 , 0.40348287, 0.13401523, 0.02878268,
          0.75513726, 0.62030955, 0.70407977, 0.21296416, 0.13637148,
          0.01454467, 0.35058756, 0.58991769, 0.39224405, 0.43747492,
          0.90415869, 0.34825547, 0.51398949, 0.78365301, 0.39654278,
          0.6220867 , 0.86236371, 0.94952062, 0.14707348, 0.92658763,
          0.49211629, 0.25824439, 0.45913576, 0.98003258, 0.49261809,
          0.32875161, 0.63340085, 0.24014562, 0.07586333, 0.12887972,
          0.12804584, 0.15190269, 0.13882717, 0.64087474, 0.18188008,
          0.34566728, 0.89678841, 0.47396164, 0.66755774, 0.17231987,
          0.19228902, 0.04086862, 0.16893506, 0.27859

---

- aliased metadata
    - `sample_weight1`, `sample_weight2`のように名前を付けられる

In [17]:
est = MetaClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
est.fit(X, y, aliased_sample_weight=my_weights)

Received sample_weight of length = 100 in ExampleClassifier.


In [18]:
try:
    est.fit(X, y, sample_weight=my_weights)
except TypeError as e:
    print(e)

MetaClassifier.fit got unexpected argument(s) {'sample_weight'}, which are not requested metadata in any object.


metadata routingの仕組み

- consumerが`set_{method_name}_request`で必要なmetadataを指定する
- routerがそのmetadataを手渡す

In [19]:
print_routing(est)

{'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'partial_fit', 'caller': 'partial_fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'predict_proba',
                            'caller': 'predict_proba'},
                           {'callee': 'predict_log_proba',
                            'caller': 'predict_log_proba'},
                           {'callee': 'decision_function',
                            'caller': 'decision_function'},
                           {'callee': 'score', 'caller': 'score'},
                           {'callee': 'split', 'caller': 'split'},
                           {'callee': 'transform', 'caller': 'transform'},
                           {'callee': 'inverse_transform',
                            'caller': 'inverse_transform'},
                           {'callee': 'fit_transform',
                            'caller': 'fit_transform'},

In [20]:
meta_est = MetaClassifier(estimator=est).fit(X, y, aliased_sample_weight=my_weights)

Received sample_weight of length = 100 in ExampleClassifier.


## Router and Consumer

- 考える状況:
    - meta-estimatorがmetadataを使用する
    - そのmetadataは内包するestimatorにも転送する

In [21]:
class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def get_metadata_routing(self):
        router = (
            MetadataRouter(owner=self.__class__.__name__)
            .add_self_request(self)
            .add(estimator=self.estimator, method_mapping="one-to-one")
        )
        return router

    def fit(self, X, y, sample_weight, **fit_params):
        if self.estimator is None:
            raise ValueError("estimator cannot be None!")

        check_metadata(self, sample_weight=sample_weight)

        if sample_weight is not None:
            fit_params["sample_weight"] = sample_weight

        # meta-estimators are responsible for validating the given metadata
        request_router = get_routing_for_object(self)
        request_router.validate_metadata(params=fit_params, method="fit")
        # we can use provided utility methods to map the given metadata to what
        # is required by the underlying estimator
        params = request_router.route_params(params=fit_params, caller="fit")
        self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
        self.classes_ = self.estimator_.classes_
        return self

    def predict(self, X, **predict_params):
        check_is_fitted(self)
        # same as in ``fit``, we validate the given metadata
        request_router = get_routing_for_object(self)
        request_router.validate_metadata(params=predict_params, method="predict")
        # and then prepare the input to the underlying ``predict`` method.
        params = request_router.route_params(params=predict_params, caller="predict")
        return self.estimator_.predict(X, **params.estimator.predict)

In [22]:
est = RouterConsumerClassifier(estimator=ExampleClassifier())
print_routing(est)

{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'partial_fit', 'caller': 'partial_fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'predict_proba',
                            'caller': 'predict_proba'},
                           {'callee': 'predict_log_proba',
                            'caller': 'predict_log_proba'},
                           {'callee': 'decision_function',
                            'caller': 'decision_function'},
                           {'callee': 'score', 'caller': 'score'},
                           {'callee': 'split', 'caller': 'split'},
                           {'callee': 'transform', 'caller': 'transform'},
                           {'callee': 'inverse_transform',
                            'caller': 'inverse_transform'},
     

In [23]:
est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
print_routing(est)

{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'partial_fit', 'caller': 'partial_fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'predict_proba',
                            'caller': 'predict_proba'},
                           {'callee': 'predict_log_proba',
                            'caller': 'predict_log_proba'},
                           {'callee': 'decision_function',
                            'caller': 'decision_function'},
                           {'callee': 'score', 'caller': 'score'},
                           {'callee': 'split', 'caller': 'split'},
                           {'callee': 'transform', 'caller': 'transform'},
                           {'callee': 'inverse_transform',
                            'caller': 'inverse_transform'},
     

In [24]:
est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request(
    sample_weight=True
)
print_routing(est)

{'$self_request': {'fit': {'sample_weight': True},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'partial_fit', 'caller': 'partial_fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'predict_proba',
                            'caller': 'predict_proba'},
                           {'callee': 'predict_log_proba',
                            'caller': 'predict_log_proba'},
                           {'callee': 'decision_function',
                            'caller': 'decision_function'},
                           {'callee': 'score', 'caller': 'score'},
                           {'callee': 'split', 'caller': 'split'},
                           {'callee': 'transform', 'caller': 'transform'},
                           {'callee': 'inverse_transform',
                            'caller': 'inverse_transform'},
     

In [25]:
est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="clf_sample_weight"),
).set_fit_request(sample_weight="meta_clf_sample_weight")
print_routing(est)

{'$self_request': {'fit': {'sample_weight': 'meta_clf_sample_weight'},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'partial_fit', 'caller': 'partial_fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'predict_proba',
                            'caller': 'predict_proba'},
                           {'callee': 'predict_log_proba',
                            'caller': 'predict_log_proba'},
                           {'callee': 'decision_function',
                            'caller': 'decision_function'},
                           {'callee': 'score', 'caller': 'score'},
                           {'callee': 'split', 'caller': 'split'},
                           {'callee': 'transform', 'caller': 'transform'},
                           {'callee': 'inverse_transform',
                            'caller': 'invers

In [26]:
est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights)

Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.


In [27]:
est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
).set_fit_request(sample_weight=True)
print_routing(est)

{'$self_request': {'fit': {'sample_weight': True},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'partial_fit', 'caller': 'partial_fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'predict_proba',
                            'caller': 'predict_proba'},
                           {'callee': 'predict_log_proba',
                            'caller': 'predict_log_proba'},
                           {'callee': 'decision_function',
                            'caller': 'decision_function'},
                           {'callee': 'score', 'caller': 'score'},
                           {'callee': 'split', 'caller': 'split'},
                           {'callee': 'transform', 'caller': 'transform'},
                           {'callee': 'inverse_transform',
                            'caller': 'inverse_transform'},
     

## Simple Pipeline

In [28]:
class SimplePipeline(ClassifierMixin, BaseEstimator):
    _required_parameters = ["estimator"]

    def __init__(self, transformer, classifier):
        self.transformer = transformer
        self.classifier = classifier

    def get_metadata_routing(self):
        router = (
            MetadataRouter(owner=self.__class__.__name__)
            .add(
                transformer=self.transformer,
                method_mapping=MethodMapping()
                .add(callee="fit", caller="fit")
                .add(callee="transform", caller="fit")
                .add(callee="transform", caller="predict"),
            )
            .add(classifier=self.classifier, method_mapping="one-to-one")
        )
        return router

    def fit(self, X, y, **fit_params):
        params = process_routing(self, "fit", **fit_params)

        self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit)
        X_transformed = self.transformer_.transform(X, **params.transformer.transform)

        self.classifier_ = clone(self.classifier).fit(
            X_transformed, y, **params.classifier.fit
        )
        return self

    def predict(self, X, **predict_params):
        params = process_routing(self, "predict", **predict_params)

        X_transformed = self.transformer_.transform(X, **params.transformer.transform)
        return self.classifier_.predict(X_transformed, **params.classifier.predict)

In [29]:
class ExampleTransformer(TransformerMixin, BaseEstimator):
    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        return self

    def transform(self, X, groups=None):
        check_metadata(self, groups=groups)
        return X

    def fit_transform(self, X, y, sample_weight=None, groups=None):
        return self.fit(X, y, sample_weight).transform(X, groups)

In [30]:
est = SimplePipeline(
    transformer=ExampleTransformer()
    # we transformer's fit to receive sample_weight
    .set_fit_request(sample_weight=True)
    # we want transformer's transform to receive groups
    .set_transform_request(groups=True),
    classifier=RouterConsumerClassifier(
        estimator=ExampleClassifier()
        # we want this sub-estimator to receive sample_weight in fit
        .set_fit_request(sample_weight=True)
        # but not groups in predict
        .set_predict_request(groups=False),
    ).set_fit_request(
        # and we want the meta-estimator to receive sample_weight as well
        sample_weight=True
    ),
)
est.fit(X, y, sample_weight=my_weights, groups=my_groups).predict(
    X[:3], groups=my_groups
)

Received sample_weight of length = 100 in ExampleTransformer.
Received groups of length = 100 in ExampleTransformer.
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleTransformer.
groups is None in ExampleClassifier.


array([1., 1., 1.])