In [None]:
#|default_exp metrics

In [None]:
#|export
import inspect, time, math

import numpy as np
import sklearn.metrics as skm
import scipy.stats as scs

from fastcore.basics import class2attr
from fastcore.dispatch import cast

from fastprogress.core import format_time

from fastai.callback.core import Callback
from fastai.learner import AvgMetric, Learner, Metric, Recorder, _maybe_item
from fastai.metrics import AccumMetric
from fastai.torch_core import flatten_check

from fastxtend.imports import *

In [None]:
#|hide
import random
from nbdev.showdoc import *
from fastxtend.test_utils import *

# Metrics Extended
> A backwards compatible reimplementation of fastai metrics to increase usability and flexibility.

All fastxtend metrics are classes which inherit from fastai's `Metric` and run on `Learner` via a modified `Recorder` callback. 

There are three main metric types: `AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX`. These correspond one-to-one with fastai's `AvgMetric`, `AccumMetric`, and `AvgSmoothMetric`.

To jump to the fastxtend metrics reference, click [here](#Metrics).

## Using a Metric

To use the accuracy metric, or any fastxtend metrics detailed below, create a `Learner` like normal (or task specific learner such as `vision_learner`, `text_classifier_learner`, etc) and add the metric(s) to the `metrics` argument:

```python
from fastai.vision.all import *
from fastxtend.vision.all import *

Learner(..., metrics=Accuracy())
```

Fastxtend metrics can be mixed with fastai metrics:

```python
Learner(..., metrics=[accuracy, Accuracy()])
```

Fastxtend metrics can be logged during training, validation, or both by setting the `log_metric` argument to `LogMetric.Train`, `LogMetric.Valid`, or `LogMetric.Both`. The sole exception is `AvgSmoothMetricX` which only logs during training.

> Note: By default, a fastxtend metric will log during validation. Fastai metrics can only log during validation. 

To log a fastxtend metric during training pass `LogMetric.Train` to `log_metric`:

```python
Learner(..., metrics=Accuracy(log_metric=LogMetric.Train))
```

Non-scikit-learn metrics can have the log type set via the `metric_type` argument to one of `MetricType.Avg`, `MetricType.Accum`, `MetricType.Smooth`, corresponding to `AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX`, respectively.

To log a smooth metric on the training set and normal metric on the valid set:

```python
Learner(..., 
        metrics=[Accuracy(log_metric=LogMetric.Train, metric_type=MetricType.Smooth), 
                 Accuracy()])
```

Fastxtend metrics also support custom names via the `name` argument:

```python
Learner(..., metrics=Accuracy(name='metric_name'))
```

which will result in Accuracy logging under "metric_name" instead of the default "accuracy".

If a fastxtend metric is logged with multiple `MetricType`s, the fastxtend `Recorder` will automatically deduplication the metric names. Unless the metric's `name` argument is set. Then fastxtend will not deduplicate any metric names.

## Creating a Metric

`AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX` all require `func`, which is a funcational implementation of the metric. The signature of `func` should be `inp,targ` (where `inp` are the predictions of the model and `targ` the corresponding labels).

Fastxtend metrics can be logged during training, validation, or both by setting the `log_metric` argument to `LogMetric.Train`, `LogMetric.Valid`, or `LogMetric.Both`. The sole exception is `AvgSmoothMetricX` which only computes during training.

`AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX` will automatically recognize and pass any `func`'s unique arguments to `func`.

> Important: Some metrics, like Root Mean Squared Error, will have incorrect results if passed to `AvgMetricX` via `MetricType.Avg`, as the mean of multiple batches of RMSE isn't equal to the RMSE of the whole dataset. For these metrics use `AccumMetricX` via `MetricType.Accum`.

An example of creating a fastxtend metric from a functional implementation:

```python
def example_accuracy(inp, targ):
    return (inp == targ).float().mean()

def ExampleAccuracy(dim_argmax=-1, log_metric=LogMetric.Valid, **kwargs):
    return AvgMetricX(example_accuracy, dim_argmax=dim_argmax, log_metric=log_metric, **kwargs)
```

Alternatively, use the `func_to_metric` convenience method to create the metric:

```python
def ExampleAccuracy(axis=-1, log_metric=LogMetric.Valid, **kwargs):
    return func_to_metric(example_accuracy, MetricType.Avg, True, axis=axis, log_metric=log_metric, **kwargs)
```

It is also possible to inherit directly from `MetricX` to create a fastxtend metric.

```python
class ExampleAccuracy(MetricX):
    def __init__(self, dim_argmax=-1, log_metric=LogMetric.Valid, **kwargs):
    super().__init__(dim_argmax=dim_argmax, log_metric=log_metric, **kwargs)

    def reset(self): self.preds,self.targs = [],[]

    def accumulate(self, learn):
        super().accumulate(learn)
        self.preds.append(learn.to_detach(self.pred))
        self.targs.append(learn.to_detach(self.targ))

    @property
    def value(self):
        if len(self.preds) == 0: return
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        return (preds == targs).float().mean()
```
> Important: If your custom <code>MetricX</code> has state depending on tensors, don't forget to store it on the CPU to avoid any potential memory leaks.

## Additional Metrics Functionality

`MetricX`, and classes which inherit from `MetricX` such as `AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX`, have optional helper functionality in `MetricX.accumulate` to assist in developing metrics.

For classification problems with single label, predictions need to be transformed with a softmax then an argmax before being compared to the targets. Since a softmax doesn't change the order of the numbers, apply the argmax. Pass along `dim_argmax` to have this done by `MetricX` (usually -1 will work pretty well). If the metric implementation requires probabilities and not predictions, use `softmax=True`.

For classification problems with multiple labels, or if targets are one-hot encoded, predictions may need to pass through a sigmoid (if it wasn't included in in the model) then be compared to a given threshold (to decide between 0 and 1), this is done by `MetricX` by passing `sigmoid=True` and/or a value for `thresh`.

`AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX` have two additional arguments to assist in creating metrics: `to_np` and `invert_arg`.

For example, if using a functional metric from sklearn.metrics, predictions and labels will need to be converted to numpy arrays with `to_np=True`. Also, scikit-learn metrics adopt the convention `y_true`, `y_preds` which is the opposite from fastai, so pass `invert_arg=True` to make `AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX` do the inversion. Alternatively, use the `skm_to_fastxtend` convenience method to handle sklearn.metrics automatically.

## Extended Metric -

In [None]:
#|export
class LogMetric(Enum):
    "All possible logging types for `MetricX`",
    Train = 1
    Valid = 2
    Both = 3

In [None]:
#|export
class MetricType(Enum):
    "All possible types of `MetricX`",
    Avg = 1
    Accum = 2
    Smooth = 3

In [None]:
#|export
class ActivationType(Enum):
    "All possible activation classes for `MetricX",
    No = 1
    Sigmoid = 2
    Softmax = 3
    BinarySoftmax = 4

### Metric -

In [None]:
#|export
class MetricX(Metric):
    "Blueprint for defining an extended metric with accumulate"
    log_metric=LogMetric.Valid
    def __init__(self, dim_argmax=None, activation=ActivationType.No, thresh=None, log_metric=None, name=None):
        store_attr(but='log_metric, name')
        self.log_metric = ifnone(log_metric, self.log_metric)
        self._name = ifnone(name, class2attr(self, 'MetricX'))
        
    def reset(self):
        "Reset inner state to prepare for new computation"
        pass

    def accumulate(self, learn):
        "Store targs and preds from `learn`, using activation function and argmax as appropriate"
        self.pred, self.targ = learn.pred, *learn.yb
        if self.activation in [ActivationType.Softmax, ActivationType.BinarySoftmax]:
            self.pred = F.softmax(self.pred, dim=self.dim_argmax)
            if self.activation == ActivationType.BinarySoftmax: self.pred = self.pred[:, -1]
        elif self.activation == ActivationType.Sigmoid: self.pred = torch.sigmoid(self.pred)
        elif self.dim_argmax: self.pred = self.pred.argmax(dim=self.dim_argmax)
        if self.thresh: self.pred = (self.pred >= self.thresh)
        
    @property
    def value(self):
        "The value of the metric"
        raise NotImplementedError

    @property
    def name(self): 
        "Name of the `Metric`, camel-cased and with Metric removed. Or custom name if provided"
        return self._name

    @name.setter
    def name(self, value): self._name = value

    def _split_kwargs(self, method, **kwargs):
        args = [k for k,v in inspect.signature(method).parameters.items() if v.default != inspect.Parameter.empty]
        return {k: kwargs[k] for k in kwargs.keys() if k in args}

> Note: By default, a `MetricX` will only log during validation. Metrics can individually set to run during training, validation, or both by passing `LogMetric.Train`, `LogMetric.Valid`, or `LogMetric.Both` to `log_metric`, respectively.

For classification problems with single label, predictions need to be transformed with a softmax then an argmax before being compared to the targets. Since a softmax doesn't change the order of the numbers, apply the argmax. Pass along `dim_argmax` to have this done by `MetricX` (usually -1 will work pretty well). If the metric implementation requires probabilities and not predictions, use `softmax=True`.

For classification problems with multiple labels, or if targets are one-hot encoded, predictions may need to pass through a sigmoid (if it wasn't included in in the model) then be compared to a given threshold (to decide between 0 and 1), this is done by `MetricX` by passing `sigmoid=True` and/or a value for `thresh`.

Metrics can be simple averages (like accuracy) but sometimes their computation is a little bit more complex and can't be averaged over batches (like precision or recall), which is why we need a special `AccumMetricX` class for them. For simple functions that can be computed as averages over batches, we can use the class `AvgMetricX`, otherwise you'll need to implement the following methods.

> Note: If your custom <code>MetricX</code> has state depending on tensors, don't forget to store it on the CPU to avoid any potential memory leaks.

In [None]:
show_doc(MetricX.reset)

<h4 id="MetricX.reset" class="doc_header"><code>MetricX.reset</code><a href="__main__.py#L10" class="source_link" style="float:right">[source]</a></h4>

> <code>MetricX.reset</code>()

Reset inner state to prepare for new computation

In [None]:
show_doc(MetricX.accumulate)

<h4 id="MetricX.accumulate" class="doc_header"><code>MetricX.accumulate</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h4>

> <code>MetricX.accumulate</code>(**`learn`**)

Store targs and preds from `learn`, using activation function and argmax as appropriate

In [None]:
show_doc(MetricX.value, name='MetricX.value')

<h4 id="MetricX.value" class="doc_header"><code>MetricX.value</code><a href="" class="source_link" style="float:right">[source]</a></h4>

The value of the metric

In [None]:
show_doc(MetricX.name, name='MetricX.name')

<h4 id="MetricX.name" class="doc_header"><code>MetricX.name</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Name of the `Metric`, camel-cased and with Metric removed. Or custom name if provided

### Metric Testing -

In [None]:
#|hide
#For testing: a fake learner and a metric that isn't an average
@delegates()
class TstLearner(Learner):
    def __init__(self,dls=None,model=None,**kwargs): self.pred,self.xb,self.yb = None,None,None

In [None]:
#|hide
def _l2_mean(x,y): return torch.sqrt((x.float()-y.float()).pow(2).mean())
def _mse(x,y): return F.mse_loss(*flatten_check(x.float(),y.float()))

#Go through a fake cycle with various batch sizes and computes the value of met
def compute_val(met, x1, x2):
    met.reset()
    vals = [0,6,15,20]
    learn = TstLearner()
    for i in range(3):
        learn.pred,learn.yb = x1[vals[i]:vals[i+1]],(x2[vals[i]:vals[i+1]],)
        met.accumulate(learn)
    return met.value

In [None]:
#|hide
def _test_metric(metrictype, metric):
    x1,x2 = torch.randn(20,5),torch.randn(20,5)
    tst = metrictype(metric)
    test_close(compute_val(tst, x1, x2), metric(x1, x2))
    if hasattrs(tst, ['preds', 'targs']):
        test_eq(torch.cat(tst.preds), x1.view(-1))
        test_eq(torch.cat(tst.targs), x2.view(-1))

    #test argmax
    x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))
    tst = metrictype(metric, dim_argmax=-1)
    test_close(compute_val(tst, x1, x2), metric(x1.argmax(dim=-1), x2))

    #test thresh
    x1,x2 = torch.randn(20,5),torch.randint(0, 2, (20,5)).bool()
    tst = metrictype(metric, thresh=0.5)
    test_close(compute_val(tst, x1, x2), metric((x1 >= 0.5), x2))

    #test sigmoid
    x1,x2 = torch.randn(20,5),torch.randn(20,5)
    tst = metrictype(metric, activation=ActivationType.Sigmoid)
    test_close(compute_val(tst, x1, x2), metric(torch.sigmoid(x1), x2))

    #test to_np
    x1,x2 = torch.randn(20,5),torch.randn(20,5)
    tst = metrictype(lambda x,y: isinstance(x, np.ndarray) and isinstance(y, np.ndarray), to_np=True)
    assert compute_val(tst, x1, x2)

    #test invert_arg
    if hasattrs(tst, ['preds', 'targs']):
        x1,x2 = torch.randn(20,5),torch.randn(20,5)
        tst = metrictype(lambda x,y: torch.sqrt(x.pow(2).mean()))
        test_close(compute_val(tst, x1, x2), torch.sqrt(x1.pow(2).mean()))
        tst = metrictype(lambda x,y: torch.sqrt(x.pow(2).mean()), invert_arg=True)
        test_close(compute_val(tst, x1, x2), torch.sqrt(x2.pow(2).mean()))
    else:
        x1,x2 = torch.randn(20,5),torch.randn(20,5)
        tst = metrictype(lambda x,y: (x-y).mean())
        test_close(compute_val(tst, x1, x2), (x1-x2).mean())
        tst = metrictype(lambda x,y: (x-y).mean(), invert_arg=True)
        test_close(compute_val(tst, x1, x2), (x2-x1).mean())

### AvgMetric -

In [None]:
#|export
@delegates(MetricX)
class AvgMetricX(MetricX):
    "Average the values of `func` taking into account potential different batch sizes"
    def __init__(self, func, to_np=False, invert_arg=False, **kwargs):
        super().__init__(**self._split_kwargs(MetricX.__init__, **kwargs))
        self.func, self.fkwargs = func, self._split_kwargs(func, **kwargs)
        self.to_np, self.invert_arg = to_np, invert_arg
        self._name = ifnone(kwargs.get('name', None), self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__)

    def reset(self): self.total,self.count = 0.,0

    def accumulate(self, learn):
        super().accumulate(learn)
        bs = find_bs(learn.yb)
        if self.to_np: self.pred,self.targ = learn.to_detach(self.pred).numpy(),learn.to_detach(self.targ).numpy()
        self.total += (self.func(self.targ, self.pred, **self.fkwargs) if self.invert_arg else self.func(self.pred, self.targ, **self.fkwargs))*bs
        self.count += bs

    @property
    def value(self): return self.total/self.count if self.count != 0 else None

`func` is applied to each batch of predictions/targets and then averaged when `value` attribute is asked for.The signature of `func` should be `inp,targ` (where `inp` are the predictions of the model and `targ` the corresponding labels).

> Important: Some metrics, like Root Mean Squared Error, will have incorrect results if passed to `AvgMetricX`, as the mean of multiple batches of RMSE isn't equal to the RMSE of the whole dataset. For these metrics use `AccumMetricX`.

If using a functional metric from sklearn.metrics, predictions and labels will need to be converted to numpy arrays with `to_np=True`. Also, scikit-learn metrics adopt the convention `y_true`, `y_preds` which is the opposite from fastai, so pass `invert_arg=True` to make `AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX` do the inversion. Alternatively, use the `skm_to_fastxtend` convenience method to handle sklearn.metrics automatically.

By default, fastxtend's scikit-learn metrics use `AccumMetricX`.

In [None]:
#|hide
learn = synth_learner()
tst = AvgMetricX(lambda x,y: (x-y).abs().mean())
t,u = torch.randn(100),torch.randn(100)
tst.reset()
for i in range(0,100,25): 
    learn.pred,learn.yb = t[i:i+25],(u[i:i+25],)
    tst.accumulate(learn)
test_close(tst.value, (t-u).abs().mean())

_test_metric(AvgMetricX, _mse)

### AccumMetric -

In [None]:
#|export
@delegates(MetricX)
class AccumMetricX(MetricX):
    "Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
    def __init__(self, func, to_np=False, invert_arg=False, flatten=True, **kwargs):
        super().__init__(**self._split_kwargs(MetricX.__init__, **kwargs))
        self.flatten, self.func, self.fkwargs = flatten, func, self._split_kwargs(func, **kwargs)
        self.to_np, self.invert_arg = to_np, invert_arg
        self._name = ifnone(kwargs.get('name', None), self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__)

    def reset(self):
        "Clear all targs and preds"
        self.targs,self.preds = [],[]

    def accumulate(self, learn):
        "Store targs and preds from `learn`, using activation function and argmax as appropriate"
        super().accumulate(learn)
        self.pred,self.targ = learn.to_detach(self.pred),learn.to_detach(self.targ)
        self.accum_values(self.pred, self.targ)

    def accum_values(self, preds, targs):
        "Store targs and preds"
        if self.flatten: preds,targs = flatten_check(preds,targs)
        self.preds.append(preds)
        self.targs.append(targs)

    def __call__(self, preds, targs):
        "Calculate metric on one batch of data"
        self.reset()
        self.accum_values(preds,targs)
        return self.value

    @property
    def value(self):
        "Value of the metric using accumulated preds and targs"
        if len(self.preds) == 0: return
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        if self.to_np: preds,targs = preds.numpy(),targs.numpy()
        return self.func(targs, preds, **self.fkwargs) if self.invert_arg else self.func(preds, targs, **self.fkwargs)

`func` is only applied to the accumulated predictions/targets when the `value` attribute is asked for (so at the end of a validation/training phase, in use with `Learner` and its `Recorder`).The signature of `func` should be `inp,targ` (where `inp` are the predictions of the model and `targ` the corresponding labels).

If using a functional metric from sklearn.metrics, predictions and labels will need to be converted to numpy arrays with `to_np=True`. Also, scikit-learn metrics adopt the convention `y_true`, `y_preds` which is the opposite from fastai, so pass `invert_arg=True` to make `AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX` do the inversion. Alternatively, use the `skm_to_fastxtend` convenience method to handle sklearn.metrics automatically.

By default, fastai's scikit-learn metrics use `AccumMetricX`.

In [None]:
#|hide
_test_metric(AccumMetricX, _l2_mean)

### AvgSmoothMetric -

In [None]:
#|export
@delegates(MetricX, but='log_metric')
class AvgSmoothMetricX(MetricX):
    "Smooth average the values of `func` (exponentially weighted with `beta`). Only computed on training set."
    log_metric = LogMetric.Train
    def __init__(self, func, beta=0.98, to_np=False, invert_arg=False, **kwargs):
        super().__init__(**self._split_kwargs(MetricX.__init__, **kwargs))
        self.func, self.fkwargs = func, self._split_kwargs(func, **kwargs)
        self.beta, self.to_np, self.invert_arg = beta, to_np, invert_arg
        self._name = ifnone(kwargs.get('name', None), self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__)

    def reset(self): self.count,self.val = 0,tensor(0.)

    def accumulate(self, learn):
        super().accumulate(learn)
        if self.to_np: self.pred,self.targ = learn.to_detach(self.pred).numpy(),learn.to_detach(self.targ).numpy()
        val = self.func(self.targ, self.pred, **self.fkwargs) if self.invert_arg else self.func(self.pred, self.targ, **self.fkwargs)
        if self.to_np: self.val = val*self.beta + val*(1-self.beta)
        else: self.val = torch.lerp(to_detach(val, gather=False), self.val, self.beta)
        self.count += 1

    @property
    def value(self): return self.val/(1-self.beta**self.count) if self.count != 0 else None

`func` is only applied to the accumulated predictions/targets when the `value` attribute is asked for (so at the end of a validation/training phase, in use with `Learner` and its `Recorder`).The signature of `func` should be `inp,targ` (where `inp` are the predictions of the model and `targ` the corresponding labels).

If using a functional metric from sklearn.metrics, predictions and labels will need to be converted to numpy arrays with `to_np=True`. Also, scikit-learn metrics adopt the convention `y_true`, `y_preds` which is the opposite from fastai, so pass `invert_arg=True` to make `AvgMetricX`, `AccumMetricX`, and `AvgSmoothMetricX` do the inversion. Alternatively, use the `skm_to_fastxtend` convenience method to handle sklearn.metrics automatically.

In [None]:
#|hide
learn = synth_learner()
tst = AvgSmoothMetricX(lambda x,y: (x-y).abs().mean())
t,u = torch.randn(100),torch.randn(100)
tst.reset()
val = tensor(0.)
for i, j in enumerate(range(0,100,25)):
    learn.pred,learn.yb = t[j:j+25],(u[j:j+25],)
    tst.accumulate(learn)
    val = val*0.98 + ((t[j:j+25]-u[j:j+25]).abs().mean())*(1-0.98)
    test_close(val/(1-0.98**(i+1)), tst.value)

### AvgLoss -

In [None]:
#|export
class AvgLossX(MetricX):
    "Average the losses taking into account potential different batch sizes"
    def reset(self): self.total,self.count = 0.,0
    def accumulate(self, learn):
        bs = find_bs(learn.yb)
        self.total += learn.to_detach(learn.loss.mean())*bs
        self.count += bs
    @property
    def value(self): return self.total/self.count if self.count != 0 else None

In [None]:
#|hide
tst = AvgLossX()
t = torch.randn(100)
tst.reset()
for i in range(0,100,25): 
    learn.yb,learn.loss = t[i:i+25],t[i:i+25].mean()
    tst.accumulate(learn)
test_close(tst.value, t.mean())

In [None]:
#|hide
#With varying batch size
tst.reset()
splits = [0, 30, 50, 60, 100]
for i in range(len(splits )-1): 
    learn.yb,learn.loss = t[splits[i]:splits[i+1]],t[splits[i]:splits[i+1]].mean()
    tst.accumulate(learn)
test_close(tst.value, t.mean())

### AvgSmoothLoss -

In [None]:
#|export
class AvgSmoothLossX(MetricX):
    "Smooth average of the losses (exponentially weighted with `beta`)"
    log_metric = LogMetric.Train
    def __init__(self, beta=0.98): 
        self.beta = beta
    def reset(self):
        self.count,self.val = 0,tensor(0.)
    def accumulate(self, learn):
        self.count += 1
        self.val = torch.lerp(to_detach(learn.loss.mean(), gather=False), self.val, self.beta)
    @property
    def value(self): return self.val/(1-self.beta**self.count)

In [None]:
#|hide
tst = AvgSmoothLossX()
t = torch.randn(100)
tst.reset()
val = tensor(0.)
for i in range(4): 
    learn.loss = t[i*25:(i+1)*25].mean()
    tst.accumulate(learn)
    val = val*0.98 + t[i*25:(i+1)*25].mean()*(1-0.98)
    test_close(val/(1-0.98**(i+1)), tst.value)

### ValueMetric -

In [None]:
#|export
class ValueMetricX(MetricX):
    "Use to include a pre-calculated metric value (for instance calculated in a `Callback`) and returned by `func`"
    def __init__(self, func, name=None, log_metric=None):
        super().__init__(log_metric=log_metric)
        self.func = func
        self._name = ifnone(name, self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__)

    @property
    def value(self): return self.func()

In [None]:
#|hide
def metric_value_fn(): return 5e-3

vm = ValueMetricX(metric_value_fn, 'custom_value_metric')
test_eq(vm.value, 5e-3)
test_eq(vm.name, 'custom_value_metric')

vm = ValueMetricX(metric_value_fn)
test_eq(vm.name, 'metric_value_fn')

## Recorder -
Patch `Recorder` to use fastxtend metrics.

In [None]:
#|exporti
def _dedup_metric_names(metrics, names):
    dup = set()
    log = metrics.map(lambda o: (LogMetric.Valid if o.log_metric==LogMetric.Both else o.log_metric) if hasattr(o, 'log_metric') else LogMetric.Valid)
    dups = L(set([o[1] for o in zip(log, names) if o in dup or dup.add(o)]))
    indices = names.argwhere(lambda o: o in dups)
    for i in indices:
        if hasattr(metrics[i], 'func'): 
            name = metrics[i].func.func.__name__ if hasattr(metrics[i].func, 'func') else metrics[i].func.__name__
        else:
            if isinstance(metrics[i], MetricX): 
                name = class2attr(metrics[i], 'MetricX')
            else:
                name = class2attr(metrics[i], 'Metric')
        if not hasattr(metrics[i], 'name') or metrics[i].name == name: # only deduplicate default metric names
            if isinstance(metrics[i], (AvgMetricX, AvgMetric)): 
                names[i] = f'avg_{names[i]}'
            elif isinstance(metrics[i], (AccumMetricX, AccumMetric)): 
                names[i] = f'accm_{names[i]}'
            elif isinstance(metrics[i], AvgSmoothMetricX): 
                names[i] = f'smth_{names[i]}'
    return names

In [None]:
#|exporti
@patch
def __init__(self:Recorder, add_time=True, train_metrics=False, valid_metrics=True, beta=0.98):
    store_attr('add_time')
    self.loss,self.smooth_loss = AvgLossX(),AvgSmoothLossX(beta=beta)

@patch
def before_fit(self:Recorder):
    "Prepare state for training"
    self.lrs,self.iters,self.losses,self.values = [],[],[],[]
    names = self.metrics.attrgot('name')
    if len(names.unique()) != len(names): 
        names = _dedup_metric_names(self.metrics, names)
    train = self.metrics.argwhere(lambda o: hasattr(o, 'log_metric') and o.log_metric != LogMetric.Valid)
    valid = self.metrics.argwhere(lambda o: not hasattr(o, 'log_metric') or o.log_metric != LogMetric.Train)
    self._train_metsX = self.metrics[train]
    self._valid_metsX = self.metrics[valid]
    train_names = names[train] 
    valid_names = names[valid]
    if len(self._train_metsX) > 0:
        train_names = train_names.map('train_{}')
        valid_names = valid_names.map('valid_{}')
    smooth = self._train_metsX.argwhere(lambda o: isinstance(o, (AvgSmoothLossX, AvgSmoothMetricX)))
    self.smooth_mets  = self._train_metsX[smooth]
    self.smooth_names = train_names[smooth]
    self.train_names = L('train_loss') + train_names
    self.valid_names = L('valid_loss') + valid_names
    self.metric_names = L('epoch') + self.train_names + self.valid_names
    if self.add_time: self.metric_names.append('time')
    self.smooth_loss.reset()
    self.loss.reset()
    self.smooth_mets.map(Self.reset())

@patch
def after_batch(self:Recorder):
    "Update all metrics and records lr and smooth loss in training"
    if len(self.yb) == 0: return
    mets = self.train_mets() if self.training else self.valid_mets()
    for met in mets: 
        met.accumulate(self.learn)
    if not self.training: return
    self.lrs.append(self.opt.hypers[-1]['lr'])
    self.losses.append(self.smooth_loss.value)
    self.learn.smooth_loss = self.smooth_loss.value

@patch
def before_epoch(self:Recorder):
    "Set timer if `self.add_time=True`"
    self.cancel_train,self.cancel_valid = False,False
    if self.add_time: self.start_epoch = time.time()
    self.log = L(getattr(self, 'epoch', 0))

@patch
def before_train(self:Recorder): 
    self._train_metsX.filter(lambda o: not isinstance(o, (AvgSmoothLossX, AvgSmoothMetricX))).map(Self.reset())

@patch
def before_validate(self:Recorder):
    self.valid_mets().map(Self.reset())

@patch
def after_train(self:Recorder):
    self.log += self.train_mets().map(_maybe_item)

@patch
def after_validate(self:Recorder):
    self.log += self.valid_mets().map(_maybe_item)

@patch
def after_epoch(self:Recorder):
    "Store and log the loss/metric values"
    self.learn.final_record = self.log[1:].copy()
    self.values.append(self.learn.final_record)
    if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))
    self.logger(self.log)
    self.iters.append(self.smooth_loss.count)

@patch
def train_mets(self:Recorder):
    if getattr(self, 'cancel_train', False): return L()
    return L(self.smooth_loss) + self._train_metsX

@patch
def valid_mets(self:Recorder):
    if getattr(self, 'cancel_valid', False): return L()
    return L(self.loss) + self._valid_metsX

In [None]:
#|hide
def tst_metric(inp, targ): return F.mse_loss(inp, targ)

def TestAvg(log_metric=LogMetric.Valid):
    return AvgMetricX(tst_metric, log_metric=log_metric)

def TestAccum(log_metric=LogMetric.Valid):
    return AccumMetricX(tst_metric, log_metric=log_metric)

def TestSmooth():
    return AvgSmoothMetricX(tst_metric)

In [None]:
#|hide
#Test printed output
learn = synth_learner(n_trn=5, metrics=TestAvg())
# pat = r"[tensor\(\d.\d*\), tensor\(\d.\d*\), tensor\(\d.\d*\), 'dd:dd']"
pat = r"\[\d, \d+.\d+, \d+.\d+, \d+.\d+, '\d\d:\d\d'\]"
test_stdout(lambda: learn.fit(1), pat, regex=True)

In [None]:
#|hide
class TestRecorderCallback(Callback):
    order=51
    def before_fit(self): 
        self.add_time = self.recorder.add_time
        self.beta = self.recorder.smooth_loss.beta
        for m in self.metrics: assert isinstance(m, Metric)
        test_eq(self.recorder.smooth_loss.val, 0.)
        #To test what the recorder logs, we use a custom logger function.
        self.learn.logger = self.test_log
        self.old_smooth,self.count = tensor(0.),0
    
    def after_batch(self):
        if self.training:
            self.count += 1
            test_eq(len(self.recorder.lrs), self.count)
            test_eq(self.recorder.lrs[-1], self.opt.hypers[-1]['lr'])
            test_eq(len(self.recorder.losses), self.count)
            smooth = (1 - self.beta**(self.count-1)) * self.old_smooth * self.beta + self.loss * (1-self.beta)
            smooth /= 1 - self.beta**self.count
            test_close(self.recorder.losses[-1], smooth, eps=1e-4)
            test_close(self.smooth_loss, smooth, eps=1e-4)
            self.old_smooth = self.smooth_loss
        self.bs += find_bs(self.yb)
        if not self.training: test_eq(self.recorder.loss.count, self.bs)
        for m in self.recorder.train_mets() if self.training else self.recorder.valid_mets(): 
            if isinstance(m, (AvgMetricX, AvgMetric)): test_eq(m.count, self.bs)
            if isinstance(m, AvgSmoothMetricX): test_eq(m.count, self.count)
        self.losses.append(self.loss.detach().cpu())
    
    def before_epoch(self): 
        if self.add_time: self.start_epoch = time.time()
        self.log = [self.epoch]

    def before_train(self):
        self.bs = 0
        self.losses = []
        for m in self.recorder.train_mets(): 
            if isinstance(m, (AvgMetricX, AvgMetric)): test_eq(m.count, self.bs)
            
    def after_train(self):
        val=0
        if isinstance(self.recorder.train_mets()[-1], AvgSmoothMetricX):
            for i, j in enumerate(self.losses):
                val = val*0.98 + j*(1-0.98)
            mean = val/(1-0.98**(i+1))
        else:
            mean = tensor(self.losses).mean()
        self.log += [self.smooth_loss, mean] if len(self.recorder.train_mets()) > 1 else [self.smooth_loss]
        test_close(self.log, self.recorder.log, 1e-4)
        self.losses = []
    
    def before_validate(self):
        self.bs = 0
        self.losses = []
        for m in [self.recorder.loss] + self.recorder.valid_mets(): 
            if isinstance(m, (AvgMetricX, AvgMetric, AvgLossX)): test_eq(m.count, self.bs)
    
    def test_log(self, log, eps=1e-4):
        res = tensor(self.losses).mean()
        self.log += [res, res]
        if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))
        test_close(log[:-1], self.log[:-1], eps)
        test_eq(log[-1], self.log[-1])

In [None]:
#|hide
learn = synth_learner(n_trn=5, metrics=TestAvg(), cbs=TestRecorderCallback)
learn.fit(1)
test_eq(learn.recorder.metric_names, ['epoch', 'train_loss', 'valid_loss', 'tst_metric', 'time'])

learn = synth_learner(n_trn=5, metrics=TestAvg(log_metric=LogMetric.Both), cbs=TestRecorderCallback)
learn.fit(1)
test_eq(learn.recorder.metric_names, 
        ['epoch', 'train_loss', 'train_tst_metric', 'valid_loss', 'valid_tst_metric', 'time'])

learn = synth_learner(n_trn=5, metrics=[TestAvg(), TestAccum()], cbs=TestRecorderCallback)
learn.fit(1)
test_eq(learn.recorder.metric_names, 
        ['epoch', 'train_loss', 'valid_loss', 'avg_tst_metric', 'accm_tst_metric', 'time'])

learn = synth_learner(n_trn=5, metrics=[TestAvg(log_metric=LogMetric.Both), TestAccum()], cbs=TestRecorderCallback)
learn.fit(1)
test_eq(learn.recorder.metric_names, 
        ['epoch', 'train_loss', 'train_avg_tst_metric', 'valid_loss', 'valid_avg_tst_metric', 'valid_accm_tst_metric', 'time'])

learn = synth_learner(n_trn=5, metrics=TestAvg(), cbs=TestRecorderCallback)
learn.recorder.add_time=False
learn.fit(1)
test_eq(learn.recorder.metric_names, ['epoch', 'train_loss', 'valid_loss', 'tst_metric'])

learn = synth_learner(n_trn=5, metrics=[TestAvg(), TestSmooth()], cbs=TestRecorderCallback)
learn.fit(1)
test_eq(learn.recorder.metric_names, ['epoch','train_loss','train_tst_metric','valid_loss','valid_tst_metric','time'])

In [None]:
#|hide
# test that fastai metrics still work
from fastai.metrics import mse as fastai_mse
learn = synth_learner(n_trn=5, metrics=fastai_mse, cbs=TestRecorderCallback)
learn.fit(1)
test_eq(learn.recorder.metric_names, ['epoch','train_loss','valid_loss','mse','time'])

## Metrics

In [None]:
#|hide
#For testing: a fake learner and a metric that isn't an average
class TstLearner(Learner):
    def __init__(self,dls=None,model=None,**kwargs): self.pred,self.xb,self.yb = None,None,None

In [None]:
#|hide
def _l2_mean(x,y): return torch.sqrt((x.float()-y.float()).pow(2).mean())

#Go through a fake cycle with various batch sizes and computes the value of met
def compute_val(met, x1, x2):
    met.reset()
    vals = [0,6,15,20]
    learn = TstLearner()
    for i in range(3):
        learn.pred,learn.yb = x1[vals[i]:vals[i+1]],(x2[vals[i]:vals[i+1]],)
        met.accumulate(learn)
    return met.value

In [None]:
#|hide
def _l2_mean(x,y): return torch.sqrt((x.argmax(dim=-1).float()-y.float()).pow(2).mean())
x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))
tst = AccumMetricX(_l2_mean, dim_argmax=-1, flatten=False, activation=ActivationType.Softmax)
test_close(compute_val(tst, x1, x2), _l2_mean(F.softmax(x1, dim=-1), x2))

### Custom Metric Creation

In [None]:
#|export
@delegates(MetricX)
def func_to_metric(func, metric_type, is_class, thresh=None, axis=-1, activation=None, log_metric=LogMetric.Valid, **kwargs):
    "Convert `func` metric to a fastai metric"

    dim_argmax = axis if is_class and thresh is None else None
    if activation is None:
        activation = ActivationType.Sigmoid if (is_class and thresh is not None) else ActivationType.No

    if metric_type==MetricType.Accum:
        return AccumMetricX(func, dim_argmax=dim_argmax, activation=activation, 
                            thresh=thresh, log_metric=log_metric, **kwargs)
    elif metric_type==MetricType.Avg:
        return AvgMetricX(func, dim_argmax=dim_argmax, activation=activation, 
                          thresh=thresh, log_metric=log_metric, **kwargs)
    elif metric_type==MetricType.Smooth:
        if log_metric!=LogMetric.Train: 
            name = func.func.__name__ if hasattr(func, 'func') else  func.__name__
            raise ValueError(f'Error with {name}: AvgSmoothMetricX can only run on train. Set `log_metric` to LogMetric.Train.')
        return AvgSmoothMetricX(func, dim_argmax=dim_argmax, activation=activation, thresh=thresh, **kwargs)
    else:
        name = func.func.__name__ if hasattr(func, 'func') else  func.__name__
        raise ValueError(f"Unsupported `metric_type` {metric_type} for metric {name}.")

This is the quickest way to use a functional metric as a fastxtend metric.

`metric_type` is one of `MetricType.Avg`, `MetricType.Accum`, or `MetricType.Smooth` which set the metric to use `AvgMetricX`, `AccumMetricX`, or `AvgSmoothMetricX`, respectively. 

`is_class` indicates if you are in a classification problem or not. In this case:
- leaving `thresh` to `None` indicates it's a single-label classification problem and predictions will pass through an argmax over `axis` before being compared to the targets
- setting a value for `thresh` indicates it's a multi-label classification problem and predictions will pass through a sigmoid (can be deactivated with `sigmoid=False`) and be compared to `thresh` before being compared to the targets

If `is_class=False`, it indicates you are in a regression problem, and predictions are compared to the targets without being modified. In all cases, `kwargs` are extra keyword arguments passed to `func`.

> Important: Some metrics, like Root Mean Squared Error, will have incorrect results if passed to `AvgMetricX` via `MetricType.Avg`, as the mean of multiple batches of RMSE isn't equal to the RMSE of the whole dataset. For these metrics use `AccumMetricX` by setting `metric_type` to `MetricType.Accum`.

In [None]:
#|export
@delegates(MetricX)
def skm_to_fastxtend(func, is_class=True, thresh=None, axis=-1, activation=None, log_metric=LogMetric.Valid, **kwargs):
    "Convert `func` from sklearn.metrics to a fastai metric"
    return func_to_metric(func, MetricType.Accum, is_class, thresh, axis, activation, 
                          log_metric, to_np=True, invert_arg=True, **kwargs)

This is the quickest way to use a scikit-learn metric using fastxtend metrics. It is the same as `func_to_metric` except it defaults to using `AccumMetricX`.

In [None]:
#|hide
tst_single = skm_to_fastxtend(skm.precision_score)
x1,x2 = torch.randn(20,2),torch.randint(0, 2, (20,))
test_close(compute_val(tst_single, x1, x2), skm.precision_score(x2, x1.argmax(dim=-1)))

In [None]:
#|hide
tst_multi = skm_to_fastxtend(skm.precision_score, thresh=0.2)
x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))
test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, torch.sigmoid(x1) >= 0.2))

tst_multi = skm_to_fastxtend(skm.precision_score, thresh=0.2, activation=ActivationType.No)
x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))
test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, x1 >= 0.2))

In [None]:
#|hide
tst_reg = skm_to_fastxtend(skm.r2_score, is_class=False)
x1,x2 = torch.randn(20,5),torch.randn(20,5)
test_close(compute_val(tst_reg, x1, x2), skm.r2_score(x2.view(-1), x1.view(-1)))

In [None]:
#|hide
test_close(tst_reg(x1, x2), skm.r2_score(x2.view(-1), x1.view(-1)))

## Single-label classification

> Warning: All functions defined in this section are intended for single-label classification and targets that are not one-hot encoded. For multi-label problems or one-hot encoded targets, use the version suffixed with multi.

> Warning: Many metrics in fastxtend are thin wrappers around sklearn functionality. However, sklearn metrics can handle python list strings, amongst other things, whereas fastxtend metrics work with PyTorch, and thus require tensors. The arguments that are passed to metrics are after all transformations, such as categories being converted to indices, have occurred. This means that when you pass a label of a metric, for instance, that you must pass indices, not strings. This can be converted with `vocab.map_obj`.

In [None]:
#|exporti
def accuracy(inp, targ):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    pred,targ = flatten_check(inp, targ)
    return (pred == targ).float().mean()

In [None]:
#|export
def Accuracy(axis=-1, metric_type=MetricType.Avg, log_metric=LogMetric.Valid, **kwargs):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    return func_to_metric(accuracy, metric_type, True, axis=axis, log_metric=log_metric, **kwargs)

In [None]:
#|hide
#For testing
def change_targ(targ, n, c):
    idx = torch.randperm(len(targ))[:n]
    res = targ.clone()
    for i in idx: res[i] = (res[i]+random.randint(1,c-1))%c
    return res

In [None]:
#|hide
#For testing
def compute_single(met, x1, x2):
    met.reset()
    learn = TstLearner()
    learn.pred,learn.yb = x1,(x2,)
    met.accumulate(learn)
    return met.value

In [None]:
#|hide
x = torch.randn(4,5)
y = x.argmax(dim=1)
test_eq(compute_single(Accuracy(), x, y), 1)
y1 = change_targ(y, 2, 5)
test_eq(compute_single(Accuracy(),x,y1), 0.5)
test_eq(compute_single(Accuracy(),x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.75)

In [None]:
#|exporti
def error_rate(inp, targ):
    "1 - `accuracy`"
    return 1 - accuracy(inp, targ)

In [None]:
#|export
def ErrorRate(axis=-1, metric_type=MetricType.Avg, log_metric=LogMetric.Valid, **kwargs):
    "Compute 1 - accuracy with `targ` when `pred` is bs * n_classes"
    return func_to_metric(error_rate, metric_type, True, axis=axis, log_metric=log_metric, **kwargs)

In [None]:
#|hide
x = torch.randn(4,5)
y = x.argmax(dim=1)
test_eq(compute_single(ErrorRate(), x,y), 0)
y1 = change_targ(y, 2, 5)
test_eq(compute_single(ErrorRate(), x,y1), 0.5)
test_eq(compute_single(ErrorRate(), x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.25)

In [None]:
#|exporti
def top_k_accuracy(inp, targ, k=5, axis=-1):
    "Computes the Top-k accuracy (`targ` is in the top `k` predictions of `inp`)"
    inp = inp.topk(k=k, dim=axis)[1]
    targ = targ.unsqueeze(dim=axis).expand_as(inp)
    return (inp == targ).sum(dim=-1).float().mean()

In [None]:
#|export
def TopKAccuracy(k=5, axis=-1, metric_type=MetricType.Avg, log_metric=LogMetric.Valid, **kwargs):
    "Computes the Top-k accuracy (`targ` is in the top `k` predictions of `inp`)"
    return func_to_metric(partial(top_k_accuracy, k=k, axis=axis), metric_type, False, 
                          log_metric=log_metric, **kwargs)

In [None]:
#|hide
x = torch.randn(6,5)
y = torch.arange(0,6)
test_eq(compute_single(TopKAccuracy(), x[:5],y[:5]), 1)
test_eq(compute_single(TopKAccuracy(), x, y), 5/6)

In [None]:
#|export
def APScoreBinary(axis=-1, average='macro', pos_label=1, sample_weight=None, log_metric=LogMetric.Valid, **kwargs):
    "Average Precision for single-label binary classification problems"
    return skm_to_fastxtend(skm.average_precision_score, axis=axis, activation=ActivationType.BinarySoftmax,
                         average=average, pos_label=pos_label, sample_weight=sample_weight, log_metric=log_metric, 
                         **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) for more details.

In [None]:
#|export
def BalancedAccuracy(axis=-1, sample_weight=None, adjusted=False, log_metric=LogMetric.Valid, **kwargs):
    "Balanced Accuracy for single-label binary classification problems"
    return skm_to_fastxtend(skm.balanced_accuracy_score, axis=axis,
                         sample_weight=sample_weight, adjusted=adjusted, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html#sklearn.metrics.balanced_accuracy_score) for more details.

In [None]:
#|export
def BrierScore(axis=-1, sample_weight=None, pos_label=None, log_metric=LogMetric.Valid, **kwargs):
    "Brier score for single-label classification problems"
    return skm_to_fastxtend(skm.brier_score_loss, axis=axis,
                         sample_weight=sample_weight, pos_label=pos_label, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.brier_score_loss.html#sklearn.metrics.brier_score_loss) for more details.

In [None]:
#|export
def CohenKappa(axis=-1, labels=None, weights=None, sample_weight=None, log_metric=LogMetric.Valid, **kwargs):
    "Cohen kappa for single-label classification problems"
    return skm_to_fastxtend(skm.cohen_kappa_score, axis=axis, labels=labels, weights=weights,
                         sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cohen_kappa_score.html#sklearn.metrics.cohen_kappa_score) for more details.

In [None]:
#|export
def F1Score(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None, log_metric=LogMetric.Valid, **kwargs):
    "F1 score for single-label classification problems"
    return skm_to_fastxtend(skm.f1_score, axis=axis, labels=labels, pos_label=pos_label, 
                         average=average, sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) for more details.

In [None]:
#|export
def FBeta(beta, axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None, 
          log_metric=LogMetric.Valid, **kwargs):
    "FBeta score with `beta` for single-label classification problems"
    return skm_to_fastxtend(skm.fbeta_score, axis=axis, beta=beta, labels=labels, pos_label=pos_label, 
                         average=average, sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html#sklearn.metrics.fbeta_score) for more details.

In [None]:
#|export
def HammingLoss(axis=-1, sample_weight=None, log_metric=LogMetric.Valid, **kwargs):
    "Hamming loss for single-label classification problems"
    return skm_to_fastxtend(skm.hamming_loss, axis=axis,
                         sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html#sklearn.metrics.hamming_loss) for more details.

In [None]:
#|export
def Jaccard(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None, 
            log_metric=LogMetric.Valid, **kwargs):
    "Jaccard score for single-label classification problems"
    return skm_to_fastxtend(skm.jaccard_score, axis=axis, labels=labels, pos_label=pos_label, 
                         average=average, sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html#sklearn.metrics.jaccard_score) for more details.

In [None]:
#|export
def Precision(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None, 
              log_metric=LogMetric.Valid, **kwargs):
    "Precision for single-label classification problems"
    return skm_to_fastxtend(skm.precision_score, axis=axis, labels=labels, pos_label=pos_label, 
                         average=average, sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score) for more details.

In [None]:
#|export
def Recall(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None, 
           log_metric=LogMetric.Valid, **kwargs):
    "Recall for single-label classification problems"
    return skm_to_fastxtend(skm.recall_score, axis=axis, labels=labels, pos_label=pos_label, 
                         average=average, sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score) for more details.

In [None]:
#|export
def RocAuc(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='ovr', 
           log_metric=LogMetric.Valid, **kwargs):
    "Area Under the Receiver Operating Characteristic Curve for single-label multiclass classification problems"
    assert multi_class in ['ovr', 'ovo']
    return skm_to_fastxtend(skm.roc_auc_score, axis=axis, activation=ActivationType.Softmax, flatten=False,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class, 
                         log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details.

In [None]:
#|export
def RocAucBinary(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='raise', 
                 log_metric=LogMetric.Valid, **kwargs):
    "Area Under the Receiver Operating Characteristic Curve for single-label binary classification problems"
    return skm_to_fastxtend(skm.roc_auc_score, axis=axis, activation=ActivationType.BinarySoftmax,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class, 
                         log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details.

In [None]:
#|export
def MatthewsCorrCoef(sample_weight=None, log_metric=LogMetric.Valid, **kwargs):
    "Matthews correlation coefficient for single-label classification problems"
    return skm_to_fastxtend(skm.matthews_corrcoef, sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html#sklearn.metrics.matthews_corrcoef) for more details.

## Multi-label classification

In [None]:
#|exporti
def accuracy_multi(inp, targ):
    "Compute accuracy when `inp` and `targ` are the same size."
    inp,targ = flatten_check(inp,targ)
    return (inp==targ.bool()).float().mean()

In [None]:
#|export
def AccuracyMulti(thresh=0.5, sigmoid=True, metric_type=MetricType.Avg, log_metric=LogMetric.Valid, **kwargs):
    "Compute accuracy when `inp` and `targ` are the same size."
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return func_to_metric(accuracy_multi, metric_type, False, thresh=thresh, activation=activation, 
                          log_metric=log_metric, **kwargs)

In [None]:
#|hide
#For testing
def change_1h_targ(targ, n):
    idx = torch.randperm(targ.numel())[:n]
    res = targ.clone().view(-1)
    for i in idx: res[i] = 1-res[i]
    return res.view(targ.shape)

In [None]:
#|hide
x = torch.randn(4,5)
y = (torch.sigmoid(x) >= 0.5).byte()
test_eq(compute_single(AccuracyMulti(),x,y), 1)
test_eq(compute_single(AccuracyMulti(),x,1-y), 0)
y1 = change_1h_targ(y, 5)
test_eq(compute_single(AccuracyMulti(),x,y1), 0.75)

#Different thresh
y = (torch.sigmoid(x) >= 0.2).byte()
test_eq(compute_single(AccuracyMulti(thresh=0.2),x,y), 1)
test_eq(compute_single(AccuracyMulti(thresh=0.2),x,1-y), 0)
y1 = change_1h_targ(y, 5)
test_eq(compute_single(AccuracyMulti(thresh=0.2),x,y1), 0.75)

#No sigmoid
y = (x >= 0.5).byte()
test_eq(compute_single(AccuracyMulti(sigmoid=False),x,y), 1)
test_eq(compute_single(AccuracyMulti(sigmoid=False),x,1-y), 0)
y1 = change_1h_targ(y, 5)
test_eq(compute_single(AccuracyMulti(sigmoid=False),x,y1), 0.75)

In [None]:
#|export
def APScoreMulti(sigmoid=True, average='macro', pos_label=1, sample_weight=None, 
                 log_metric=LogMetric.Valid, **kwargs):
    "Average Precision for multi-label classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.average_precision_score, activation=activation, flatten=False,
                         average=average, pos_label=pos_label, sample_weight=sample_weight, 
                         log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) for more details.

In [None]:
#|export
def BrierScoreMulti(thresh=0.5, sigmoid=True, sample_weight=None, pos_label=None, 
                    log_metric=LogMetric.Valid, **kwargs):
    "Brier score for multi-label classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.brier_score_loss, thresh=thresh, activation=activation, flatten=False,
                         sample_weight=sample_weight, pos_label=pos_label, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.brier_score_loss.html#sklearn.metrics.brier_score_loss) for more details.

In [None]:
#|export
def F1ScoreMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None, 
                 log_metric=LogMetric.Valid, **kwargs):
    "F1 score for multi-label classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.f1_score, thresh=thresh, activation=activation, flatten=False,
                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, 
                         log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) for more details.

In [None]:
#|export
def FBetaMulti(beta, thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None, 
               log_metric=LogMetric.Valid, **kwargs):
    "FBeta score with `beta` for multi-label classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.fbeta_score, thresh=thresh, activation=activation, flatten=False,
                beta=beta, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, 
                log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html#sklearn.metrics.fbeta_score) for more details.

In [None]:
#|export
def HammingLossMulti(thresh=0.5, sigmoid=True, labels=None, sample_weight=None, 
                     log_metric=LogMetric.Valid, **kwargs):
    "Hamming loss for multi-label classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.hamming_loss, thresh=thresh, activation=activation, flatten=False,
                         sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html#sklearn.metrics.hamming_loss) for more details.

In [None]:
#|export
def JaccardMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None, 
                 log_metric=LogMetric.Valid, **kwargs):
    "Jaccard score for multi-label classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.jaccard_score, thresh=thresh, activation=activation, flatten=False,
                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, 
                         log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html#sklearn.metrics.jaccard_score) for more details.

In [None]:
#|export
def MatthewsCorrCoefMulti(thresh=0.5, sigmoid=True, sample_weight=None, log_metric=LogMetric.Valid, **kwargs):
    "Matthews correlation coefficient for multi-label classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.matthews_corrcoef, thresh=thresh, activation=activation, flatten=False, 
                         sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html#sklearn.metrics.matthews_corrcoef) for more details.

In [None]:
#|export
def PrecisionMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None, 
                   log_metric=LogMetric.Valid, **kwargs):
    "Precision for multi-label classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.precision_score, thresh=thresh, activation=activation, flatten=False,
                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, 
                         log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score) for more details.

In [None]:
#|export
def RecallMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None, 
                log_metric=LogMetric.Valid, **kwargs):
    "Recall for multi-label classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.recall_score, thresh=thresh, activation=activation, flatten=False,
                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, 
                         log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score) for more details.

In [None]:
#|export
def RocAucMulti(sigmoid=True, average='macro', sample_weight=None, max_fpr=None, log_metric=LogMetric.Valid, **kwargs):
    "Area Under the Receiver Operating Characteristic Curve for multi-label binary classification problems"
    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No
    return skm_to_fastxtend(skm.roc_auc_score, activation=activation, flatten=False,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr, log_metric=log_metric, **kwargs)

In [None]:
#|hide
roc_auc_metric = RocAucMulti(sigmoid=False)
x,y = torch.tensor([np.arange(start=0, stop=0.2, step=0.04)]*20), torch.tensor([0, 0, 1, 1]).repeat(5)
assert compute_val(roc_auc_metric, x, y) == 0.5

  x,y = torch.tensor([np.arange(start=0, stop=0.2, step=0.04)]*20), torch.tensor([0, 0, 1, 1]).repeat(5)


See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details.

## Regression

In [None]:
#|exporti
def mse(inp,targ):
    "Mean squared error between `inp` and `targ`."
    return F.mse_loss(*flatten_check(inp,targ))

In [None]:
#|export
def MSE(metric_type=MetricType.Avg, log_metric=LogMetric.Valid, **kwargs):
    "Mean squared error between `inp` and `targ`."
    return func_to_metric(mse, metric_type, False, log_metric=log_metric, **kwargs)

In [None]:
#|hide
x1,x2 = torch.randn(4,5),torch.randn(4,5)
test_close(compute_single(MSE(),x1,x2), (x1-x2).pow(2).mean())

In [None]:
#|exporti
def rmse(inp, targ): 
    return torch.sqrt(F.mse_loss(inp, targ))

In [None]:
#|export
def RMSE(log_metric=LogMetric.Valid, **kwargs):
    "Mean squared error between `inp` and `targ`."
    return func_to_metric(rmse, MetricType.Accum, False, log_metric=log_metric, **kwargs)

In [None]:
#|hide
x1,x2 = torch.randn(20,5),torch.randn(20,5)
test_eq(compute_val(RMSE(), x1, x2), torch.sqrt(F.mse_loss(x1,x2)))

In [None]:
#|exporti
def mae(inp,targ):
    "Mean absolute error between `inp` and `targ`."
    inp,targ = flatten_check(inp,targ)
    return torch.abs(inp - targ).mean()

In [None]:
#|export
def MAE(metric_type=MetricType.Avg, log_metric=LogMetric.Valid, **kwargs):
    "Mean absolute error between `inp` and `targ`."
    return func_to_metric(mae, metric_type, False, log_metric=log_metric, **kwargs)

In [None]:
#|hide
x1,x2 = torch.randn(4,5),torch.randn(4,5)
test_eq(compute_single(MAE(),x1,x2), torch.abs(x1-x2).mean())

In [None]:
#|exporti
def msle(inp, targ):
    "Mean squared logarithmic error between `inp` and `targ`."
    inp,targ = flatten_check(inp,targ)
    return F.mse_loss(torch.log(1 + inp), torch.log(1 + targ))

In [None]:
#|export
def MSLE(metric_type=MetricType.Avg, log_metric=LogMetric.Valid, **kwargs):
    "Mean squared logarithmic error between `inp` and `targ`."
    return func_to_metric(msle, metric_type, False, log_metric=log_metric, **kwargs)

In [None]:
#|hide
x1,x2 = torch.randn(4,5),torch.randn(4,5)
x1,x2 = torch.relu(x1),torch.relu(x2)
test_close(compute_single(MSLE(), x1,x2), (torch.log(x1+1)-torch.log(x2+1)).pow(2).mean())

In [None]:
#|exporti
def exp_rmspe(inp,targ):
    inp,targ = torch.exp(inp),torch.exp(targ)
    return torch.sqrt(((targ - inp)/targ).pow(2).mean())

In [None]:
#|export
def ExpRMSE(log_metric=LogMetric.Valid, **kwargs):
    "Root mean square percentage error of the exponential of  predictions and targets"
    return func_to_metric(exp_rmspe, MetricType.Accum, False, log_metric=log_metric, **kwargs)

In [None]:
#|hide
x1,x2 = torch.randn(20,5),torch.randn(20,5)
test_eq(compute_val(ExpRMSE(), x1, x2), torch.sqrt((((torch.exp(x2) - torch.exp(x1))/torch.exp(x2))**2).mean()))

In [None]:
#|export
def ExplainedVariance(sample_weight=None, log_metric=LogMetric.Valid, **kwargs):
    "Explained variance between predictions and targets"
    return skm_to_fastxtend(skm.explained_variance_score, is_class=False, 
                         sample_weight=sample_weight, log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.explained_variance_score.html#sklearn.metrics.explained_variance_score) for more details.

In [None]:
#|export
def R2Score(sample_weight=None, log_metric=LogMetric.Valid, **kwargs):
    "R2 score between predictions and targets"
    return skm_to_fastxtend(skm.r2_score, is_class=False, sample_weight=sample_weight, 
                         log_metric=log_metric, **kwargs)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html#sklearn.metrics.r2_score) for more details.

In [None]:
#|export
def PearsonCorrCoef(dim_argmax=None, log_metric=LogMetric.Valid, **kwargs):
    "Pearson correlation coefficient for regression problem"
    def pearsonr(x,y): return scs.pearsonr(x,y)[0]
    return AccumMetricX(pearsonr, invert_arg=False, dim_argmax=dim_argmax, 
                        log_metric=log_metric, **kwargs)

See the [scipy documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html?highlight=pearson#scipy.stats.pearsonr) for more details.

In [None]:
#|hide
x = torch.randint(-999, 999,(20,))
y = torch.randint(-999, 999,(20,))
test_eq(compute_val(PearsonCorrCoef(), x, y), scs.pearsonr(x.view(-1), y.view(-1))[0])

In [None]:
#|export
def SpearmanCorrCoef(dim_argmax=None, axis=0, nan_policy='propagate', log_metric=LogMetric.Valid, **kwargs):
    "Spearman correlation coefficient for regression problem"
    def spearmanr(a,b=None,**kwargs): return scs.spearmanr(a,b,**kwargs)[0]
    return AccumMetricX(partial(spearmanr, axis=axis, nan_policy=nan_policy),
                        invert_arg=False, dim_argmax=dim_argmax, log_metric=log_metric, **kwargs)

See the [scipy documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html?highlight=spearman#scipy.stats.spearmanr) for more details.

In [None]:
#|hide
x = torch.randint(-999, 999,(20,))
y = torch.randint(-999, 999,(20,))
test_eq(compute_val(SpearmanCorrCoef(), x, y), scs.spearmanr(x.view(-1), y.view(-1))[0])

## Segmentation

In [None]:
#|hide
from fastai.vision.models import resnet18
model = resnet18()
x = cast(torch.rand(1,3,128,128), TensorImage)
type(model(x))

fastai.torch_core.TensorImage

In [None]:
#|exporti
def foreground_acc(inp, targ, bkg_idx=0, axis=1):
    "Computes non-background accuracy for multiclass segmentation"
    targ = cast(targ.squeeze(1), TensorBase)
    mask = targ != bkg_idx
    return (inp[mask]==targ[mask]).float().mean()

In [None]:
#|export
def ForegroundAcc(bkg_idx=0, axis=1, metric_type=MetricType.Avg, log_metric=LogMetric.Valid, **kwargs):
    "Computes non-background accuracy for multiclass segmentation"
    return func_to_metric(foreground_acc, metric_type, True, bkg_idx=bkg_idx, axis=axis, 
                          log_metric=log_metric, **kwargs)

In [None]:
#|hide
x = cast(torch.randn(4,5,3,3), TensorImage)
y = cast(x, TensorMask).argmax(dim=1)[:,None]
test_eq(compute_single(ForegroundAcc(),x,y), 1)
y[0] = 0 #the 0s are ignored so we get the same value
test_eq(compute_single(ForegroundAcc(),x,y), 1)

In [None]:
#|export
class Dice(MetricX):
    "Dice coefficient metric for binary target in segmentation"
    def __init__(self, axis=1, log_metric=LogMetric.Valid, **kwargs):
        super().__init__(dim_argmax=axis, log_metric=log_metric, **kwargs)
    def reset(self): self.inter,self.union = 0,0
    def accumulate(self, learn):
        super().accumulate(learn)
        self.pred,self.targ = flatten_check(self.pred, self.targ)
        self.inter += (self.pred*self.targ).float().sum().item()
        self.union += (self.pred+self.targ).float().sum().item()

    @property
    def value(self): return 2. * self.inter/self.union if self.union > 0 else None

In [None]:
#|hide
x1 = cast(torch.randn(20,2,3,3), TensorImage)
x2 = cast(torch.randint(0, 2, (20, 3, 3)), TensorMask)
pred = x1.argmax(1)
inter = (pred*x2).float().sum().item()
union = (pred+x2).float().sum().item()
test_eq(compute_val(Dice(), x1, x2), 2*inter/union)

In [None]:
#|export
class DiceMulti(MetricX):
    "Averaged Dice metric (Macro F1) for multiclass target in segmentation"
    def __init__(self, axis=1, log_metric=LogMetric.Valid, **kwargs):
        super().__init__(dim_argmax=axis, log_metric=log_metric, **kwargs)
    def reset(self): self.inter,self.union = {},{}
    def accumulate(self, learn):
        super().accumulate(learn)
        self.pred,self.targ = flatten_check(self.pred, self.targ)
        for c in range(learn.pred.shape[self.dim_argmax]):
            p = torch.where(self.pred == c, 1, 0)
            t = torch.where(self.targ == c, 1, 0)
            c_inter = (p*t).float().sum().item()
            c_union = (p+t).float().sum().item()
            if c in self.inter:
                self.inter[c] += c_inter
                self.union[c] += c_union
            else:
                self.inter[c] = c_inter
                self.union[c] = c_union

    @property
    def value(self):
        binary_dice_scores = np.array([])
        for c in self.inter:
            binary_dice_scores = np.append(binary_dice_scores, 2.*self.inter[c]/self.union[c] if self.union[c] > 0 else np.nan)
        return np.nanmean(binary_dice_scores)

The DiceMulti method implements the "Averaged F1: arithmetic mean over harmonic means" described in this publication: https://arxiv.org/pdf/1911.03347.pdf

In [None]:
#|hide
x1a = torch.ones(20,1,1,1)
x1b = torch.clone(x1a)*0.5
x1c = torch.clone(x1a)*0.3
x1 = torch.cat((x1a,x1b,x1c),dim=1)   # Prediction: 20xClass0
x2 = torch.zeros(20,1,1)              # Target: 20xClass0
test_eq(compute_val(DiceMulti(), x1, x2), 1.)

x2 = torch.ones(20,1,1)               # Target: 20xClass1
test_eq(compute_val(DiceMulti(), x1, x2), 0.)

x2a = torch.zeros(10,1,1)
x2b = torch.ones(5,1,1)
x2c = torch.ones(5,1,1) * 2
x2 = torch.cat((x2a,x2b,x2c),dim=0)   # Target: 10xClass0, 5xClass1, 5xClass2
dice1 = (2*10)/(2*10+10)              # Dice: 2*TP/(2*TP+FP+FN)
dice2 = 0
dice3 = 0
test_eq(compute_val(DiceMulti(), x1, x2), (dice1+dice2+dice3)/3)

In [None]:
#|export
class JaccardCoeff(Dice):
    "Implementation of the Jaccard coefficient that is lighter in RAM"
    @property
    def value(self): return self.inter/(self.union-self.inter) if self.union > 0 else None

In [None]:
#|hide
x1 = cast(torch.randn(20,2,3,3), TensorImage)
x2 = cast(torch.randint(0, 2, (20, 3, 3)), TensorMask)
pred = x1.argmax(1)
inter = (pred*x2).float().sum().item()
union = (pred+x2).float().sum().item()
test_eq(compute_val(JaccardCoeff(), x1, x2), inter/(union-inter))

## NLP

In [None]:
#|export
class CorpusBLEUMetric(MetricX):
    "BLEU Metric calculated over the validation corpus"
    def __init__(self, vocab_sz=5000, axis=-1, log_metric=LogMetric.Valid, name='CorpusBLEU', **kwargs):
        super().__init__(log_metric=log_metric, name=name, **kwargs)
        self.axis, self.vocab_sz = axis, vocab_sz
        self.pred_len,self.targ_len,self.samp_idx,self.corrects,self.counts, = 0,0,0,[0]*4,[0]*4

    def reset(self):
        self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4

    class NGram():
        def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n
        def __eq__(self, other):
            if len(self.ngram) != len(other.ngram): return False
            return np.all(np.array(self.ngram) == np.array(other.ngram))
        def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))

    def get_grams(self, x, n, max_n=5000):
        return x if n==1 else [self.NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]

    def get_correct_ngrams(self, pred, targ, n, max_n=5000):
        pred_grams,targ_grams = self.get_grams(pred, n, max_n=max_n),self.get_grams(targ, n, max_n=max_n)
        pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)
        return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)

    def accumulate(self, learn):
        if learn.training: return None
        else:
            last_output = learn.pred.argmax(dim=self.axis)
            last_target = learn.y
            for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):
                self.pred_len += len(pred)
                self.targ_len += len(targ)
                smooth_mteval = 1
                for i in range(4):
                    c,t = self.get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)
                    if c == 0:
                        smooth_mteval *= 2
                        c = 1 / smooth_mteval    # exp smoothing, method 3 from http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf
                    self.corrects[i] += c
                    self.counts[i]   += t

    @property
    def value(self):
        if self.counts == 0: return None
        elif max(self.corrects) == 0: return 0.0
        else:
            precs = [c/t for c,t in zip(self.corrects,self.counts)]
            len_penalty = math.exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1
            return len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25)

In [None]:
#|hide
def create_vcb_emb(pred, targ):
    # create vocab "embedding" for predictions
    vcb_sz = max(torch.unique(torch.cat([pred, targ])))+1
    pred_emb=torch.zeros(pred.size()[0], pred.size()[1] ,vcb_sz)
    for i,v in enumerate(pred):
        pred_emb[i].scatter_(1, v.view(len(v),1),1)
    return pred_emb

def compute_bleu_val(met, x1, x2):
    met.reset()
    learn = TstLearner()
    learn.training=False    
    for i in range(len(x1)): 
        learn.pred,learn.yb = x1, (x2,)
        met.accumulate(learn)
    return met.value

targ = torch.tensor([[1,2,3,4,5,6,1,7,8]]) 
pred = torch.tensor([[1,9,3,4,5,6,1,10,8]])
pred_emb = create_vcb_emb(pred, targ)
test_close(compute_bleu_val(CorpusBLEUMetric(), pred_emb, targ), 0.48549)

targ = torch.tensor([[1,2,3,4,5,6,1,7,8],[1,2,3,4,5,6,1,7,8]]) 
pred = torch.tensor([[1,9,3,4,5,6,1,10,8],[1,9,3,4,5,6,1,10,8]])
pred_emb = create_vcb_emb(pred, targ)
test_close(compute_bleu_val(CorpusBLEUMetric(), pred_emb, targ), 0.48549)

The BLEU metric was introduced in [this article](https://www.aclweb.org/anthology/P02-1040) to come up with a way to evaluate the performance of translation models. It's based on the precision of n-grams in your prediction compared to your target. See the [fastai NLP course BLEU notebook](https://github.com/fastai/course-nlp/blob/master/bleu_metric.ipynb) for a more detailed description of BLEU.

The smoothing used in the precision calculation is the same as in [SacreBLEU](https://github.com/mjpost/sacrebleu/blob/32c54cdd0dfd6a9fadd5805f2ea189ac0df63907/sacrebleu/sacrebleu.py#L540-L542), which in turn is "method 3" from the [Chen & Cherry, 2014](http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf) paper.

In [None]:
#|export
class Perplexity(AvgLossX):
    "Perplexity (exponential of cross-entropy loss) for Language Models"
    @property
    def value(self): return torch.exp(self.total/self.count) if self.count != 0 else None
    @property
    def name(self):  return "perplexity"

perplexity = Perplexity()

In [None]:
#|hide
x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))
tst = perplexity
tst.reset()
vals = [0,6,15,20]
learn = TstLearner()
for i in range(3): 
    learn.yb = (x2[vals[i]:vals[i+1]],)
    learn.loss = F.cross_entropy(x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]])
    tst.accumulate(learn)
test_close(tst.value, torch.exp(F.cross_entropy(x1,x2)))

## LossMetrics -

In [None]:
#|export
class LossMetric(AvgMetricX):
    "Create a metric from `loss_func.attr` named `nm`"
    def __init__(self, attr, nm=None, log_metric=LogMetric.Valid, **kwargs):
        super().__init__(noop, log_metric=log_metric, **kwargs)
        store_attr('attr,nm')
    def accumulate(self, learn):
        bs = find_bs(learn.yb)
        self.total += learn.to_detach(getattr(learn.loss_func, self.attr, 0))*bs
        self.count += bs

    @property
    def name(self): return self.attr if self.nm is None else self.nm

In [None]:
#|export
def LossMetrics(attrs, nms=None):
    "List of `LossMetric` for each of `attrs` and `nms`"
    if isinstance(attrs, str): attrs = attrs.split(',')
    nms = attrs if nms is None else nms.split(',') if isinstance(nms, str) else nms
    return [LossMetric(a, n) for a,n in zip(attrs,nms)]

In [None]:
#|hide
class CombineL1L2(Module):
    def forward(self, out, targ):
        self.l1 = F.l1_loss(out, targ)
        self.l2 = F.mse_loss(out, targ)
        return self.l1+self.l2

In [None]:
#|hide
with less_random():
    learn = synth_learner(metrics=LossMetrics('l1,l2'))
    learn.loss_func = CombineL1L2()
    learn.fit(2)

[0, 9.458013534545898, 8.1944580078125, 2.3591325283050537, 5.835325241088867, '00:00']
[1, 8.172794342041016, 5.658956050872803, 1.8951740264892578, 3.763782024383545, '00:00']


## Logger Patches -

In [None]:
#|exporti
try:
    import wandb
    from fastai.callback.wandb import WandbCallback
    from fastai.callback.training import FetchPredsCallback

    if not hasattr(WandbCallback,'_metrics_before_fit'): WandbCallback._metrics_before_fit = WandbCallback.before_fit

    @patch
    def before_fit(self:WandbCallback):
        self._metrics_before_fit()
        self.log_smooth = len(self.recorder.smooth_names) > 0

    @patch
    def after_batch(self:WandbCallback):
        "Log hyper-parameters and training loss"
        if self.training:
            self._wandb_step += 1
            self._wandb_epoch += 1/self.n_iter
            hypers = {f'{k}_{i}':v for i,h in enumerate(self.opt.hypers) for k,v in h.items()}
            if self.log_smooth: 
                for n,m in zip(self.recorder.smooth_names, self.recorder.smooth_mets): hypers[n]=m.value
            wandb.log({'epoch': self._wandb_epoch, 'train_loss': to_detach(self.smooth_loss.clone()), 'raw_loss': to_detach(self.loss.clone()), **hypers}, step=self._wandb_step)

    @patch
    def after_epoch(self:WandbCallback):
        "Log validation loss and custom metrics & log prediction samples"
        # Correct any epoch rounding error and overwrite value
        self._wandb_epoch = round(self._wandb_epoch)
        wandb.log({'epoch': self._wandb_epoch}, step=self._wandb_step)
        # Log sample predictions
        if self.log_preds:
            try:
                self.log_predictions(self.learn.fetch_preds.preds)
            except Exception as e:
                self.log_preds = False
                self.remove_cb(FetchPredsCallback)
                print(f'WandbCallback was not able to get prediction samples -> {e}')
        wandb.log({n:s for n,s in zip(self.recorder.metric_names, self.recorder.log) if n not in ['train_loss', 'epoch', 'time']+self.recorder.smooth_names}, step=self._wandb_step)
except:
    pass

In [None]:
#|exporti
try:
    import tensorboard
    from fastai.callback.tensorboard import TensorBoardCallback, tensorboard_log

    if not hasattr(TensorBoardCallback,'_metrics_before_fit'): TensorBoardCallback._metrics_before_fit = TensorBoardCallback.before_fit

    @patch
    def before_fit(self:TensorBoardCallback):
        self._metrics_before_fit()
        self.log_smooth = len(self.recorder.smooth_names) > 0

    @patch
    def after_batch(self:TensorBoardCallback):
        self.writer.add_scalar('train_loss', self.smooth_loss, self.train_iter)
        for i,h in enumerate(self.opt.hypers):
            for k,v in h.items(): self.writer.add_scalar(f'{k}_{i}', v, self.train_iter)
        if self.log_smooth:
            for k,v in zip(self.recorder.smooth_names, self.recorder.smooth_mets): 
                self.writer.add_scalar(f'{k}', v, self.train_iter)

    @patch
    def after_epoch(self:TensorBoardCallback):
        for n,v in zip(self.recorder.metric_names[2:-1], self.recorder.log[2:-1]):
            if n not in ['train_loss', 'time']+self.recorder.smooth_names:
                self.writer.add_scalar(n, v, self.train_iter)
        if self.log_preds:
            b = self.dls.valid.one_batch()
            self.learn.one_batch(0, b)
            preds = getattr(self.loss_func, 'activation', noop)(self.pred)
            out = getattr(self.loss_func, 'decodes', noop)(preds)
            x,y,its,outs = self.dls.valid.show_results(b, out, show=False, max_n=self.n_preds)
            tensorboard_log(x, y, its, outs, self.writer, self.train_iter)
except:
    pass