-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtorchmetric_kid.py
282 lines (230 loc) · 11.9 KB
/
torchmetric_kid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module
from torchmetrics.image.fid import NoTrainInceptionV3
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE
__doctest_requires__ = {("KernelInceptionDistance", "KID"): ["torch_fidelity"]}
def maximum_mean_discrepancy(k_xx: Tensor, k_xy: Tensor, k_yy: Tensor) -> Tensor:
"""Adapted from `KID Score`_"""
m = k_xx.shape[0]
diag_x = torch.diag(k_xx)
diag_y = torch.diag(k_yy)
kt_xx_sums = k_xx.sum(dim=-1) - diag_x
kt_yy_sums = k_yy.sum(dim=-1) - diag_y
k_xy_sums = k_xy.sum(dim=0)
kt_xx_sum = kt_xx_sums.sum()
kt_yy_sum = kt_yy_sums.sum()
k_xy_sum = k_xy_sums.sum()
value = (kt_xx_sum + kt_yy_sum) / (m * (m - 1))
value -= 2 * k_xy_sum / (m**2)
return value
def poly_kernel(f1: Tensor, f2: Tensor, degree: int = 3, gamma: Optional[float] = None, coef: float = 1.0) -> Tensor:
"""Adapted from `KID Score`_"""
if gamma is None:
gamma = 1.0 / f1.shape[1]
kernel = (f1 @ f2.T * gamma + coef) ** degree
return kernel
def poly_mmd(
f_real: Tensor, f_fake: Tensor, degree: int = 3, gamma: Optional[float] = None, coef: float = 1.0
) -> Tensor:
"""Adapted from `KID Score`_"""
k_11 = poly_kernel(f_real, f_real, degree, gamma, coef)
k_22 = poly_kernel(f_fake, f_fake, degree, gamma, coef)
k_12 = poly_kernel(f_real, f_fake, degree, gamma, coef)
return maximum_mean_discrepancy(k_11, k_12, k_22)
class KernelInceptionDistance(Metric):
r"""
Calculates Kernel Inception Distance (KID) which is used to access the quality of generated images. Given by
.. math::
KID = MMD(f_{real}, f_{fake})^2
where :math:`MMD` is the maximum mean discrepancy and :math:`I_{real}, I_{fake}` are extracted features
from real and fake images, see [1] for more details. In particular, calculating the MMD requires the
evaluation of a polynomial kernel function :math:`k`
.. math::
k(x,y) = (\gamma * x^T y + coef)^{degree}
which controls the distance between two features. In practise the MMD is calculated over a number of
subsets to be able to both get the mean and standard deviation of KID.
Using the default feature extraction (Inception v3 using the original weights from [2]), the input is
expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images
will be resized to 299 x 299 which is the size of the original training data.
.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
``pip install torch-fidelity``
.. note:: the ``forward`` method can be used but ``compute_on_step`` is disabled by default (oppesit of
all other metrics) as this metric does not really make sense to calculate on a single batch. This
means that by default ``forward`` will just call ``update`` underneat.
Args:
feature: Either an str, integer or ``nn.Module``:
- an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following:
'logits_unbiased', 64, 192, 768, 2048
- an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns
an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size.
subsets: Number of subsets to calculate the mean and standard deviation scores over
subset_size: Number of randomly picked samples in each subset
degree: Degree of the polynomial kernel function
gamma: Scale-length of polynomial kernel. If set to ``None`` will be automatically set to the feature size
coef: Bias term in the polynomial kernel.
reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if
your dataset does not change.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
References:
[1] Demystifying MMD GANs
Mikołaj Bińkowski, Danica J. Sutherland, Michael Arbel, Arthur Gretton
https://arxiv.org/abs/1801.01401
[2] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium,
Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Sepp Hochreiter
https://arxiv.org/abs/1706.08500
Raises:
ValueError:
If ``feature`` is set to an ``int`` (default settings) and ``torch-fidelity`` is not installed
ValueError:
If ``feature`` is set to an ``int`` not in ``[64, 192, 768, 2048]``
ValueError:
If ``subsets`` is not an integer larger than 0
ValueError:
If ``subset_size`` is not an integer larger than 0
ValueError:
If ``degree`` is not an integer larger than 0
ValueError:
If ``gamma`` is niether ``None`` or a float larger than 0
ValueError:
If ``coef`` is not an float larger than 0
ValueError:
If ``reset_real_features`` is not an ``bool``
Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics.image.kid import KernelInceptionDistance
>>> kid = KernelInceptionDistance(subset_size=50)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> kid.update(imgs_dist1, real=True)
>>> kid.update(imgs_dist2, real=False)
>>> kid_mean, kid_std = kid.compute()
>>> print((kid_mean, kid_std))
(tensor(0.0337), tensor(0.0023))
"""
real_features: List[Tensor]
fake_features: List[Tensor]
higher_is_better: bool = False
is_differentiable: bool = False
def __init__(
self,
feature: Union[str, int, torch.nn.Module] = 2048,
subsets: int = 100,
subset_size: int = 1000,
degree: int = 3,
gamma: Optional[float] = None, # type: ignore
coef: float = 1.0,
reset_real_features: bool = True,
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(compute_on_step=compute_on_step, **kwargs)
rank_zero_warn(
"Metric `Kernel Inception Distance` will save all extracted features in buffer."
" For large datasets this may lead to large memory footprint.",
UserWarning,
)
if isinstance(feature, (str, int)):
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"Kernel Inception Distance metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
valid_int_input = ("logits_unbiased", 64, 192, 768, 2048)
if feature not in valid_int_input:
raise ValueError(
f"Integer input to argument `feature` must be one of {valid_int_input}," f" but got {feature}."
)
self.inception: Module = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
elif isinstance(feature, Module):
self.inception = feature
else:
raise TypeError("Got unknown input to argument `feature`")
if not (isinstance(subsets, int) and subsets > 0):
raise ValueError("Argument `subsets` expected to be integer larger than 0")
self.subsets = subsets
if not (isinstance(subset_size, int) and subset_size > 0):
raise ValueError("Argument `subset_size` expected to be integer larger than 0")
self.subset_size = subset_size
if not (isinstance(degree, int) and degree > 0):
raise ValueError("Argument `degree` expected to be integer larger than 0")
self.degree = degree
if gamma is not None and not (isinstance(gamma, float) and gamma > 0):
raise ValueError("Argument `gamma` expected to be `None` or float larger than 0")
self.gamma = gamma
if not (isinstance(coef, float) and coef > 0):
raise ValueError("Argument `coef` expected to be float larger than 0")
self.coef = coef
if not isinstance(reset_real_features, bool):
raise ValueError("Arugment `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features
# states for extracted features
self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)
def update(self, imgs: Tensor, real: bool) -> None: # type: ignore
"""Update the state with extracted features.
Args:
imgs: tensor with images feed to the feature extractor
real: bool indicating if ``imgs`` belong to the real or the fake distribution
"""
features = self.inception(imgs)
if real:
self.real_features.append(features)
else:
self.fake_features.append(features)
def compute(self) -> Tuple[Tensor, Tensor]:
"""Calculate KID score based on accumulated extracted features from the two distributions. Returns a tuple
of mean and standard deviation of KID scores calculated on subsets of extracted features.
Implementation inspired by `Fid Score`_
"""
real_features = dim_zero_cat(self.real_features)
fake_features = dim_zero_cat(self.fake_features)
n_samples_real = real_features.shape[0]
if n_samples_real < self.subset_size:
raise ValueError("Argument `subset_size` should be smaller than the number of samples")
n_samples_fake = fake_features.shape[0]
if n_samples_fake < self.subset_size:
raise ValueError("Argument `subset_size` should be smaller than the number of samples")
kid_scores_ = []
for _ in range(self.subsets):
perm = torch.randperm(n_samples_real)
f_real = real_features[perm[: self.subset_size]]
perm = torch.randperm(n_samples_fake)
f_fake = fake_features[perm[: self.subset_size]]
o = poly_mmd(f_real, f_fake, self.degree, self.gamma, self.coef)
kid_scores_.append(o)
kid_scores = torch.stack(kid_scores_)
return kid_scores.mean(), kid_scores.std(unbiased=False)
def reset(self) -> None:
if not self.reset_real_features:
# remove temporarily to avoid resetting
value = self._defaults.pop("real_features")
super().reset()
self._defaults["real_features"] = value
else:
super().reset()