# Metric支持设置index指定用于计算评估值的输入

问题描述：目前头孢 Metrics 模块调用的时候一般是传入两个值，分别是 logits 和 labels. 所以当用户自定义一个网络的时候，而且当这个网络是多输出的时候，用户就需要显示的指定 logits 和 labels 在输入对应的位置，或者需要自定义一个 Metrics 度量来控制模型的输出，往往会给用户使用带来一些困扰。

## 解决方案 1

```python
# 2. define custom accuracy, and do not set the `eval_indexes`.
# custom accuracy, the inputs is the outputs of evaluation network, so we should use inputs[0] and inputs[2] to update metrics.
from typing import List, Optional


class CustomAccuracy(nn.Metric):
    def __init__(self, eval_type: str = 'classification', indexes: Optional[List[int]] = None):
        super(CustomAccuracy, self).__init__(eval_type)
        self.indexes = indexes

    def update(self, *inputs):
        y_pred = inputs[self.indexes[0]]
        y_label = inputs[self.indexes[1]]

        pass

# use CustomAccuracy and eval_indexes can be empty
model = Model(..., eval_network=eval_network, metrics={'accuracy': CustomAccuracy(indexes=[0, 2])})
```

## 解决方案 2

```python
# A example for modifing the nn.Accuracy
class Accuracy(nn.Metric):
    def __init__(self, eval_type='classification'):
        super().__init__(eval_type)

    @rearrange_inputs
    def update(self, *inputs):
        pass

model = Model(..., eval_network=eval_network, metrics={'accuracy': nn.Accuracy().set_indexes([0, 2])})
```

### 具体实现

In [1]:
import functools
from typing import List, Optional
from abc import ABCMeta, abstractmethod

In [2]:
def rearrange_inputs(func):
    """
    This decorator is used to rearrange the inputs according to its indexes.

    Args:
        func (Callable): A candidate function to be wrapped whose input will be rearranged.

    Returns:
        Callable, used to exchange metadata between functions.
    """
    @functools.wraps(func)
    def wrapper(self, *inputs):
        indexes = self.indexes
        inputs = inputs if not indexes else [inputs[i] for i in indexes]
        return func(self, *inputs)
    return wrapper

In [3]:
class RearrangeInputsDemo(object):
    def __init__(self):
        self._indexes = None

    @property
    def indexes(self):
        return getattr(self, '_indexes', None)

    def set_indexes(self, indexes):
        self._indexes = indexes
        return self

    @rearrange_inputs
    def update(self, *inputs):
        return inputs

In [4]:
# test_rearrange_inputs_without_arrange()
mini_decorator = RearrangeInputsDemo()
outs = mini_decorator.update(5, 9)
# assert outs == (5, 9)
print(f'out: {outs}')

out: (5, 9)


In [5]:
# test_rearrange_inputs_with_arrange()
mini_decorator = RearrangeInputsDemo().set_indexes([1, 0])
outs = mini_decorator.update(5, 9)
# assert outs == (9, 5)
print(f'out: {outs}')

out: (9, 5)


In [6]:
# test_rearrange_inputs_with_multi_inputs()
mini_decorator = RearrangeInputsDemo().set_indexes([1, 3])
outs = mini_decorator.update(0, 9, 0, 5)
# assert outs == (9, 5)
print(f'out: {outs}')

out: (9, 5)
