diff --git a/.github/workflows/ci-mypy.yml b/.github/workflows/ci-mypy.yml index a47742c1..b20979fd 100644 --- a/.github/workflows/ci-mypy.yml +++ b/.github/workflows/ci-mypy.yml @@ -24,6 +24,6 @@ jobs: python-version: ${{ matrix.python-version }} - name: Check types with mypy run: | - python3 -m pip install mypy + python3 -m pip install mypy types-setuptools # stop the build if there are Python syntax errors or undefined names python3 -m mypy piq/ --allow-redefinition diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 39e793e0..ecb4c9e5 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -18,7 +18,7 @@ jobs: matrix: include: - python-version: "3.7" - torchvision-version: "0.6.1" + torchvision-version: "0.10.0" - python-version: "3.7" torchvision-version: "0.14.1" - python-version: "3.10" @@ -40,6 +40,9 @@ jobs: ${{ runner.os }}-pip- ${{ runner.os }}- - name: Install dependencies + # It is important to freeze scikit-image to version < 0.21.0 because there they change the default number of + # dimensions in loaded images, which breaks image loading-related tests + # Might update the version in the future but tests need to be adjusted accordingly. run: | python -m pip install --upgrade pip setuptools wheel pip install torchvision==${{ matrix.torchvision-version }} @@ -49,7 +52,7 @@ jobs: tensorflow \ libsvm \ pybrisque \ - scikit-image \ + "scikit-image<=0.20.0" \ pandas \ tqdm pip install --upgrade scipy diff --git a/README.rst b/README.rst index a7b7666b..92de39ee 100644 --- a/README.rst +++ b/README.rst @@ -219,7 +219,7 @@ Benchmark --------- As part of our library we provide `code to benchmark `_ all metrics on a set of common Mean Opinon Scores databases. -Currently we support `TID2013`_, `KADID10k`_ and `PIPAL`_. +Currently we support several Full-Reference (`TID2013`_, `KADID10k`_ and `PIPAL`_) and No-Reference (`KonIQ10k`_ and `LIVE-itW`_) datasets. You need to download them separately and provide path to images as an argument to the script. Here is an example how to evaluate SSIM and MS-SSIM metrics on TID2013 dataset: @@ -228,7 +228,7 @@ Here is an example how to evaluate SSIM and MS-SSIM metrics on TID2013 dataset: python3 tests/results_benchmark.py --dataset tid2013 --metrics SSIM MS-SSIM --path ~/datasets/tid2013 --batch_size 16 -Below we provide a comparison between `Spearman's Rank Correlation cCoefficient `_ (SRCC) values obtained with PIQ and reported in surveys. +Below we provide a comparison between `Spearman's Rank Correlation Coefficient `_ (SRCC) values obtained with PIQ and reported in surveys. Closer SRCC values indicate the higher degree of agreement between results of computations on given datasets. We do not report `Kendall rank correlation coefficient `_ (KRCC) as it is highly correlated with SRCC and provides limited additional information. @@ -237,6 +237,8 @@ as it's highly dependent on fitting method and is biased towards simple examples For metrics that can take greyscale or colour images, ``c`` means chromatic version. +Full-Reference (FR) Datasets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ =========== =========================== =========================== =========================== \ TID2013 KADID10k PIPAL ----------- --------------------------- --------------------------- --------------------------- @@ -264,6 +266,7 @@ LPIPS-VGG 0.67 / 0.67 `DISTS`_ 0.72 / - 0.57 / 0. PieAPP 0.84 / 0.88 `DISTS`_ 0.87 / - 0.70 / 0.71 `PIPAL`_ DISTS 0.81 / 0.83 `DISTS`_ 0.88 / - 0.62 / 0.66 `PIPAL`_ BRISQUE 0.37 / 0.84 `Eval2019`_ 0.33 / 0.53 `KADID10k`_ 0.21 / - +CLIP-IQA 0.50 / - 0.48 / - 0.26 / - IS 0.26 / - 0.25 / - 0.09 / - FID 0.67 / - 0.66 / - 0.18 / - KID 0.42 / - 0.66 / - 0.12 / - @@ -271,6 +274,17 @@ MSID 0.21 / - 0.32 / - 0.01 / - GS 0.37 / - 0.37 / - 0.02 / - =========== =========================== =========================== =========================== +No-Reference (NR) Datasets +^^^^^^^^^^^^^^^^^^^^^^^^^^ +=========== =========================== =========================== + \ KonIQ10k LIVE-itW +----------- --------------------------- --------------------------- + Source PIQ / Reference PIQ / Reference +=========== =========================== =========================== +BRISQUE 0.22 / - 0.31 / - +CLIP-IQA 0.68 / 0.68 `CLIP-IQA off`_ 0.64 / 0.64 `CLIP-IQA off`_ +=========== =========================== =========================== + .. _TID2013: http://www.ponomarenko.info/tid2013.htm .. _KADID10k: http://database.mmsp-kn.de/kadid-10k-database.html .. _Eval2019: https://ieeexplore.ieee.org/abstract/document/8847307/ @@ -280,6 +294,9 @@ GS 0.37 / - 0.37 / - 0.02 / - .. _HaarPSI: https://arxiv.org/abs/1607.06140 .. _PIPAL: https://arxiv.org/pdf/2011.15002.pdf .. _IW-SSIM: https://ieeexplore.ieee.org/document/7442122 +.. _KonIQ10k: http://database.mmsp-kn.de/koniq-10k-database.html +.. _LIVE-itW: https://live.ece.utexas.edu/research/ChallengeDB/index.html +.. _CLIP-IQA off: https://github.com/IceClear/CLIP-IQA Unlike FR and NR IQMs, designed to compute an image-wise distance, the DB metrics compare distributions of *sets* of images. To address these problems, we adopt a different way of computing the DB IQMs proposed in ``_. diff --git a/piq/__init__.py b/piq/__init__.py index 525aa470..a169ab74 100644 --- a/piq/__init__.py +++ b/piq/__init__.py @@ -22,3 +22,4 @@ from .pieapp import PieAPP from .dss import dss, DSSLoss from .iw_ssim import information_weighted_ssim, InformationWeightedSSIMLoss +from .clip_iqa import CLIPIQA diff --git a/piq/clip_iqa.py b/piq/clip_iqa.py new file mode 100644 index 00000000..e51c7b17 --- /dev/null +++ b/piq/clip_iqa.py @@ -0,0 +1,128 @@ +r"""This module implements CLIP-IQA metric in PyTorch. + +The metric is proposed in: +"Exploring CLIP for Assessing the Look and Feel of Images" +by Jianyi Wang, Kelvin C.K. Chan and Chen Change Loy. +AAAI 2023. +https://arxiv.org/abs/2207.12396 + +This implementation is inspired by the offisial implementation but avoids using MMCV and MMEDIT libraries. +Ref url: https://github.com/IceClear/CLIP-IQA +""" +import os +import torch + +from torch.nn.modules.loss import _Loss +from typing import Union + +from piq.feature_extractors import clip +from piq.utils.common import download_tensor, _validate_input + + +OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) +TOKENS_URL = "https://github.com/photosynthesis-team/piq/releases/download/v0.7.1/clipiqa_tokens.pt" + + +class CLIPIQA(_Loss): + r"""Creates a criterion that measures image quality based on a general notion of text-to-image similarity + learned by the CLIP[1] model during its large-scale pre-training on a large dataset with paired texts and images. + + The method is based on the idea that two antonyms ("Good photo" and "Bad photo") can be used as anchors in the + text embedding space representing good and bad images in terms of their image quality. + + After the anchors are defined, one can use them to determine the quality of a given image in the following way: + 1. Compute the image embedding of the image of interest using the pre-trained CLIP model; + 2. Compute the text embeddings of the selected anchor antonyms; + 3. Compute the angle (cosine similarity) between the image embedding (1) and both text embeddings (2); + 4. Compute the Softmax of cosine similarities (3) -> CLIP-IQA[2] score. + + This method is proposed to eliminate the linguistic ambiguity of the naive approach + (using a single prompt, e.g., "Good photo"). + + This method has an extension called CLIP-IQA+[2] proposed in the same research paper. + It uses the same approach but also fine-tunes the CLIP weights using the CoOp[3] fine-tuning algorithm. + + Note: + The initial computation of the metric is performed in `float32` and other dtypes (i.e. `float16`, `float64`) + are not supported. We preserve this behaviour for reproducibility perposes. Also, at the time of writing + conv2d is not supported for `float16` tensors on CPU. + + Warning: + In order to avoid implicit dtype conversion and normalization of input tensors, they are copied. + Note that it may consume extra memory, which might be noticible on large batch sizes. + + Args: + data_range: Maximum value range of images (usually 1.0 or 255). + + Examples: + >>> from piq import CLIPIQA + >>> clipiqa = CLIPIQA() + >>> x = torch.rand(1, 3, 224, 224) + >>> score = clipiqa(x) + + References: + [1] Radford, Alec, et al. "Learning transferable visual models from natural language supervision." + International conference on machine learning. PMLR, 2021. + [2] Wang, Jianyi, Kelvin CK Chan, and Chen Change Loy. "Exploring CLIP for Assessing the Look + and Feel of Images." arXiv preprint arXiv:2207.12396 (2022). + [3] Zhou, Kaiyang, et al. "Learning to prompt for vision-language models." International + Journal of Computer Vision 130.9 (2022): 2337-2348. + """ + def __init__(self, data_range: Union[float, int] = 1.) -> None: + super().__init__() + + self.feature_extractor = clip.load().eval() + for param in self.feature_extractor.parameters(): + param.requires_grad = False + + # Pre-computed tokens for prompt pairs: "Good photo.", "Bad photo.". + tokens = download_tensor(TOKENS_URL, os.path.expanduser("~/.cache/clip")) + + anchors = self.feature_extractor.encode_text(tokens).float() + anchors = anchors / anchors.norm(dim=-1, keepdim=True) + + self.data_range = float(data_range) + default_mean = torch.tensor(OPENAI_CLIP_MEAN).view(1, 3, 1, 1) + default_std = torch.tensor(OPENAI_CLIP_STD).view(1, 3, 1, 1) + self.logit_scale = self.feature_extractor.logit_scale.exp() + + # Take advantage of Torch buffers. CLIPIQA.to(device) will move these to the device as well. + self.register_buffer("anchors", anchors) + self.register_buffer("default_mean", default_mean) + self.register_buffer("default_std", default_std) + + def forward(self, x_input: torch.Tensor) -> torch.Tensor: + r"""Computation of CLIP-IQA metric for a given image :math:`x`. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(C, H, W)`. + The metric is designed in such a way that it expects: + - 3D or 4D PyTorch tensors; + - These tensors are have any ranges of values between 0 and 255; + - These tensros have channels first format. + + Returns: + The value of CLI-IQA score in [0, 1] range. + """ + _validate_input([x_input], dim_range=(3, 4), data_range=(0., 255.), check_for_channels_first=True) + + x = x_input.clone() + x = x.float() / self.data_range + x = (x - self.default_mean) / self.default_std + + # Device for nn.Module cannot be cached through the buffer so it has to be done here. + self.feature_extractor = self.feature_extractor.to(x) + + with torch.no_grad(): + image_features = self.feature_extractor.encode_image(x, pos_embedding=False).float() + + # Normalized features. + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + # Cosine similarity as logits. + logits_per_image = self.logit_scale * image_features @ self.anchors.t() + + probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(dim=-1) + result = probs[..., 0] + return result.detach() diff --git a/piq/feature_extractors/clip.py b/piq/feature_extractors/clip.py new file mode 100644 index 00000000..189dc7a7 --- /dev/null +++ b/piq/feature_extractors/clip.py @@ -0,0 +1,635 @@ +import hashlib +import os +import warnings + +import torch +import torch.nn.functional as F + +from torch import nn +from typing import Tuple, Union, Optional +from collections import OrderedDict +from urllib.request import urlopen +from urllib.error import URLError, HTTPError + +from piq.utils.common import is_sha256_hash + + +# We use the same model as OpenAI but store our own snapshot of it for reproducibility perposes. +PIQ_CLIP_MODEL_PATH = ( + "https://github.com/photosynthesis-team/piq/releases/download/v0.7.1/RN50.pt" +) +OPENIQA_CLIP_MODEL_PATH = ( + "https://openaipublic.azureedge.net/clip/models/" + "afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt" +) + + +def _download(url: str, root: str) -> str: + r"""Downloads model's weights and caches them. If already downloaded - loads from cache. + Performs required SHA checksum verifications. + + Args: + url: Web or file system path. + root: Absolute or relative path of the cache folder. + + Returns: + Absolute or relative path of the model's weights. + """ + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if is_sha256_hash(expected_sha256): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + else: + return download_target + + with urlopen(url) as source, open(download_target, "wb") as output: + while True: + buff = source.read(8192) + if not buff: + break + + output.write(buff) + + # Perform hash check iff hash is actually present. + if is_sha256_hash(expected_sha256): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not match") + + return download_target + + +def load() -> nn.Module: + r"""Load a CLIP model. + + Returns: + Initialized CLIP model. + """ + # We use our snapshot by default and use OpenAI link as a backup in case of some trouble. + try: + model_path = _download(PIQ_CLIP_MODEL_PATH, os.path.expanduser("~/.cache/clip")) + except (URLError, HTTPError): + model_path = _download(OPENIQA_CLIP_MODEL_PATH, os.path.expanduser("~/.cache/clip")) + + with open(model_path, "rb") as f: + model = torch.jit.load(f, map_location="cpu").eval() + + model = build_model(model.state_dict()) + model.float() + return model + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # All conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1. + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # Downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1. + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: Optional[int] = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + self.spacial_dim = spacial_dim + self.embed_dim = embed_dim + + def forward(self, x, return_token=False, pos_embedding=True): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + if pos_embedding: + positional_embedding_resize = ( + F.interpolate( + self.positional_embedding.unsqueeze(0).unsqueeze(0), + size=(x.size(0), x.size(2)), + mode="bicubic", + ) + .squeeze(0) + .squeeze(0) + ) + x = x + positional_embedding_resize[:, None, :].to(x.dtype) # (HW+1)NC + + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + if return_token: + return x[0], x[1:] + else: + return x[0] + + +class ModifiedResNet(nn.Module): + r"""A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1. + - The final pooling layer is a QKV attention instead of an average pool. + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # The 3-layer stem. + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # Residual layers. + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # The ResNet feature dimension. + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, return_token=False, pos_embedding=True): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if return_token: + x, tokens = self.attnpool(x, return_token, pos_embedding) + return x, tokens + else: + x = self.attnpool(x, return_token, pos_embedding) + return x + + +class LayerNorm(nn.LayerNorm): + r"""Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + r"""Modified version of GeLU activation function.""" + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None + else None + ) + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor, return_token=False, pos_embedding=True): + x = self.conv1(x) # Shape = [*, width, grid, grid]. + x = x.reshape(x.shape[0], x.shape[1], -1) # Shape = [*, width, grid ** 2]. + x = x.permute(0, 2, 1) # Shape = [*, grid ** 2, width]. + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + + if pos_embedding: + positional_embedding_resize = ( + F.interpolate( + self.positional_embedding.unsqueeze(0).unsqueeze(0), + size=(x.size(1), x.size(2)), + mode="bicubic", + ) + .squeeze(0) + .squeeze(0) + ) + x = x + positional_embedding_resize.to(x.dtype) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + token = self.ln_post(x[:, 1:, :]) + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + if return_token: + return x, token + else: + return x + + +class CLIP(nn.Module): + r"""General class of CLIP model. Supports various backbones. + Taken from the original implementation by Open AI: https://github.com/openai/CLIP. + """ + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, ...], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width, + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log()) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, + self.visual.layer2, + self.visual.layer3, + self.visual.layer4, + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # Lazily create causal attention mask, with full attention between the vision tokens. + # PyTorch uses additive attention mask; fill with -inf. + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image, pos_embedding): + return self.visual(image.type(self.dtype), pos_embedding=pos_embedding) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # Take features from the eot embedding (eot_token is the highest number in each sequence). + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text, pos_embedding=True, text_features=None): + image_features = self.encode_image(image, pos_embedding) + if text_features is None: + text_features = self.encode_text(text) + + # Normalized features. + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # Cosine similarity as logits. + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # Shape = [global_batch_size, global_batch_size]. + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + r"""Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + r"""Builds CLIP model based on a pre-loaded checkpoint. + Supports ViT and CNN backbones. + + Args: + state_dict: A pre-loaded checkpoint of torch.nn.Module. + """ + vit = "visual.proj" in state_dict + + vision_layers: Union[Tuple[int, ...], int] + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [ + k + for k in state_dict.keys() + if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") + ] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + vision_layers = tuple( + len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"visual.layer{b}") + ) + ) + for b in [1, 2, 3, 4] + ) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/piq/utils/common.py b/piq/utils/common.py index 1ceb336a..b0608e09 100644 --- a/piq/utils/common.py +++ b/piq/utils/common.py @@ -1,8 +1,10 @@ import torch import re +import os import warnings from typing import Tuple, List, Optional, Union, Dict, Any +from urllib.request import urlopen SEMVER_VERSION_PATTERN = re.compile( r""" @@ -62,8 +64,8 @@ def _validate_input( tensors: List[torch.Tensor], dim_range: Tuple[int, int] = (0, -1), data_range: Tuple[float, float] = (0., -1.), - # size_dim_range: Tuple[float, float] = (0., -1.), size_range: Optional[Tuple[int, int]] = None, + check_for_channels_first: bool = False ) -> None: r"""Check that input(-s) satisfies the requirements Args: @@ -99,6 +101,11 @@ def _validate_input( f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}' assert t.max() <= data_range[1], \ f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}' + + if check_for_channels_first: + channels_last = t.shape[-1] in {1, 2, 3} + assert not channels_last, "Expected tensor to have channels first format, but got channels last. \ + Please permute channels (e.g. t.permute(0, 3, 1, 2) for 4D tensors) and rerun." def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: @@ -156,3 +163,36 @@ def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]: release = tuple(int(i) for i in match.group("release").split(".")) return release + + +def download_tensor(url: str, root: str, map_location: str = 'cpu') -> torch.Tensor: + r"""Downloads torch tensor and caches it. If already downloaded - loads from cache. + + Args: + url: Web or file system path. + root: Absolute or relative path of the cache folder. + + Returns: + Loaded torch tesnor. + """ + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + download_target = os.path.join(root, filename) + if os.path.isfile(download_target): + return torch.load(download_target, map_location=map_location) + + with urlopen(url) as source, open(download_target, "wb") as output: + while True: + buff = source.read(8192) + if not buff: + break + + output.write(buff) + + return torch.load(download_target, map_location=map_location) + + +def is_sha256_hash(string: str) -> Optional[re.Match]: + """ Checks whether the provided sting is a valid SHA256 hash. """ + pattern = re.compile("^[a-fA-F0-9]{64}$") + return pattern.match(string) diff --git a/requirements.txt b/requirements.txt index 6b4e0d03..08a8a9be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -torchvision>=0.6.1,!=0.9.0 +torchvision>=0.10.0 diff --git a/tests/results_benchmark.py b/tests/results_benchmark.py index 3113a46b..2d637b58 100644 --- a/tests/results_benchmark.py +++ b/tests/results_benchmark.py @@ -4,13 +4,14 @@ import piq import tqdm import torch +import types import argparse import functools import torchvision import pandas as pd -from typing import List, Callable, Tuple +from typing import List, Callable, Tuple, Optional from pathlib import Path from skimage.io import imread from scipy.stats import spearmanr, kendalltau @@ -18,6 +19,11 @@ from dataclasses import dataclass from torch import nn from itertools import chain +from scipy.io import loadmat + + +OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) @dataclass @@ -32,7 +38,7 @@ def __post_init__(self): f'Provide one of: {valid_categories}' -torch.multiprocessing.set_sharing_strategy('file_system') +torch.multiprocessing.set_sharing_strategy("file_system") METRICS = { # Full-reference @@ -68,23 +74,30 @@ def __post_init__(self): "DSS": Metric(name="DSS", functor=functools.partial(piq.dss, data_range=255., reduction='none'), category='FR'), # No-reference - "BRISQUE": Metric(name="BRISQUE", functor=functools.partial(piq.brisque, data_range=255., reduction='none'), - category='NR'), + "BRISQUE": Metric( + name="BRISQUE", + functor=functools.partial(piq.brisque, data_range=255.0, reduction="none"), + category="NR", + ), + "CLIPIQA": Metric(name="CLIPIQA", functor=piq.CLIPIQA(data_range=255), category="NR"), # Distribution-based - "IS": Metric(name="IS", functor=piq.IS(distance='l1'), category='DB'), - "FID": Metric(name="FID", functor=piq.FID(), category='DB'), - "GS": Metric(name="GS", functor=piq.GS(), category='DB'), - "KID": Metric(name="KID", functor=piq.KID(), category='DB'), - "MSID": Metric(name="MSID", functor=piq.MSID(), category='DB'), - "PR": Metric(name="PR", functor=piq.PR(), category='DB') + "IS": Metric(name="IS", functor=piq.IS(distance="l1"), category="DB"), + "FID": Metric(name="FID", functor=piq.FID(), category="DB"), + "GS": Metric(name="GS", functor=piq.GS(), category="DB"), + "KID": Metric(name="KID", functor=piq.KID(), category="DB"), + "MSID": Metric(name="MSID", functor=piq.MSID(), category="DB"), + "PR": Metric(name="PR", functor=piq.PR(), category="DB"), } -METRIC_CATEGORIES = {cat: [k for k, v in METRICS.items() if v.category == cat] for cat in ['FR', 'NR', 'DB']} +METRIC_CATEGORIES = { + cat: [k for k, v in METRICS.items() if v.category == cat] + for cat in ["FR", "NR", "DB"] +} class TID2013(Dataset): - r""" A class to evaluate on the KADID10k dataset. + r"""A class to evaluate on the KADID10k dataset. Note that the class is callable. The values are returned as a result of calling the __getitem__ method. Args: @@ -97,20 +110,20 @@ class TID2013(Dataset): _filename = "mos_with_names.txt" def __init__(self, root: Path = "datasets/tid2013") -> None: - assert root.exists(), \ - "You need to download TID2013 dataset first. Check http://www.ponomarenko.info/tid2013" + assert ( + root.exists() + ), "You need to download TID2013 dataset first. Check http://www.ponomarenko.info/tid2013" df = pd.read_csv( - root / self._filename, - sep=' ', - names=['score', 'dist_img'], - header=None + root / self._filename, sep=" ", names=["score", "dist_img"], header=None + ) + df["ref_img"] = df["dist_img"].apply( + lambda x: f"reference_images/{(x[:3] + x[-4:]).upper()}" ) - df["ref_img"] = df["dist_img"].apply(lambda x: f"reference_images/{(x[:3] + x[-4:]).upper()}") df["dist_img"] = df["dist_img"].apply(lambda x: f"distorted_images/{x}") - self.scores = df['score'].to_numpy() - self.df = df[["dist_img", 'ref_img', 'score']] + self.scores = df["score"].to_numpy() + self.df = df[["dist_img", "ref_img", "score"]] self.root = root def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -129,7 +142,7 @@ def __len__(self) -> int: class KADID10k(TID2013): - r""" A class to evaluate on the KADID10k dataset. + r"""A class to evaluate on the KADID10k dataset. One can get the dataset via the direct link: https://datasets.vqa.mmsp-kn.de/archives/kadid10k.zip. Note that the class is callable. The values are returned as a result of calling the __getitem__ method. @@ -152,13 +165,13 @@ def __init__(self, root: Path = "datasets/kadid10k") -> None: self.df = pd.read_csv(root / self._filename) self.df.rename(columns={"dmos": "score"}, inplace=True) self.scores = self.df["score"].to_numpy() - self.df = self.df[["dist_img", 'ref_img', 'score']] + self.df = self.df[["dist_img", "ref_img", "score"]] self.root = root / "images" class PIPAL(TID2013): - r""" A class to evaluate on the train set of the PIPAL dataset. + r"""A class to evaluate on the train set of the PIPAL dataset. Note that the class is callable. The values are returned as a result of calling the __getitem__ method. Args: @@ -168,7 +181,6 @@ class PIPAL(TID2013): y: image without distortion in [0, 1] range score: MOS score for this pair of images """ - def __init__(self, root: Path = Path("data/raw/pipal")) -> None: assert root.exists(), \ "You need to download PIPAL dataset. Check https://www.jasongt.com/projectpages/pipal.html" @@ -188,14 +200,126 @@ def __init__(self, root: Path = Path("data/raw/pipal")) -> None: df["dist_img"] = df["dist_img"].apply(lambda x: f"Train_Dist/{x}") self.scores = df["score"].to_numpy() - self.df = df[["dist_img", 'ref_img', 'score']] + self.df = df[["dist_img", "ref_img", "score"]] + self.root = root + + +class KonIQ10k(Dataset): + r"""A class to evaluate on the train/test/train+test the KonIQ dataset. + http://database.mmsp-kn.de/koniq-10k-database.html + + Note that the class is callable. The values are returned as a result of calling the __getitem__ method. + + Args: + root: Root directory path. + transform: Something that pre-processes dataset images. + subset: Which part of the dataset to use. + Options: "train", "test", "all". "all" means train + test. + Returns: + x: image with some kind of distortion in [0, 1] range. + y: dummy variable, used for compatibility with FR datasets. + score: MOS score of the image quality. + """ + _filename_scores = "koniq10k_scores.csv" + _filename_dists = "koniq10k_distributions_sets.csv" + + def __init__(self, root: Path, transform: Optional[Callable] = None, subset: str = 'test') -> None: + supported_subsets = ["train", "test", "all"] + assert subset in supported_subsets, f"Unknown subset [{subset}], choose one of {supported_subsets}." + assert root.exists(), "You need to download KonIQ-10k dataset first." + + self.root = root + self.initial_image_size = "1024x768" + df1 = pd.read_csv(root / self._filename_scores) + df2 = pd.read_csv(root / self._filename_dists) + self.df = df1.merge(df2, on=["image_name"]) + if not subset == "all": + self.df = self.df[self.df.set == subset].reset_index() + + self.df["image_name"] = self.df["image_name"].apply( + lambda x: f"{self.initial_image_size}/{x}" + ) + self.scores = self.df["MOS_zscore"].to_numpy() + + self.transform = transform + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, None, float]: + x_path = self.root / self.df.at[index, "image_name"] + score = self.scores[index] + + x = imread(x_path) + y = torch.rand(1) + + if self.transform is not None: + x = self.transform(x) + return x, y, score + + x = torch.from_numpy(x).float() + x = x.permute(2, 0, 1) + return x, y, score + + def __len__(self) -> int: + return len(self.df) + + +class LIVEitW(KonIQ10k): + r"""A class to evaluate on the train/test/train+test the LIVE-in-the-Wild dataset. + https://live.ece.utexas.edu/research/ChallengeDB/index.html + + Note that the class is callable. The values are returned as a result of calling the __getitem__ method. + + WARNING: This dataset contains images with different spatial resolutions. + Hence, inference with bs > 1 may cause runtime errors. + + Args: + root: Root directory path. + transform: Something that pre-processes dataset images. + subset: Which part of the dataset to use. + Options: "train", "test", "all". "all" means train + test. + Returns: + x: image with some kind of distortion in [0, 1] range. + y: dummy variable, used for compatibility with FR datasets. + score: MOS score of the image quality. + """ + _filename_names = "AllImages_release.mat" + _filename_mos = "AllMOS_release.mat" + + def __init__(self, root: Path, transform: Optional[Callable] = None, subset: str = 'test') -> None: + supported_subsets = ["train", "test", "all"] + assert subset in supported_subsets, f"Unknown subset [{subset}], choose one of {supported_subsets}." + assert root.exists(), "You need to download LIVEitW dataset first." + + labels_folder = "Data" + names = loadmat(root / labels_folder / self._filename_names) + mos = loadmat(root / labels_folder / self._filename_mos) + + images_folder = "Images" + n_train_images = 7 # There are only 7 images in the train set that are placed in different folder. + train_paths = [root / images_folder / "trainingImages" / n[0][0] + for n in names["AllImages_release"]][:n_train_images] + test_paths = [root / images_folder / n[0][0] for n in names["AllImages_release"]][n_train_images:] + scores = mos["AllMOS_release"][0] + + if subset == "train": + self.df = pd.DataFrame().from_dict({"image_name": train_paths}) + self.scores = scores[:n_train_images] + elif subset == "test": + self.df = pd.DataFrame().from_dict({"image_name": test_paths}) + self.scores = scores[n_train_images:] + else: + self.df = pd.DataFrame().from_dict({"image_name": train_paths + test_paths}) + self.scores = scores + self.root = root + self.transform = transform DATASETS = { "tid2013": TID2013, "kadid10k": KADID10k, "pipal": PIPAL, + "koniq10k": KonIQ10k, + "liveitw": LIVEitW, } @@ -236,14 +360,14 @@ def eval_metric(loader: DataLoader, metric: Metric, device: str, feature_extract def determine_compute_function(metric_category: str) -> Callable: return { - 'FR': compute_full_reference, - 'NR': compute_no_reference, - 'DB': compute_distribution_based + "FR": compute_full_reference, + "NR": compute_no_reference, + "DB": compute_distribution_based, }[metric_category] def get_feature_extractor(feature_extractor_name: str, device: str) -> nn.Module: - r""" A factory to initialize feature extractor from its name. """ + r"""A factory to initialize feature extractor from its name.""" if feature_extractor_name == "vgg16": return torchvision.models.vgg16(pretrained=True, progress=True).features.to(device) elif feature_extractor_name == "vgg19": @@ -255,17 +379,34 @@ def get_feature_extractor(feature_extractor_name: str, device: str) -> nn.Module raise ValueError(f"Wrong feature extractor name {feature_extractor_name}") -def compute_full_reference(metric_functor: Callable, distorted_images: torch.Tensor, - reference_images: torch.Tensor, _, __) -> torch.Tensor: +def compute_full_reference( + metric_functor: Callable, + distorted_images: torch.Tensor, + reference_images: torch.Tensor, + _, + __, +) -> torch.Tensor: return metric_functor(distorted_images, reference_images).cpu() -def compute_no_reference(metric_functor: Callable, distorted_images: torch.Tensor, _, __, ___) -> torch.Tensor: +def compute_no_reference( + metric_functor: Callable, + distorted_images: torch.Tensor, + _, + device: str, + ___) -> torch.Tensor: + if not isinstance(metric_functor, types.FunctionType): + metric_functor = metric_functor.to(device) + return metric_functor(distorted_images).cpu() -def extract_features(distorted_patches: torch.Tensor, feature_extractor: nn.Module, feature_extractor_name: str, - reference_patches: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def extract_features( + distorted_patches: torch.Tensor, + feature_extractor: nn.Module, + feature_extractor_name: str, + reference_patches: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: distorted_features, reference_features = [], [] with torch.no_grad(): if feature_extractor_name == "inception": @@ -285,16 +426,22 @@ def extract_features(distorted_patches: torch.Tensor, feature_extractor: nn.Modu def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor: - r""" Map tensor values to [0, 1] """ + r"""Map tensor values to [0, 1]""" return (tensor - tensor.min()) / (tensor.max() - tensor.min()) -def compute_distribution_based(metric_functor: Callable, distorted_images: torch.Tensor, - reference_images: torch.Tensor, device: str, feature_extractor_name: str) \ - -> torch.Tensor: - feature_extractor = get_feature_extractor(feature_extractor_name=feature_extractor_name, device=device) +def compute_distribution_based( + metric_functor: Callable, + distorted_images: torch.Tensor, + reference_images: torch.Tensor, + device: str, + feature_extractor_name: str, +) -> torch.Tensor: + feature_extractor = get_feature_extractor( + feature_extractor_name=feature_extractor_name, device=device + ) - if feature_extractor_name == 'inception': + if feature_extractor_name == "inception": distorted_images = normalize_tensor(distorted_images) reference_images = normalize_tensor(reference_images) @@ -306,8 +453,9 @@ def compute_distribution_based(metric_functor: Callable, distorted_images: torch distorted_patches = distorted_patches.view(-1, *distorted_patches.shape[-3:]) reference_patches = reference_patches.view(-1, *reference_patches.shape[-3:]) - distorted_features, reference_features = extract_features(distorted_patches, feature_extractor, - feature_extractor_name, reference_patches) + distorted_features, reference_features = extract_features( + distorted_patches, feature_extractor, feature_extractor_name, reference_patches + ) return metric_functor(distorted_features, reference_features).cpu() @@ -327,8 +475,14 @@ def crop_patches(images: torch.Tensor, size: int = 64, stride: int = 32) -> torc return patches -def main(dataset_name: str, path: Path, metrics: List[str], batch_size: int, device: str, feature_extractor: str) \ - -> None: +def main( + dataset_name: str, + path: Path, + metrics: List[str], + batch_size: int, + device: str, + feature_extractor: str, +) -> None: # Init dataset and dataloader dataset = DATASETS[dataset_name](root=path) loader = DataLoader(dataset, batch_size=batch_size, num_workers=4) @@ -337,7 +491,7 @@ def main(dataset_name: str, path: Path, metrics: List[str], batch_size: int, dev if metrics[0] in METRIC_CATEGORIES: metrics = METRIC_CATEGORIES[metrics[0]] - if metrics[0] == 'all': + if metrics[0] == "all": metrics = list(chain(*METRIC_CATEGORIES.values())) for metric_name in metrics: diff --git a/tests/test_clip_iqa.py b/tests/test_clip_iqa.py new file mode 100644 index 00000000..cc723246 --- /dev/null +++ b/tests/test_clip_iqa.py @@ -0,0 +1,115 @@ +import torch +import pytest + +from PIL import Image +from piq import CLIPIQA +from torchvision.transforms import PILToTensor +from torch.nn.modules.loss import _Loss + + +@pytest.fixture(scope='module') +def x_grey() -> torch.Tensor: + return torch.rand(3, 1, 96, 96) + + +@pytest.fixture(scope='module') +def x_rgb() -> torch.Tensor: + return torch.rand(3, 3, 96, 96) + + +@pytest.fixture(scope='module') +def clipiqa() -> _Loss: + return CLIPIQA(data_range=255) + + +# ================== Test class: `CLIPIQA` ================== +def test_clip_iqa_works_with_grey_channels_last(clipiqa: _Loss, x_grey: torch.Tensor, device: str) -> None: + clipiqa = clipiqa.to(device) + clipiqa(x_grey.to(device)) + + +def test_clip_iqa_fails_with_gray_channels_first(clipiqa: _Loss, x_grey: torch.Tensor, device: str) -> None: + clipiqa = clipiqa.to(device) + x_grey = x_grey.permute(0, 2, 3, 1) + with pytest.raises(AssertionError): + clipiqa(x_grey.to(device)) + + +def test_clip_iqa_works_with_rgb_channels_last(clipiqa: _Loss, x_rgb: torch.Tensor, device: str) -> None: + clipiqa = clipiqa.to(device) + clipiqa(x_rgb.to(device)) + + +def test_clip_iqa_fails_with_rgb_channels_first(clipiqa: _Loss, x_rgb: torch.Tensor, device: str) -> None: + clipiqa = clipiqa.to(device) + x_rgb = x_rgb.permute(0, 2, 3, 1) + with pytest.raises(AssertionError): + clipiqa(x_rgb.to(device)) + + +def test_clip_iqa_values_rgb(clipiqa: _Loss, device: str) -> None: + """Reference values are obtained by running the following script on the selected images: + https://github.com/IceClear/CLIP-IQA/blob/v2-3.8/demo/clipiqa_single_image_demo.py + """ + clipiqa = clipiqa.to(device) + paths_scores = {'tests/assets/i01_01_5.bmp': 0.45898438, + 'tests/assets/I01.BMP': 0.89160156} + for path, of_score in paths_scores.items(): + img = Image.open(path) + x_rgb = PILToTensor()(img) + x_rgb = x_rgb.float()[None] + score = clipiqa(x_rgb.to(device)) + score_official = torch.tensor([of_score], dtype=torch.float, device=device) + assert torch.isclose(score, score_official, rtol=1e-2), \ + f'Expected values to be equal to baseline, got {score.item()} and {score_official}' + + +def test_clip_iqa_input_dtype_does_not_change(clipiqa: _Loss, x_rgb: torch.Tensor, device: str) -> None: + clipiqa = clipiqa.to(device) + x_rgb = x_rgb[0][None] + optional_data_types = torch.float16, torch.float64 + + for op_type in optional_data_types: + x_rgb = x_rgb.type(op_type).to(device) + clipiqa(x_rgb) + assert x_rgb.dtype == op_type, \ + f'Expect {op_type} dtype to be preserved, got {x_rgb.dtype}' + + +def test_clip_iqa_dims_work(clipiqa: _Loss, device: str) -> None: + clipiqa = clipiqa.to(device) + x_3dims = [torch.rand((3, 96, 96)), torch.rand((3, 128, 128)), torch.rand((3, 160, 160))] + for x in x_3dims: + clipiqa(x.to(device)) + + x_4dims = [torch.rand((3, 3, 96, 96)), torch.rand((4, 3, 128, 128)), torch.rand((5, 3, 160, 160))] + for x in x_4dims: + clipiqa(x.to(device)) + + +def test_clip_iqa_results_equal_for_3_and_4_dims(clipiqa: _Loss, device: str) -> None: + clipiqa = clipiqa.to(device) + x = torch.rand((3, 128, 128)) + x_copy = x[None] + x_result = clipiqa(x.to(device)) + x_copy_result = clipiqa(x_copy.to(device)) + assert torch.isclose(x_result, x_copy_result, rtol=1e-2), \ + f'Expected values to be equal, got {x_result} and {x_copy_result}' + + +def test_clip_iqa_dims_does_not_work(clipiqa: _Loss, device: str) -> None: + clipiqa = clipiqa.to(device) + x_2dims = [torch.rand((96, 96)), torch.rand((128, 128)), torch.rand((160, 160))] + with pytest.raises(AssertionError): + for x in x_2dims: + clipiqa(x.to(device)) + + x_1dims = [torch.rand((96)), torch.rand((128)), torch.rand((160))] + with pytest.raises(AssertionError): + for x in x_1dims: + clipiqa(x.to(device)) + + x_5dims = [torch.rand((1, 3, 3, 96, 96)), torch.rand((2, 4, 3, 128, 128)), torch.rand((1, 5, 3, 160, 160))] + with pytest.raises(AssertionError): + for x in x_5dims: + clipiqa(x.to(device)) diff --git a/tests/test_utils.py b/tests/test_utils.py index d869ff54..f2c678e4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,12 @@ import torch import pytest +import os +import hashlib +import re import numpy as np -from piq.utils import _validate_input, _reduce, _parse_version +from piq.utils.common import _validate_input, _reduce, _parse_version, download_tensor, is_sha256_hash @pytest.fixture(scope='module') @@ -77,7 +80,23 @@ def test_works_on_two_not_5d_tensors(tensor_1d: torch.Tensor) -> None: def test_breaks_if_tensors_have_different_n_dims(tensor_2d: torch.Tensor, tensor_5d: torch.Tensor) -> None: with pytest.raises(AssertionError): - _validate_input([tensor_2d, tensor_5d], dim_range=(2, 5)) + _validate_input([tensor_2d, tensor_5d], dim_range=(2, 5), check_for_channels_first=True) + + +def test_breaks_if_wrong_channel_order() -> None: + with pytest.raises(AssertionError): + _validate_input([torch.rand(1, 5, 5, 1)], check_for_channels_first=True) + _validate_input([torch.rand(1, 5, 5, 2)], check_for_channels_first=True) + _validate_input([torch.rand(1, 5, 5, 3)], check_for_channels_first=True) + + +def test_works_if_correct_channel_order() -> None: + try: + _validate_input([torch.rand(1, 1, 5, 5)], check_for_channels_first=True) + _validate_input([torch.rand(1, 2, 5, 5)], check_for_channels_first=True) + _validate_input([torch.rand(1, 3, 5, 5)], check_for_channels_first=True) + except Exception as e: + pytest.fail(f"Unexpected error occurred: {e}") # ================== Test function: `_reduce` ================== @@ -136,3 +155,37 @@ def test_version_tuple_parses_correctly(version, expected) -> None: def test_version_tuple_warns_on_invalid_input(version) -> None: with pytest.warns(UserWarning): _parse_version(version) + + +def test_download_tensor(): + url = "https://github.com/photosynthesis-team/piq/releases/download/v0.7.1/clipiqa_tokens.pt" + file_name = os.path.basename(url) + root = os.path.expanduser("~/.cache/clip") + + # Check if tensor gets downloaded if not cached locally. + full_file_path = os.path.join(root, file_name) + print('full_file_path', full_file_path) + if os.path.exists(full_file_path): + os.remove(full_file_path) + + assert isinstance(download_tensor(url, root), torch.Tensor) + + # Check if tensor loads if cached. + assert isinstance(download_tensor(url, root), torch.Tensor) + + +# =============== Test function: `is_sha256_hash` ============== +def test_works_for_hashes(): + example_stings = [b'the', b'the', b'meaning', b'of', b'life', b'the', b'universe', b'and', b'everything'] + for ex in example_stings: + h = hashlib.new('sha256') + h.update(ex) + h = h.hexdigest() + assert isinstance(is_sha256_hash(h), re.Match), f'Exepected re.Match, got {type(h)}' + + +def test_does_not_work_for_plane_strings(): + example_stings = ['the', 'the', 'meaning', 'of', 'life', 'the' 'universe', 'and' 'everything'] + for ex in example_stings: + with pytest.raises(AssertionError): + assert isinstance(ex, re.Match), f'Exepected re.Match, got {type(hash)}'