Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

There should be metrics package #22439

Closed
kaszperro opened this issue Jul 2, 2019 · 54 comments
Closed

There should be metrics package #22439

kaszperro opened this issue Jul 2, 2019 · 54 comments
Labels
feature A request for a proper, new feature. module: numpy Related to numpy support, and also numpy compatibility of our operators quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kaszperro
Copy link

kaszperro commented Jul 2, 2019

🚀 Feature

Why not implement some common metrics, to evaluate models, with strong GPU acceleration.

Motivation

When I need to evaluate model accuracy (ex. measure r squared) I have to write it by hand, or use Numpy or other library (without GPU support). It'd be cleaner and simpler to have dedicated package with Pytonic API. It could also provide performance benefits because of low-level (c++) optimizations.

Pitch

I'd like to have some package ex: torch.metrics then I could do something like:

from torch.metrics import RSquaredMetric
metric = RSquaredMetric()
out = model(input)
r_squared = metric(out, reference)

Or I could do:

from torch.metrics import MSEMetric
metric = MSEMetric()
out = model(input)
mse = metric(out, reference)

Alternatives

For now, I can use some Tensors modifications or use Numpy:

ref_val = reference_batch[key].cpu().detach().numpy()
cor_matrix = np.corrcoef(val, ref_val, rowvar=False)
n = cor_matrix.shape[0] // 2
out_r2[key] = [cor_matrix[i][i + n] for i in range(n)]

If you agree, I'd like to contribute.

cc @ezyang @gchanan @zou3519 @albanD @mruberry

@albanD
Copy link
Collaborator

albanD commented Jul 2, 2019

I'm not sure what the difference would be with the loss functions from torch.nn ?

@kaszperro
Copy link
Author

Correct me if I'm wrong, but losses functions are intended to do .backward() on the Tensors returned by them. But for some metrics, like RSquared, Quantiles, ... it doesn't make sense or is undoable. So metrics don't need to calculate gradients and so one.

@albanD
Copy link
Collaborator

albanD commented Jul 2, 2019

The autograd overhead is very small, even if you can do a .backward(), you don't have to.
Also, if the input does not require gradients, or you're executing the function within a with torch.no_grad(): block, the autograd is not enabled and so these functions will run as fast as the version that never computes gradients.

I agree though that standard metrics could be implemented, either in tnt or in the main package?

@kaszperro
Copy link
Author

I get your point, you are correct.

In my opinion, those metrics should be added to the main package, because as you pointed, it could be done within losses function from torch.nn. If it'd be done in tnt project, how would you classify which metric or 'loss' should be implemented in main package or in tnt?
Also there are already some in torch, like MSE, so it'd be intcosisnten to use ex. MSE from torch and RSquared from tnt.

@albanD albanD added feature A request for a proper, new feature. module: nn Related to torch.nn labels Jul 2, 2019
@kaszperro
Copy link
Author

I'd like to also discuss, where should it be implemented? For sure as a new losses functions? If so, how should for example quantiles look like?
I'd like it to be something like:

from ..... import Quantile
quantile = Quantile(0.5, axis=0)

tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
quantile(tensor)
>>> [4, 5]

@albanD
Copy link
Collaborator

albanD commented Jul 3, 2019

The place to put it I guess is to implement a function in torch.nn.functional that performs what you want. Then you can wrap it as a module in torch.nn.

I would say it's fine to have non differentiable functions in nn, what do you think @soumith ?

@izdeby izdeby added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 3, 2019
@soumith
Copy link
Member

soumith commented Jul 4, 2019

I want to see a fuller proposal for metrics that expands on what metrics need to be metrics instead of loss functions in torch.nn.functional. Adding a namespace and a full package into PyTorch core is possible, but the onus of proving that it needs to be in core, and it's design aspects need to be more like a fully fleshed PEP at this stage rather than a short description.

I think a lot (or a significant) amount of metrics actually belong in torchvision, torchtext and torchaudio, and whatever are supposed to be generic metrics probably have a strong case to be differentiable (but I get it, there are non-differentiable metrics as well).

@kaszperro
Copy link
Author

It'd take some time to write full proposal, but for now I'd like to propose some ideas and hear your opinion.

  1. It should have functionality for accumulating measurements. Imagine we have huge dataset, and we can't feed it all into memory, so we need some function for giving it batches of data and when we are finished, call some function like .get() to receive our metric, for all the data. For example:
r_squared = RSquaredMetric()
for model_out, target_data in zip(model_outs, targets):
   r_squared.accumulate(model_out, target_data)
correlation_matrix = r_squared.get()
  1. It should automatically not accumulate gradients and not track history of operations. Probably calling with torch.no_grad() internally will be okay for now.

  2. In my opinion implementing them in losses, wouldn't be ideal because for some or most of them, you can't call .backwards(), and hardly anyone would use them for training purposes.

@albanD
Copy link
Collaborator

albanD commented Jul 4, 2019

What you present looks quite similar to the meters in tnt here with new types of metrics.
Is there a key feature missing from the meter that you're looking for?

@kaszperro
Copy link
Author

I'm sorry but I know about existence of tnt since you mentioned it two days ago. And now, after brief reading of its documentation, I thing that its meters package meets my needs, since it provide accumulating batched data functionality.

But there are only few (really few) meters available. And none of them probably is useful to me. So now I think, we need to implement more of them, including some losses from PyTorch (MSE, ...) and more.

@fmassa
Copy link
Member

fmassa commented Jul 4, 2019

Check ignite, it is more feature complete and maintained

https://github.com/pytorch/ignite

@varunagrawal
Copy link
Contributor

I was thinking about this recently as well and came across something similar in Tensorflow.

I believe the idea is to have metrics such as mAP, precision (at k), nDCG etc., easily available and with strong GPU and multi-dimensional tensor support, so that there is less reinventing of the wheel and users don't get hit by bugs when trying to implement multi-dimensional versions of them.

Perhaps extending ignite's metrics module is the right way to go? Thoughts?

@kaszperro
Copy link
Author

@varunagrawal that's exactly the idea.

I believe extending ignite's metrics, is what PyTorch members want. And I might be wrong, but I still don't know why should it be in separate project, as almost everyone using PyTorch, want to evaluate model accuracy.

@ezyang
Copy link
Contributor

ezyang commented Jul 8, 2019

@jeffreyksmithjr says that this issue has been filed multiple times. It's common in analogous libraries; it's a weird user experience that we don't have these metrics out of the box. (And it may become a requirement for some of our internal requirements.) We (in PyTorch core) may need to step up and do the design work for something like this. (In terms of internal priority, this might come after having non-neural network baseline models working in PyTorch.)

@fmassa says, if we can find someone to work on this, that would be great. But it will take some time to come up with a good design (echoing what @soumith said). Moving ignite metrics into PyTorch core is definitely a good idea; it's just not clear that ignite's design is stabilized to the extent that we can actually do this. Part of the trouble is that here in core PyTorch, most of us are not writing code that are using metrics (though, @jeffreyksmithjr points out that we have plenty of internal people who have relevant experience here.)

@jeffreyksmithjr
Copy link
Contributor

I continue to feel that the omission of metrics is just a hole in our feature set that should be filled. But I totally agree that some metrics belong in domain libraries (torchvision, etc.) and some belong completely outside of PT. There is some base level of functionality that I think could make a lot of sense for us to provide us the batteries obviously worthy of including.

Beyond the concerns of mine that @ezyang transcribed above, I would call out that we're increasing the amount of staffing of folks who are expected to produce materials like tutorials. Not having even basic metrics built in creates yet more weird devX for tutorial authors. In my book on machine learning systems, I always relied upon first-party library implementations of metrics to demonstrate any concept other than the single section where the reader was learning how to implement metrics code. It's just more comprehensible for a user to not need a whole separate component to check the performance of their trained model.

That said, I don't have strong feelings about how we would structure such a solution. I think we could try to ensure a level of modularity, if it was deemed valuable, so long as we eventually pointed users to something that didn't introduce a lot of new cognitive overhead (which I would argue that tnt and ignite do if you're literally only trying to call a single function like .precision() or whatever).

I'll take responsibility for seeing if we can find an internal team member who could at least get us to a draft design that made sense to discuss. Low detail issues like this one tend to not really move the conversation more. We need an actual list of metrics to work from.

@varunagrawal
Copy link
Contributor

@jeffreyksmithjr @ezyang I would be more than happy to help with the implementation.

@kaszperro
Copy link
Author

@ezyang @jeffreyksmithjr I'm also willing to help.

@ezyang
Copy link
Contributor

ezyang commented Jul 10, 2019 via email

@kaszperro
Copy link
Author

@ezyang in what form should it be? As a comment to this issue? What should it contain?

@ezyang
Copy link
Contributor

ezyang commented Jul 15, 2019

A comment to this issue seems like the best format we have right now. I think at a first cut, a fuller description of the APIs that would be added, and some overall description about organization and philosophy, would be good.

@kaszperro
Copy link
Author

Basic idea

Let me quote @varunagrawal:

Implement metrics such as mAP, precision (at k), nDCG etc., easily available and with strong GPU and multi-dimensional tensor support, so that there is less reinventing of the wheel and users don't get hit by bugs when trying to implement multi-dimensional versions of them.

It's somewhat wired user experience that you have to write them by hand on import another package (Ignite) to have such a basic functionality. From my experience, I can tell that in pollution forecast, it's essential to measure numpy corrcoef to check correlation between model output and target (measured) output, and I had to implement it by my self to allow batch accumulation and GPU support.
As you guys have already implemented many metrics in Ignite, I think, that it won't be very time consuming. And In my opinion, many users will be pleased to receive metrics in PyTorch, because minority of them know about Ignite package.

Other Libraries

  • Tensorflow (v 1.x)

In Tensorflow there is metrics package. All of metrics return two operations: first to calculate metric output, second to update metric (ex. accumulate over batches).
Let's look at some code (based on this tutorial)

# Placeholders to take in batches of data
tf_label = tf.placeholder(dtype=tf.int32, shape=[None])
tf_prediction = tf.placeholder(dtype=tf.int32, shape=[None])

# Define the metric and update operations
acc_metric, acc_metric_update = tf.metrics.accuracy(tf_label,
                                                      tf_prediction,
                                                      name="acc_metric")

tf.metrics.accuracy computes: num_correct/num_items. It creates two hidden variables, which accumulates num_correct and num_items.

That's how to update metric:

for i in range(n_batches):
        # Update the running variables on new batch of samples
        feed_dict={tf_label: labels[i], tf_prediction: predictions[i]}
        session.run(acc_metric_update, feed_dict=feed_dict)

Now we can actually compute our acc_metric:

score = session.run(acc_metric)

Because Tensorflow uses static graphs, it's a bit complicated, we have to use sessions.
But we can see that they only needed to operations: calculate metric, accumulate metric. It's also possible to reset metric state, by getting variables created by a metric and resetting them:

# Get variables created by metric
running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="acc_metric")

# Define initializer to initialize/reset running variables
running_vars_initializer = tf.variables_initializer(var_list=running_vars)

# Reset variables
session.run(running_vars_initializer)
  • Tensorflow (v 2.x)

As Tensorflow 2.0 is stil in beta stage, this (probably?) can change slightly. But for now each metric provide three methods (metrics are classes from now one):

  • reset_states() resets all of the metric state variables,
  • result() computes metric,
  • update_state(...) accumulates metric statistics

Let's look at Accuracy metric as before but from the Tensorflow 2.x perspective:

# Define metric as a class
m = tf.keras.metrics.Accuracy()
for i in range(n_batches):
        # Update the metrics on new batch of samples
        m.update_state(labels[i],  predictions[i])

# Get accuracy:
m.result().numpy()
>>> 0.99

In my opinion, PyTorch's metrics should be implemented in similar way as the Tensorflow's 2.x are.
So each Metric is a Class with three methods.
I believe that Ignite metrics work in a similar fashion as the Tensorflow ones, so I do not intend to quote them here, as you know much better how they are constructed.

Summary

After looking at other libraries, and from my experience, I came to conclusion that there should be new package called: torch.metrics. It's base class should contain three methods:

  • reset() resets current state
  • accumulate(*args, **kwargs) accumulate metric statistic
  • compute() compute metric result and return Tensor(s) with data.

It is not necessary for Metric to accumulate gradients, nor remember operations history. But it needs deeper look whether to implement it as with torch.no_grad() or should user call with torch.no_grad() before calling accumulate or other ideas.

There is also a matter of which device to use for computations. Maybe there should be another method or constructor parameter, which will set device for all contained Tensors? Something like .to() know from nn.Module? It also could be done on first .accumulate() call, depending on given Tensor's devices.

I'm happy contribute, and have already prototyped numpy corrcoef and numpy cov equivalents.

class CovMetric:
    def __init__(self):
        self.x_mean = 0
        self.c = 0
        self.n = 0

    @staticmethod
    def __concat_input(x, y, rowvar):
        if not rowvar and x.shape[0] != 1:
            x = x.t()

        if y is not None:
            if not rowvar and y.shape[0] != 1:
                y = y.t()

            x = torch.cat((x, y), dim=0)

        return x

    def reset(self):
        self.x_mean = 0
        self.c = 0
        self.n = 0

    def accumulate(self, x, y=None, rowvar=True):
        x = self.__concat_input(x, y, rowvar)
        self.n += x.size(1)

        xs = torch.sum(x, 1).unsqueeze(-1).expand_as(x)

        new_mean = self.x_mean + (xs - self.x_mean * x.size(1)) / self.n

        m1 = torch.sub(x, new_mean)
        m2 = torch.sub(x, self.x_mean)

        self.c += m1.mm(m2.t())
        self.x_mean = new_mean

    def compute(self):
        return self.c / (self.n - 1)


class CorrcoefMetric:
    def __init__(self):
        self.cov = CovMetric()

    def reset(self):
        self.cov.reset()

    def accumulate(self, x, y=None, rowvar=True):
        self.cov.accumulate(x, y, rowvar)

    def compute(self):
        c = self.cov.compute()
        # normalize covariance matrix
        d = torch.diag(c)
        stddev = torch.sqrt(d)
        c /= stddev[:, None]
        c /= stddev[None, :]

        return torch.clamp(c, -1.0, 1.0)


if __name__ == '__main__':
    N = 1024
    M = 100
    batch_size = 64
    mat1 = torch.rand((N, M), dtype=torch.float64)
    mat2 = torch.rand((N, M), dtype=torch.float64)

    cor = CorrcoefMetric()

    for i in range(N // batch_size):
        cor.accumulate(
            mat1[i * batch_size:(i + 1) * batch_size, ],
            mat2[i * batch_size:(i + 1) * batch_size, ],
            False
        )

    accumulated_cor = cor.compute().numpy()
    numpy_cor = np.corrcoef(mat1.numpy(), mat2.numpy(), False)
    print(np.allclose(numpy_cor, accumulated_cor))

>>> True

@prasunanand
Copy link
Contributor

Is anyone working on it ?

@varunagrawal
Copy link
Contributor

@prasunanand I believe the consensus right now is to first come up with a proposal for this subpackage. This will be taken care of by someone on the Pytorch internal team, and after an initial setup is when we can start adding things.

@kaszperro
Copy link
Author

So is there any progress with this issue internally?

@williamFalcon
Copy link
Contributor

@Darktex are you thinking we enable this in lightning? How would that look for you? We could add a flag called metrics or something and we'd call at the right times.

but maybe i'm not understanding, it seems like you just want to import this package and call at the end of your training_step or validation_step?

@Darktex
Copy link

Darktex commented Feb 29, 2020

Reviving this thread to cross-link to the issue on Lightning here and to spur more discussion. Let's take Lightning as an example of a library on top of PyTorch that might want to handle some wrapping for the pretty printing. This would align well with the proposal I wrote a few comments above. I like reasoning from more concrete stuff, so let me try to refine that proposal in a more concrete manner (I am not married to anything here, I think it just helps ground things).

  1. Treat metrics as a subtype of nn.Module? We do that for cost functions, it seems good practice to do that for metrics too - the extra caveat would be that they are not differentiable. See later paragraph for a more detailed discussion on this.
  2. The sklearn replacement would probably live in F as usual, so eg nn.metrics.F1Score will eventually call F.f1_score
  3. The pretty printing, metrics reporting etc lives outside of PyTorch in client libs, for example Lightning.

The biggest question I have is whether we can find a way to support computing metrics nicely once we factor in DDP in a way that is completely independent. I like how Ignite is doing it and I would consider starting from their API for the one in PyTorch (potentially taking the implementation too and decoupling it from the Ignite specifics so everyone can use it!).

Here's a primer from their API:

class Metric(metaclass=ABCMeta):
    _required_output_keys = ("y_pred", "y")

    def __init__(self, output_transform=lambda x: x, device=None):
        self._output_transform = output_transform

        # Check device if distributed is initialized:
        if dist.is_available() and dist.is_initialized():

            # check if reset and update methods are decorated. Compute may not be decorated
            if not (hasattr(self.reset, "_decorated") and hasattr(self.update, "_decorated")):
                warnings.warn("{} class does not support distributed setting. Computed result is not collected "
                              "across all computing devices".format(self.__class__.__name__),
                              RuntimeWarning)
            if device is None:
                device = "cuda"
            device = torch.device(device)
        self._device = device
        self._is_reduced = False
        self.reset()

@abstractmethod
    def reset(self):
        """
        Resets the metric to it's initial state.

        This is called at the start of each epoch.
        """
        pass


@abstractmethod
    def update(self, output):
        """
        Updates the metric's state using the passed batch output.

        This is called once for each batch.

        Args:
            output: the is the output from the engine's process function.
        """
        pass


@abstractmethod
    def compute(self):
        """
        Computes the metric based on it's accumulated state.

        This is called at the end of each epoch.

        Returns:
            Any: the actual quantity of interest.

        Raises:
            NotComputableError: raised when the metric cannot be computed.
        """
        pass

Source: ignite.metrics.metric.

To make them amenable to DDP, they use decorators to mark what they do during all_reduce and after it. To see a minimal implementation, take a look at their implementation of [Recall]:(https://pytorch.org/ignite/_modules/ignite/metrics/recall.html#Recall).

Another good thing that Ignite does very well here is making it easy for others to write their own DDP-compatible metrics without dealing with the DDP internals themselves. For example, take a look at the VariableAccumulation metric which is a parent/mixin that can bootstrap other concrete metrics such as GeometricAverage.

For the sake of argument, let's say our metrics package looks exactly the same as Ignite's. The next question is how we are going to use it: Ignite is events-based so they just declare what metrics they want and stuff happens behind the scenes eg evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(device=device)}, device=device). The major difference with us is that we'd want users to have much more control about how they are going to use these.

Let's write a sample eval loop:

macro_f1 = nn.F1Score("macro")

with torch.no_grad():
    for x, y in eval_dataloader:
        y_hat = F.softmax(model(x), dim=-1)
        macro_f1.update(y_hat, y)  # void method, just update state
        m = macro_f1()  # forward calls compute(), takes no args. Compute from state

        print(f"Macro f1: {m.item()}")  # or whatever else you want to do

With DDP, maybe this still works - normally we don't need to pass the loss through torch.nn.DistributedDataParallel. Worst case let's say we do, this is still not too bad:

macro_f1 = nn.F1Score("macro")

torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
model = DistributedDataParallel(model, device_ids=[i], output_device=i)
macro_f1 = DistributedDataParallel(macro_f1, device_ids=[i], output_device=i)

with torch.no_grad():
    for x, y in eval_dataloader:
        y_hat = F.softmax(model(x), dim=-1)
        macro_f1.update(y_hat, y)  # void method, just update state
        m = macro_f1()  # forward calls compute(), takes no args. Compute from state

        print(f"Macro f1: {m.item()}")  # or whatever else you want to do

This is flexible, because we leave the responsibility of updating state and computing in the hands of the client, so they can do however they please. This will not limit researchers, actually quite the opposite: it will liberate them from writing metrics code that is unsexy and hard to do efficiently (for example on DDP). Trainer libraries can then do further magic by using events to update and compute for you at the right time.

Any feedback?

@kaszperro
Copy link
Author

That looks really nice.
I like freedom that it provides and separation Model from Metrics (unlike lightning, which combines model with evaluation).

And as I understand, sklearn functions replacement would be useful for small datasets, in which you don't have to accumulate state, because everything fits into memory, potentially increasing performance?

@Darktex
Copy link

Darktex commented Mar 3, 2020

Yes, and we can factor out components as needed so that they share as much code as possible

@netw0rkf10w
Copy link

Let's write a sample eval loop:

macro_f1 = nn.F1Score("macro")

with torch.no_grad():
    for x, y in eval_dataloader:
        y_hat = F.softmax(model(x), dim=-1)
        macro_f1.update(y_hat, y)  # void method, just update state
        m = macro_f1()  # forward calls compute(), takes no args. Compute from state

        print(f"Macro f1: {m.item()}")  # or whatever else you want to do

With DDP, maybe this still works - normally we don't need to pass the loss through torch.nn.DistributedDataParallel. Worst case let's say we do, this is still not too bad:

macro_f1 = nn.F1Score("macro")

torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
model = DistributedDataParallel(model, device_ids=[i], output_device=i)
macro_f1 = DistributedDataParallel(macro_f1, device_ids=[i], output_device=i)

with torch.no_grad():
    for x, y in eval_dataloader:
        y_hat = F.softmax(model(x), dim=-1)
        macro_f1.update(y_hat, y)  # void method, just update state
        m = macro_f1()  # forward calls compute(), takes no args. Compute from state

        print(f"Macro f1: {m.item()}")  # or whatever else you want to do

Very nice! This is exactly what should be done!

@kaszperro
Copy link
Author

So now we have everything ready to start?
As I understand, the steps:

  1. Implement as many metrics as possible in torch.nn.functional
  2. Implement metrics using Ingite's API in torch.nn?

@Darktex proposed to put metrics in torch.nn:
But shouldn't metrics be separated from torch.nn, ex: torch.nn.metrics or torch.metrics?
They could be subclass of torch.nn.Module, but if they were put directly into torch.nn, it'd cause confusion and mix-up imo.
I thing that they are different enough, that they deserve separate module
I cloud be wrong though.

@netw0rkf10w
Copy link

I don't think the metrics should be separated from nn.

In tf.keras they have tf.keras.metrics, similar to tf.keras.losses or tf.keras.optimizers, but note that in PyTorch there's no torch.nn.losses.

Everything is inside nn (nn.CrossEntropyLoss is a loss, nn.Conv2d is a layer), so in my opinion, it makes more sense (and is less confusing) to do nn.MeanIoU() instead of torch.metrics.MeanIoU() (and torch.nn.metrics.MeanIoU() is even worse).

(You may ask: Then why do we have torch.optim.SGD instead of torch.nn.SGD? I have no idea.)

For the implementation, it'd be obvious that the metrics should be implemented in a new torch.nn module, and similar to the losses, all metrics will derive from some class _Metric(Module):.

@soumith
Copy link
Member

soumith commented Apr 29, 2020

  1. Metrics should be different from torch.nn, because they are not differentiable
  2. Metrics needs to fluidly interact with torch.distributed, and hence often in non-trivial ways with the training loop.

The APIs for a metrics package, especially because of (2), are not obvious.
Both 1 and 2 makes it nicer to have a separate metrics package that can eventually stabilize and come into core pytorch

@mruberry mruberry added module: numpy Related to numpy support, and also numpy compatibility of our operators and removed module: nn Related to torch.nn labels May 10, 2020
@Raikan10
Copy link

Hi, I'm a student who is a beginner at Pytorch but I understand what you guys are trying to do here. This comes from personal experience in that when I started the MNIST tutorial it was kind of weird I had to calculate accuracy myself. Thus I would love to help out, hopefully under some kind of mentorship!

@mruberry
Copy link
Collaborator

Thanks @Raikan10! We'll post here as there are developments. If you're interested in contributing to PyTorch generally then you could also check out issues with the "OSS contribution wanted" label. See this query: https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+label%3A%22oss+contribution+wanted%22.

@williamFalcon
Copy link
Contributor

williamFalcon commented May 18, 2020

@Raikan10 for the metrics package we are looking for help with some metrics in lightning. good place to start!
ping @justusschock on the lightning slack

https://join.slack.com/t/pytorch-lightning/shared_invite/enQtODU5ODIyNTUzODQwLTFkMDg5Mzc1MDBmNjEzMDgxOTVmYTdhYjA1MDdmODUyOTg2OGQ1ZWZkYTQzODhhNzdhZDA3YmNhMDhlMDY4YzQ

@arita37
Copy link

arita37 commented May 20, 2020 via email

@Yura52
Copy link
Contributor

Yura52 commented May 31, 2020

I definitely agree that ignite.metrics is an excellent package. I would like to share some thoughts on Ignite's Accuracy. There are two points that probably should be addressed while implementing Accuracy in a hypothetical torch.metrics:

Inconsistent input format for binary classification and multiclass problems

In the first case, Ignite's Accuracy expects labels as input, whilst in the second case it expects probabilities or logits. It was a big point of confusion to me.

No shortcuts for saying "I want to pass logits/probabilities as input"

Fundamentally, Accuracy is a metric that takes predicted and correct labels as input and returns the percentage of correct predictions as output. However, in practice neural networks trained for classification often return logits or probabilities. For example, in the case of binary classification, I have never written the following:

accuracy = Accuracy()

Instead, I always have to write:

accuracy = Accuracy(transform=lambda x: torch.round(torch.sigmoid(x)))
# either
accuracy = Accuracy(transform=lambda x: torch.round(x))

Suggested solution for both problems: let the user explicitly say in which form input will be passed:

import enum
class Accuracy(...):
    class Mode(enum.Enum):
        LABELS = enum.auto()
        PROBABILITIES = enum.auto()
        LOGITS = enum.auto()

    def __init__(self, mode=Mode.LABELS, ...):
        ...

The suggested interface can be also extended to support custom thresholds by adding the __call__ method to the Mode class.

Yes, the mode argument serves as one more transform argument (see the original Ignite Accuracy). However, my hypothesis is that this shortcut will cover so many use cases and will make user code so much more expressive, that it may worth it.

cc: @vfdev-5

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 31, 2020

@WeirdKeksButtonK this can be a nice enhancement for Ignite as well. Please, open an issue in Ignite repository.

@soumith ignite.metrics module is rather stable can be a reference for a potential future torch.metrics module.

The major difference with us is that we'd want users to have much more control about how they are going to use these.

@Darktex you do have also the same control of that in ignite with almost the same API :

metric = ...
metric.reset()
for _ in range(n):
    metric.update(y_pred, y)
result = metrics.compute()

Another point is about distributed package, currently native torch distributed supports only nccl, gloo, mpi backends. For users playing with XLA, aggregating methods should be readapted.

@enochkan
Copy link

enochkan commented Nov 3, 2020

https://github.com/chinokenochkan/torch-metrics

Feel free to suggest more metrics/ contributions are welcome!

@francois-rozet
Copy link

For those who are interested, I have implemented a package of IQA metrics:
https://github.com/francois-rozet/piqa

@edenlightning
Copy link

Check out https://github.com/PyTorchLightning/metrics!
Over 25 implementations and a simple API to build your own metric, optimized for distributed training!

@arita37
Copy link

arita37 commented Mar 24, 2021 via email

@ezyang
Copy link
Contributor

ezyang commented Aug 12, 2021

For those following along, there is now a third-party torchmetrics package that may be helpful https://torchmetrics.readthedocs.io/en/latest/ EDIT: edenlightning already mentioned this!!

@ananthsub
Copy link

Closing this issue out since torchmetrics and torcheval fill this need now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: numpy Related to numpy support, and also numpy compatibility of our operators quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests