Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inception Score #2053

Merged
merged 50 commits into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
e5630b2
IS first implementation
gucifer Jun 12, 2021
4b2a20c
IS tests
gucifer Jun 12, 2021
03e4836
mpyp fix
gucifer Jun 12, 2021
23c4abc
Merge branch 'pytorch:master' into IS
gucifer Jun 12, 2021
4c0258b
Changed name
gucifer Jun 12, 2021
275b3f6
Merge branch 'IS' of https://github.com/gucifer/ignite into IS
gucifer Jun 12, 2021
c7b8e0d
Changed IS to inception_score
gucifer Jun 13, 2021
bc23b4e
Merge branch 'master' into IS
sdesrozis Jun 13, 2021
939e3fb
Remove torch precision
gucifer Jun 13, 2021
2b86a9d
Merge branch 'IS' of https://github.com/gucifer/ignite into IS
gucifer Jun 13, 2021
7577292
Error output changes
gucifer Jun 14, 2021
2653047
Docs update
gucifer Jun 14, 2021
91e0ca6
torchified everything
gucifer Jun 15, 2021
05bb5e6
Merge FID
gucifer Jun 18, 2021
3494ef8
Common Inception Class
gucifer Jun 19, 2021
0228c1c
Mypy fix
gucifer Jun 19, 2021
0d0de5e
rename test file
gucifer Jun 19, 2021
c75fd40
Docs changes, tests added
gucifer Jun 20, 2021
a40386f
Removed static methods
gucifer Jun 20, 2021
7bcbcfd
New changes and tests
gucifer Jun 21, 2021
b86f24e
IS formula updates
gucifer Jun 21, 2021
b679ca4
Added default test
gucifer Jun 21, 2021
b372537
Cleaner if else stmt and removed InceptionModel
gucifer Jun 21, 2021
56ede91
Better device handling
gucifer Jun 21, 2021
47dba8e
Device handling
gucifer Jun 21, 2021
4cf087b
Callable to torch.nn.Module
gucifer Jun 21, 2021
5cd3c2c
Update ignite/metrics/__init__.py
gucifer Jun 21, 2021
4b6f130
Merge branch 'master' of https://github.com/pytorch/ignite into IS
gucifer Jun 21, 2021
e717c16
GAN Metric Refeactor
gucifer Jun 22, 2021
eb4ba5e
Test update for coverage
gucifer Jun 22, 2021
06ed0ed
Docs change
gucifer Jun 22, 2021
b0ca920
Mypy fix
gucifer Jun 22, 2021
31932fb
Docs change
gucifer Jun 23, 2021
665ac78
Safer Base Inception Class
gucifer Jun 23, 2021
7d964f0
Base class changes
gucifer Jun 23, 2021
5416312
removed default args
gucifer Jun 23, 2021
10b48d4
Better check_input
gucifer Jun 23, 2021
e0c241c
Docs fix
gucifer Jun 23, 2021
3f64353
Redesign
gucifer Jun 24, 2021
2f2b183
Docs fix
gucifer Jun 24, 2021
9c2a094
Feature Extraction Function
gucifer Jun 24, 2021
d999011
Better naming and checking feature shape
gucifer Jun 24, 2021
279412d
Merge branch 'master' into IS
sdesrozis Jun 24, 2021
4913a83
Tests for Base Inception Class
gucifer Jun 24, 2021
a3ed1ed
Tests
gucifer Jun 24, 2021
7ab7711
small fix for coverage
Jun 25, 2021
7fec3c0
Merge branch 'master' into IS
sdesrozis Jun 25, 2021
ed0923b
move model to metric device - detach inputs - minor fix
Jun 25, 2021
6bdd92d
Update ignite/metrics/gan/inception_score.py
sdesrozis Jun 25, 2021
fb69446
Update inception_score.py
vfdev-5 Jun 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ Complete list of metrics
Rouge
RougeL
RougeN
InceptionScore
FID

Helpers for customizing metrics
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ignite.metrics.fbeta import Fbeta
from ignite.metrics.frequency import Frequency
from ignite.metrics.gan.fid import FID
from ignite.metrics.gan.inception_score import InceptionScore
from ignite.metrics.loss import Loss
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
Expand Down Expand Up @@ -41,6 +42,7 @@
"FID",
"GeometricAverage",
"IoU",
"InceptionScore",
"mIoU",
"JaccardIndex",
"MultiLabelConfusionMatrix",
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/gan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ignite.metrics.gan.fid import FID
from ignite.metrics.gan.inception_score import InceptionScore

__all__ = [
"InceptionScore",
"FID",
]
124 changes: 57 additions & 67 deletions ignite/metrics/gan/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch

from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
from ignite.metrics.gan.utils import InceptionModel, _BaseInceptionMetric
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

__all__ = [
"FID",
Expand Down Expand Up @@ -47,26 +48,7 @@ def fid_score(
return float(diff.dot(diff).item() + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean)


class InceptionExtractor:
gucifer marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self) -> None:
try:
from torchvision import models
except ImportError:
raise RuntimeError("This module requires torchvision to be installed.")
self.model = models.inception_v3(pretrained=True)
self.model.fc = torch.nn.Identity()
self.model.eval()

@torch.no_grad()
def __call__(self, data: torch.Tensor) -> torch.Tensor:
if data.dim() != 4:
raise ValueError(f"Inputs should be a tensor of dim 4, got {data.dim()}")
if data.shape[1] != 3:
raise ValueError(f"Inputs should be a tensor with 3 channels, got {data.shape}")
return self.model(data)


class FID(Metric):
class FID(_BaseInceptionMetric):
r"""Calculates Frechet Inception Distance.

.. math::
Expand All @@ -90,10 +72,15 @@ class FID(Metric):
__ https://github.com/mseitzer/pytorch-fid

Args:
num_features: number of features, must be defined if the parameter ``feature_extractor`` is also defined.
Otherwise, default value is 2048.
feature_extractor: a callable for extracting the features from the input data. If neither num_features nor
feature_extractor are defined, default value is ``InceptionExtractor``.
num_features: number of features predicted by the model or the reduced feature vector of the image.
Default value is 2048.
feature_extractor: a torch Module for extracting the features from the input data.
It returns a tensor of shape (batch_size, num_features).
If neither ``num_features`` nor ``feature_extractor`` are defined, by default we use an ImageNet
pretrained Inception Model. If only ``num_features`` is defined but ``feature_extractor`` is not
defined, ``feature_extractor`` is assigned Identity Function.
Please note that the model will be implicitly converted to device mentioned in the ``device``
argument.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
Expand All @@ -110,7 +97,7 @@ class FID(Metric):
import torch
from ignite.metric.gan import FID

y_pred, y = torch.rand(10, 2048), torch.rand(10, 2048)
y_pred, y = torch.rand(10, 3, 299, 299), torch.rand(10, 3, 299, 299)
m = FID()
m.update((y_pred, y))
print(m.compute())
Expand All @@ -121,7 +108,7 @@ class FID(Metric):
def __init__(
self,
num_features: Optional[int] = None,
feature_extractor: Optional[Callable] = None,
feature_extractor: Optional[torch.nn.Module] = None,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
) -> None:
Expand All @@ -136,26 +123,24 @@ def __init__(
except ImportError:
raise RuntimeError("This module requires scipy to be installed.")

# default is inception
if num_features is None and feature_extractor is None:
num_features = 2048
feature_extractor = InceptionExtractor()
elif num_features is None:
raise ValueError("Argument num_features should be defined")
elif feature_extractor is None:
self._feature_extractor = lambda x: x
feature_extractor = self._feature_extractor

if num_features <= 0:
raise ValueError(f"Argument num_features must be greater to zero, got: {num_features}")
self._num_features = num_features
self._feature_extractor = feature_extractor
num_features = 1000
feature_extractor = InceptionModel(return_features=False)

self._eps = 1e-6
super(FID, self).__init__(output_transform=output_transform, device=device)

super(FID, self).__init__(
num_features=num_features,
feature_extractor=feature_extractor,
output_transform=output_transform,
device=device,
)

@staticmethod
def _online_update(features: torch.Tensor, total: torch.Tensor, sigma: torch.Tensor) -> None:

total += features

if LooseVersion(torch.__version__) <= LooseVersion("1.7.0"):
sigma += torch.ger(features, features)
else:
Expand All @@ -165,65 +150,70 @@ def _get_covariance(self, sigma: torch.Tensor, total: torch.Tensor) -> torch.Ten
r"""
Calculates covariance from mean and sum of products of variables
"""

if LooseVersion(torch.__version__) <= LooseVersion("1.7.0"):
sub_matrix = torch.ger(total, total)
else:
sub_matrix = torch.outer(total, total)

sub_matrix = sub_matrix / self._num_examples
return (sigma - sub_matrix) / (self._num_examples - 1)

@staticmethod
def _check_feature_input(train: torch.Tensor, test: torch.Tensor) -> None:
for feature in [train, test]:
if feature.dim() != 2:
raise ValueError(f"Features must be a tensor of dim 2, got: {feature.dim()}")
if feature.shape[0] == 0:
raise ValueError(f"Batch size should be greater than one, got: {feature.shape[0]}")
if feature.shape[1] == 0:
raise ValueError(f"Feature size should be greater than one, got: {feature.shape[1]}")
if train.shape[0] != test.shape[0] or train.shape[1] != test.shape[1]:
raise ValueError(
f"Number of Training Features and Testing Features should be equal ({train.shape} != {test.shape})"
)
return (sigma - sub_matrix) / (self._num_examples - 1)

@reinit__is_reduced
def reset(self) -> None:
self._train_sigma = torch.zeros((self._num_features, self._num_features), dtype=torch.float64).to(self._device)
self._train_total = torch.zeros(self._num_features, dtype=torch.float64).to(self._device)
self._test_sigma = torch.zeros((self._num_features, self._num_features), dtype=torch.float64).to(self._device)
self._test_total = torch.zeros(self._num_features, dtype=torch.float64).to(self._device)
self._num_examples = 0

self._train_sigma = torch.zeros(
(self._num_features, self._num_features), dtype=torch.float64, device=self._device
)

self._train_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)

self._test_sigma = torch.zeros(
(self._num_features, self._num_features), dtype=torch.float64, device=self._device
)

self._test_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
self._num_examples: int = 0

super(FID, self).reset()

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:

# Extract the features from the outputs
train_features = self._feature_extractor(output[0].detach()).to(self._device)
test_features = self._feature_extractor(output[1].detach()).to(self._device)
train, test = output
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
train_features = self._extract_features(train)
test_features = self._extract_features(test)

# Check the feature shapess
self._check_feature_input(train_features, test_features)
if train_features.shape[0] != test_features.shape[0] or train_features.shape[1] != test_features.shape[1]:
raise ValueError(
f"""
Number of Training Features and Testing Features should be equal ({train_features.shape} != {test_features.shape})
"""
)

# Updates the mean and covariance for the train features
for i, features in enumerate(train_features, start=self._num_examples + 1):
for features in train_features:
self._online_update(features, self._train_total, self._train_sigma)

# Updates the mean and covariance for the test features
for i, features in enumerate(test_features, start=self._num_examples + 1):
for features in test_features:
self._online_update(features, self._test_total, self._test_sigma)

self._num_examples += train_features.shape[0]

@sync_all_reduce("_num_examples", "_train_total", "_test_total", "_train_sigma", "_test_sigma")
def compute(self) -> float:

fid = fid_score(
mu1=self._train_total / self._num_examples,
mu2=self._test_total / self._num_examples,
sigma1=self._get_covariance(self._train_sigma, self._train_total),
sigma2=self._get_covariance(self._test_sigma, self._test_total),
eps=self._eps,
)

if torch.isnan(torch.tensor(fid)) or torch.isinf(torch.tensor(fid)):
warnings.warn("The product of covariance of train and test features is out of bounds.")

return fid
115 changes: 115 additions & 0 deletions ignite/metrics/gan/inception_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Callable, Optional, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.gan.utils import InceptionModel, _BaseInceptionMetric

# These decorators helps with distributed settings
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

__all__ = ["InceptionScore"]


class InceptionScore(_BaseInceptionMetric):
r"""Calculates Inception Score.

.. math::
\text{IS(G)} = \exp(\frac{1}{N}\sum_{i=1}^{N} D_{KL} (p(y|x^{(i)} \parallel \hat{p}(y))))

where :math:`p(y|x)` is the conditional probability of image being the given object and
:math:`p(y)` is the marginal probability that the given image is real, `G` refers to the
generated image and :math:`D_{KL}` refers to KL Divergence of the above mentioned probabilities.

More details can be found in `Barratt et al. 2018`__.

__ https://arxiv.org/pdf/1801.01973.pdf


Args:
gucifer marked this conversation as resolved.
Show resolved Hide resolved
num_features: number of features predicted by the model or number of classes of the model. Default
value is 1000.
feature_extractor: a torch Module for predicting the probabilities from the input data.
It returns a tensor of shape (batch_size, num_features).
If neither ``num_features`` nor ``feature_extractor`` are defined, by default we use an ImageNet
pretrained Inception Model. If only ``num_features`` is defined but ``feature_extractor`` is not
defined, ``feature_extractor`` is assigned Identity Function.
Please note that the class object will be implicitly converted to device mentioned in the
``device`` argument.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``y_pred``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Example:

.. code-block:: python
gucifer marked this conversation as resolved.
Show resolved Hide resolved

from ignite.metric.gan import InceptionScore
import torch

images = torch.rand(10, 3, 299, 299)

m = InceptionScore()
m.update(images)
print(m.compute())

.. versionadded:: 0.5.0
"""

def __init__(
self,
num_features: Optional[int] = None,
feature_extractor: Optional[torch.nn.Module] = None,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
) -> None:

if num_features is None and feature_extractor is None:
num_features = 1000
feature_extractor = InceptionModel(return_features=False)

self._eps = 1e-16

super(InceptionScore, self).__init__(
num_features=num_features,
feature_extractor=feature_extractor,
output_transform=output_transform,
device=device,
)

@reinit__is_reduced
def reset(self) -> None:

self._num_examples = 0

self._prob_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
self._total_kl_d = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)

super(InceptionScore, self).reset()

@reinit__is_reduced
def update(self, output: torch.Tensor) -> None:

probabilities = self._extract_features(output)

self._num_examples += probabilities.shape[0]

self._prob_total += torch.sum(probabilities, 0).to(self._device)
self._total_kl_d += torch.sum(probabilities * torch.log(probabilities + self._eps), 0).to(self._device)

@sync_all_reduce("_num_examples", "_prob_total", "_total_kl_d")
def compute(self) -> torch.Tensor:

if self._num_examples == 0:
raise NotComputableError("InceptionScore must have at least one example before it can be computed.")

mean_probs = self._prob_total / self._num_examples
excess_entropy = self._prob_total * torch.log(mean_probs + self._eps)
avg_kl_d = torch.sum(self._total_kl_d - excess_entropy) / self._num_examples

return torch.exp(avg_kl_d)
Loading