Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric: CLIPIQA #348

Merged
merged 49 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b85274a
initial implementation
snk4tr Feb 28, 2023
dd3c154
fix data range for results bench
snk4tr Mar 1, 2023
37a7ebc
add eval on koniq
snk4tr Apr 5, 2023
4ff87b0
fix implementation so it corresponds to the official one
snk4tr Apr 5, 2023
3515329
add tokenizer, remove clip dependency
snk4tr Apr 19, 2023
969dc40
fix model loading
snk4tr Apr 19, 2023
5bbfe7e
simplification of loading of the clip model
snk4tr Apr 19, 2023
c31bdbd
docs for the main clip-iqa file
snk4tr Apr 19, 2023
a5d9083
inference with bs > 1 + extra docs
snk4tr Apr 23, 2023
b50b389
Merge branch 'master' into feat/clip-iqa
snk4tr Apr 23, 2023
6328da1
handle images with channels first format
snk4tr May 1, 2023
f387bdc
add evaluation on the LIVE-itW dataset
snk4tr May 1, 2023
ac19f1b
simplify
snk4tr May 1, 2023
001afd1
fix wrong probs scaling
snk4tr May 1, 2023
cd79eb1
clip-iqa tests
snk4tr May 1, 2023
9088962
Merge branch 'master' into feat/clip-iqa
snk4tr May 1, 2023
3d77bb6
fix some flake8 errors
snk4tr May 1, 2023
30ea548
fix some flake8 errors
snk4tr May 1, 2023
cd0ce7f
fix some flake8 errors
snk4tr May 1, 2023
e3bbbb4
fix some flake8 errors
snk4tr May 1, 2023
4652e9a
fix some flake8 errors
snk4tr May 1, 2023
9ce6065
update workflow and refactor tokenizer
snk4tr May 1, 2023
bb8e2c8
fix some mypy errors
snk4tr May 1, 2023
f91465c
fix some errors
snk4tr May 1, 2023
5c5b841
fix errors
snk4tr May 1, 2023
86f02da
remove ftfy package
snk4tr May 1, 2023
a658434
replace tokenizer and all related logic with pre-computed tokens
snk4tr May 2, 2023
7f6e97e
update torchvision versions
snk4tr May 2, 2023
2ca08c2
benchmarks table
snk4tr May 2, 2023
1fbad82
add some changes from the review
snk4tr May 20, 2023
d35f5b8
more fixes and the test for the donwloader
snk4tr May 20, 2023
28ecc1d
more fixes + _validate_input extention + more tests
snk4tr May 20, 2023
a7b82c9
Fix readthedocs pipeline
denproc May 23, 2023
88fc321
address some comments, add tests
snk4tr May 31, 2023
26c86aa
incorporate the last comment-related changes
snk4tr May 31, 2023
eee3b6b
flake8
snk4tr May 31, 2023
1e51e95
adjust downloading for the case when sha hash is not present
snk4tr May 31, 2023
4ff2ae6
update torchvision version so torch tensors support `min_all`
snk4tr May 31, 2023
825e449
downgrade torchvision
snk4tr May 31, 2023
65762a7
address review comments
snk4tr Jun 4, 2023
06853ec
address flake and code smells
snk4tr Jun 4, 2023
053a370
address flake
snk4tr Jun 4, 2023
fb94ab9
address flake
snk4tr Jun 4, 2023
53dbf17
freeze scikit-image version, the newest version changes default shape…
snk4tr Jun 4, 2023
9091499
+
snk4tr Jun 4, 2023
dcc1f76
+
snk4tr Jun 4, 2023
a036d99
undo fake smell change
snk4tr Jun 4, 2023
04f8e73
address comments
snk4tr Jun 5, 2023
2fd9b4d
+
snk4tr Jun 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Check types with mypy
run: |
python3 -m pip install mypy
python3 -m pip install mypy types-setuptools
# stop the build if there are Python syntax errors or undefined names
python3 -m mypy piq/ --allow-redefinition
7 changes: 5 additions & 2 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
matrix:
include:
- python-version: "3.7"
torchvision-version: "0.6.1"
torchvision-version: "0.10.0"
- python-version: "3.7"
torchvision-version: "0.14.1"
- python-version: "3.10"
Expand All @@ -40,6 +40,9 @@ jobs:
${{ runner.os }}-pip-
${{ runner.os }}-
- name: Install dependencies
# It is important to freeze scikit-image to version < 0.21.0 because there they change the default number of
# dimensions in loaded images, which breaks image loading-related tests
# Might update the version in the future but tests need to be adjusted accordingly.
run: |
python -m pip install --upgrade pip setuptools wheel
pip install torchvision==${{ matrix.torchvision-version }}
Expand All @@ -49,7 +52,7 @@ jobs:
tensorflow \
libsvm \
pybrisque \
scikit-image \
"scikit-image<=0.20.0" \
pandas \
tqdm
pip install --upgrade scipy
Expand Down
21 changes: 19 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ Benchmark
---------

As part of our library we provide `code to benchmark <tests/results_benchmark.py>`_ all metrics on a set of common Mean Opinon Scores databases.
Currently we support `TID2013`_, `KADID10k`_ and `PIPAL`_.
Currently we support several Full-Reference (`TID2013`_, `KADID10k`_ and `PIPAL`_) and No-Reference (`KonIQ10k`_ and `LIVE-itW`_) datasets.
You need to download them separately and provide path to images as an argument to the script.

Here is an example how to evaluate SSIM and MS-SSIM metrics on TID2013 dataset:
Expand All @@ -228,7 +228,7 @@ Here is an example how to evaluate SSIM and MS-SSIM metrics on TID2013 dataset:

python3 tests/results_benchmark.py --dataset tid2013 --metrics SSIM MS-SSIM --path ~/datasets/tid2013 --batch_size 16

Below we provide a comparison between `Spearman's Rank Correlation cCoefficient <https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient>`_ (SRCC) values obtained with PIQ and reported in surveys.
Below we provide a comparison between `Spearman's Rank Correlation Coefficient <https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient>`_ (SRCC) values obtained with PIQ and reported in surveys.
Closer SRCC values indicate the higher degree of agreement between results of computations on given datasets.
We do not report `Kendall rank correlation coefficient <https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient>`_ (KRCC)
as it is highly correlated with SRCC and provides limited additional information.
Expand All @@ -237,6 +237,8 @@ as it's highly dependent on fitting method and is biased towards simple examples

For metrics that can take greyscale or colour images, ``c`` means chromatic version.

Full-Reference (FR) Datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
=========== =========================== =========================== ===========================
\ TID2013 KADID10k PIPAL
----------- --------------------------- --------------------------- ---------------------------
Expand Down Expand Up @@ -264,13 +266,25 @@ LPIPS-VGG 0.67 / 0.67 `DISTS`_ 0.72 / - 0.57 / 0.
PieAPP 0.84 / 0.88 `DISTS`_ 0.87 / - 0.70 / 0.71 `PIPAL`_
DISTS 0.81 / 0.83 `DISTS`_ 0.88 / - 0.62 / 0.66 `PIPAL`_
BRISQUE 0.37 / 0.84 `Eval2019`_ 0.33 / 0.53 `KADID10k`_ 0.21 / -
CLIP-IQA 0.50 / - 0.48 / - 0.26 / -
IS 0.26 / - 0.25 / - 0.09 / -
FID 0.67 / - 0.66 / - 0.18 / -
KID 0.42 / - 0.66 / - 0.12 / -
MSID 0.21 / - 0.32 / - 0.01 / -
GS 0.37 / - 0.37 / - 0.02 / -
=========== =========================== =========================== ===========================

No-Reference (NR) Datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^
=========== =========================== ===========================
\ KonIQ10k LIVE-itW
----------- --------------------------- ---------------------------
Source PIQ / Reference PIQ / Reference
=========== =========================== ===========================
BRISQUE 0.22 / - 0.31 / -
CLIP-IQA 0.68 / 0.68 `CLIP-IQA off`_ 0.64 / 0.64 `CLIP-IQA off`_
=========== =========================== ===========================

.. _TID2013: http://www.ponomarenko.info/tid2013.htm
.. _KADID10k: http://database.mmsp-kn.de/kadid-10k-database.html
.. _Eval2019: https://ieeexplore.ieee.org/abstract/document/8847307/
Expand All @@ -280,6 +294,9 @@ GS 0.37 / - 0.37 / - 0.02 / -
.. _HaarPSI: https://arxiv.org/abs/1607.06140
.. _PIPAL: https://arxiv.org/pdf/2011.15002.pdf
.. _IW-SSIM: https://ieeexplore.ieee.org/document/7442122
.. _KonIQ10k: http://database.mmsp-kn.de/koniq-10k-database.html
.. _LIVE-itW: https://live.ece.utexas.edu/research/ChallengeDB/index.html
.. _CLIP-IQA off: https://github.com/IceClear/CLIP-IQA

Unlike FR and NR IQMs, designed to compute an image-wise distance, the DB metrics compare distributions of *sets* of images.
To address these problems, we adopt a different way of computing the DB IQMs proposed in `<https://arxiv.org/abs/2203.07809>`_.
Expand Down
1 change: 1 addition & 0 deletions piq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .pieapp import PieAPP
from .dss import dss, DSSLoss
from .iw_ssim import information_weighted_ssim, InformationWeightedSSIMLoss
from .clip_iqa import CLIPIQA
128 changes: 128 additions & 0 deletions piq/clip_iqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
r"""This module implements CLIP-IQA metric in PyTorch.

The metric is proposed in:
"Exploring CLIP for Assessing the Look and Feel of Images"
by Jianyi Wang, Kelvin C.K. Chan and Chen Change Loy.
AAAI 2023.
https://arxiv.org/abs/2207.12396

This implementation is inspired by the offisial implementation but avoids using MMCV and MMEDIT libraries.
Ref url: https://github.com/IceClear/CLIP-IQA
"""
import os
import torch

from torch.nn.modules.loss import _Loss
from typing import Union

from piq.feature_extractors import clip
from piq.utils.common import download_tensor, _validate_input


OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
TOKENS_URL = "https://github.com/photosynthesis-team/piq/releases/download/v0.7.1/clipiqa_tokens.pt"


class CLIPIQA(_Loss):
r"""Creates a criterion that measures image quality based on a general notion of text-to-image similarity
learned by the CLIP[1] model during its large-scale pre-training on a large dataset with paired texts and images.

The method is based on the idea that two antonyms ("Good photo" and "Bad photo") can be used as anchors in the
text embedding space representing good and bad images in terms of their image quality.

After the anchors are defined, one can use them to determine the quality of a given image in the following way:
1. Compute the image embedding of the image of interest using the pre-trained CLIP model;
2. Compute the text embeddings of the selected anchor antonyms;
3. Compute the angle (cosine similarity) between the image embedding (1) and both text embeddings (2);
4. Compute the Softmax of cosine similarities (3) -> CLIP-IQA[2] score.

This method is proposed to eliminate the linguistic ambiguity of the naive approach
(using a single prompt, e.g., "Good photo").

This method has an extension called CLIP-IQA+[2] proposed in the same research paper.
It uses the same approach but also fine-tunes the CLIP weights using the CoOp[3] fine-tuning algorithm.

Note:
The initial computation of the metric is performed in `float32` and other dtypes (i.e. `float16`, `float64`)
are not supported. We preserve this behaviour for reproducibility perposes. Also, at the time of writing
conv2d is not supported for `float16` tensors on CPU.

Warning:
In order to avoid implicit dtype conversion and normalization of input tensors, they are copied.
Note that it may consume extra memory, which might be noticible on large batch sizes.

Args:
data_range: Maximum value range of images (usually 1.0 or 255).

Examples:
>>> from piq import CLIPIQA
>>> clipiqa = CLIPIQA()
>>> x = torch.rand(1, 3, 224, 224)
>>> score = clipiqa(x)

References:
[1] Radford, Alec, et al. "Learning transferable visual models from natural language supervision."
International conference on machine learning. PMLR, 2021.
[2] Wang, Jianyi, Kelvin CK Chan, and Chen Change Loy. "Exploring CLIP for Assessing the Look
and Feel of Images." arXiv preprint arXiv:2207.12396 (2022).
[3] Zhou, Kaiyang, et al. "Learning to prompt for vision-language models." International
Journal of Computer Vision 130.9 (2022): 2337-2348.
"""
def __init__(self, data_range: Union[float, int] = 1.) -> None:
super().__init__()

self.feature_extractor = clip.load().eval()
for param in self.feature_extractor.parameters():
param.requires_grad = False

# Pre-computed tokens for prompt pairs: "Good photo.", "Bad photo.".
tokens = download_tensor(TOKENS_URL, os.path.expanduser("~/.cache/clip"))

anchors = self.feature_extractor.encode_text(tokens).float()
anchors = anchors / anchors.norm(dim=-1, keepdim=True)

self.data_range = float(data_range)
default_mean = torch.tensor(OPENAI_CLIP_MEAN).view(1, 3, 1, 1)
default_std = torch.tensor(OPENAI_CLIP_STD).view(1, 3, 1, 1)
self.logit_scale = self.feature_extractor.logit_scale.exp()

# Take advantage of Torch buffers. CLIPIQA.to(device) will move these to the device as well.
self.register_buffer("anchors", anchors)
self.register_buffer("default_mean", default_mean)
self.register_buffer("default_std", default_std)

def forward(self, x_input: torch.Tensor) -> torch.Tensor:
r"""Computation of CLIP-IQA metric for a given image :math:`x`.

Args:
x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(C, H, W)`.
snk4tr marked this conversation as resolved.
Show resolved Hide resolved
The metric is designed in such a way that it expects:
- 3D or 4D PyTorch tensors;
- These tensors are have any ranges of values between 0 and 255;
- These tensros have channels first format.

Returns:
The value of CLI-IQA score in [0, 1] range.
"""
_validate_input([x_input], dim_range=(3, 4), data_range=(0., 255.), check_for_channels_first=True)

x = x_input.clone()
x = x.float() / self.data_range
snk4tr marked this conversation as resolved.
Show resolved Hide resolved
x = (x - self.default_mean) / self.default_std

# Device for nn.Module cannot be cached through the buffer so it has to be done here.
self.feature_extractor = self.feature_extractor.to(x)

with torch.no_grad():
image_features = self.feature_extractor.encode_image(x, pos_embedding=False).float()

# Normalized features.
image_features = image_features / image_features.norm(dim=-1, keepdim=True)

# Cosine similarity as logits.
logits_per_image = self.logit_scale * image_features @ self.anchors.t()

probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(dim=-1)
result = probs[..., 0]
return result.detach()
Loading
Loading