Skip to content

Commit

Permalink
Merge branch 'master' into compute_grads_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
denproc committed Feb 7, 2023
2 parents faf561a + d04a848 commit cc75629
Show file tree
Hide file tree
Showing 26 changed files with 178 additions and 168 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ on:
jobs:
build:

runs-on: ubuntu-20.04
runs-on: ubuntu-latest
strategy:
max-parallel: 4
matrix:
python-version: [3.6]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci-mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ on:
jobs:
build:

runs-on: ubuntu-20.04
runs-on: ubuntu-latest
strategy:
max-parallel: 4
matrix:
python-version: [3.8]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v2
Expand Down
15 changes: 10 additions & 5 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@ on:
jobs:
build:

runs-on: ubuntu-20.04
runs-on: ubuntu-latest
env:
USING_COVERAGE: '3.6'
USING_COVERAGE: '3.7'

strategy:
max-parallel: 4
matrix:
python-version: [3.6, 3.7, 3.8]
torchvision-version: [0.6.1, 0.9.1]
include:
- python-version: "3.7"
torchvision-version: "0.6.1"
- python-version: "3.7"
torchvision-version: "0.14.1"
- python-version: "3.10"
torchvision-version: "0.14.1"

steps:
- uses: actions/checkout@v2
Expand All @@ -36,7 +41,7 @@ jobs:
${{ runner.os }}-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pip setuptools wheel
pip install torchvision==${{ matrix.torchvision-version }}
pip install -r requirements.txt
pip install pytest \
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ venv.bak/
# VSCode settings
.vscode/

# MacOS
.DS_Store

# Files
data/
*.zip
*.ipynb
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ We provide:
* Extensive user input validation. Your code will not crash in the middle of the training.
* Fast (GPU computations available) and reliable.
* Most metrics can be backpropagated for model optimization.
* Supports python 3.6-3.8.
* Supports python 3.7-3.10.

PIQ was initially named `PhotoSynthesis.Metrics <https://pypi.org/project/photosynthesis-metrics/0.4.0/>`_.

Expand Down
12 changes: 6 additions & 6 deletions piq/brisque.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _aggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch

def _natural_scene_statistics(luma: torch.Tensor, kernel_size: int = 7, sigma: float = 7. / 6) -> torch.Tensor:
kernel = gaussian_filter(kernel_size=kernel_size,
sigma=sigma, dtype=luma.dtype).view(1, 1, kernel_size, kernel_size).to(luma)
sigma=sigma, dtype=luma.dtype, device=luma.device).view(1, 1, kernel_size, kernel_size)
C = 1
mu = F.conv2d(luma, kernel, padding=kernel_size // 2)
mu_sq = mu ** 2
Expand Down Expand Up @@ -209,10 +209,10 @@ def _scale_features(features: torch.Tensor) -> torch.Tensor:
[0.471, 3.264], [0.012809, 0.703171], [0.218, 1.046],
[-0.094876, 0.187459], [1.5e-005, 0.442057], [0.001272, 0.40803],
[0.222, 1.042], [-0.115772, 0.162604], [1.6e-005, 0.444362],
[0.001374, 0.40243], [0.227, 0.996],
[-0.117188, 0.09832299999999999], [3e-005, 0.531903],
[0.001122, 0.369589], [0.228, 0.99], [-0.12243, 0.098658],
[2.8e-005, 0.530092], [0.001118, 0.370399]]).to(features)
[0.001374, 0.40243], [0.227, 0.996], [-0.117188, 0.09832299999999999],
[3e-005, 0.531903], [0.001122, 0.369589], [0.228, 0.99], [-0.12243, 0.098658],
[2.8e-005, 0.530092],
[0.001118, 0.370399]], device=features.device, dtype=features.dtype)

scaled_features = lower_bound + (upper_bound - lower_bound) * (features - feature_ranges[..., 0]) / (
feature_ranges[..., 1] - feature_ranges[..., 0])
Expand All @@ -236,5 +236,5 @@ def _score_svr(features: torch.Tensor) -> torch.Tensor:
rho = -153.591
sv.t_()
kernel_features = _rbf_kernel(features=features, sv=sv, gamma=gamma)
score = kernel_features @ sv_coef.to(dtype=features.dtype)
score = kernel_features @ sv_coef.type(features.dtype)
return score - rho
22 changes: 12 additions & 10 deletions piq/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch.nn.functional as F

from typing import Union
from typing import Union, Optional
from torch.nn.modules.loss import _Loss

from piq.utils import _validate_input, _reduce
Expand Down Expand Up @@ -84,7 +84,7 @@ def dss(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean',
dct_y = _dct_decomp(y_lum, dct_size)

# Create a Gaussian window that will be used to weight subbands scores
coords = torch.arange(1, dct_size + 1).to(x)
coords = torch.arange(1, dct_size + 1, dtype=x.dtype, device=x.device)
weight = (coords - 0.5) ** 2
weight = (- (weight.unsqueeze(0) + weight.unsqueeze(1)) / (2 * sigma_weight ** 2)).exp()

Expand Down Expand Up @@ -135,8 +135,8 @@ def _subband_similarity(x: torch.Tensor, y: torch.Tensor, first_term: bool,
c = dc_coeff if first_term else ac_coeff

# Compute local variance
kernel = gaussian_filter(kernel_size=kernel_size, sigma=sigma)
kernel = kernel.view(1, 1, kernel_size, kernel_size).to(x)
kernel = gaussian_filter(kernel_size=kernel_size, sigma=sigma, dtype=x.dtype, device=x.device)
kernel = kernel.view(1, 1, kernel_size, kernel_size)
mu_x = F.conv2d(x, kernel, padding=kernel_size // 2)
mu_y = F.conv2d(y, kernel, padding=kernel_size // 2)

Expand All @@ -162,17 +162,19 @@ def _subband_similarity(x: torch.Tensor, y: torch.Tensor, first_term: bool,
return similarity


def _dct_matrix(size: int) -> torch.Tensor:
def _dct_matrix(size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
r""" Computes the matrix coefficients for DCT transform using the following formula:
https://fr.mathworks.com/help/images/discrete-cosine-transform.html
Args:
size : size of DCT matrix to create. (`size`, `size`)
size : size of DCT matrix to create. (`size`, `size`)
device: target device for generation
dtype: target data type for generation
"""
p = torch.arange(1, size).reshape((size - 1, 1))
q = torch.arange(1, 2 * size, 2)
p = torch.arange(1, size, device=device, dtype=dtype).reshape((size - 1, 1))
q = torch.arange(1, 2 * size, 2, device=device, dtype=dtype)
return torch.cat((
math.sqrt(1 / size) * torch.ones((1, size)),
math.sqrt(1 / size) * torch.ones((1, size), device=device, dtype=dtype),
math.sqrt(2 / size) * torch.cos(math.pi / (2 * size) * p * q)), 0)


Expand All @@ -196,7 +198,7 @@ def _dct_decomp(x: torch.Tensor, dct_size: int = 8) -> torch.Tensor:
blocks = blocks.view(bs, 1, -1, dct_size, dct_size) # shape (bs, 1, block_num, N, N)

# apply DCT transform
coeffs = _dct_matrix(dct_size).to(x)
coeffs = _dct_matrix(dct_size, device=x.device, dtype=x.dtype)

blocks = coeffs @ blocks @ coeffs.t() # @ does operation on last 2 channels only

Expand Down
107 changes: 34 additions & 73 deletions piq/fsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
import math
import functools
from typing import Union, Tuple
from typing import Union

import torch
from torch.nn.modules.loss import _Loss
Expand Down Expand Up @@ -84,20 +84,16 @@ def fsim(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean',
x_lum = x
y_lum = y

# Compute filters
filters = _construct_filters(x_lum, scales, orientations, min_length, mult, sigma_f, delta_theta)

# Compute phase congruency maps
pc_x = _phase_congruency(
x_lum, scales=scales, orientations=orientations,
min_length=min_length, mult=mult, sigma_f=sigma_f,
delta_theta=delta_theta, k=k
)
pc_y = _phase_congruency(
y_lum, scales=scales, orientations=orientations,
min_length=min_length, mult=mult, sigma_f=sigma_f,
delta_theta=delta_theta, k=k
)
pc_x = _phase_congruency(x_lum, filters=filters, scales=scales, orientations=orientations, k=k)
pc_y = _phase_congruency(y_lum, filters=filters, scales=scales, orientations=orientations, k=k)

# Gradient maps
kernels = torch.stack([scharr_filter(), scharr_filter().transpose(-1, -2)])
sch_filter = scharr_filter(device=x_lum.device, dtype=x_lum.dtype)
kernels = torch.stack([sch_filter, sch_filter.transpose(-1, -2)])
grad_map_x = gradient_map(x_lum, kernels)
grad_map_y = gradient_map(y_lum, kernels)

Expand All @@ -123,9 +119,8 @@ def fsim(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean',
return _reduce(result, reduction)


def _construct_filters(x: torch.Tensor, scales: int = 4, orientations: int = 4,
min_length: int = 6, mult: int = 2, sigma_f: float = 0.55,
delta_theta: float = 1.2, k: float = 2.0):
def _construct_filters(x: torch.Tensor, scales: int = 4, orientations: int = 4, min_length: int = 6,
mult: int = 2, sigma_f: float = 0.55, delta_theta: float = 1.2):
"""Creates a stack of filters used for computation of phase congruensy maps
Args:
Expand All @@ -140,24 +135,33 @@ def _construct_filters(x: torch.Tensor, scales: int = 4, orientations: int = 4,
delta_theta: Ratio of angular interval between filter orientations
and the standard deviation of the angular Gaussian function
used to construct filters in the freq. plane.
k: No of standard deviations of the noise energy beyond the mean
at which we set the noise threshold point, below which phase
congruency values get penalized.
"""
Returns:
Tensor with filters. Shape :math:`(1, scales * orientations, H, W)`
"""
N, _, H, W = x.shape

# Calculate the standard deviation of the angular Gaussian function
# used to construct filters in the freq. plane.
theta_sigma = math.pi / (orientations * delta_theta)

# Pre-compute some stuff to speed up filter construction
grid_x, grid_y = get_meshgrid((H, W))
grid_x, grid_y = get_meshgrid((H, W), device=x.device, dtype=x.dtype)

# Move grid to GPU early on, so that all math heavy stuff computes faster.
grid_x, grid_y = grid_x.to(x), grid_y.to(x)
radius = torch.sqrt(grid_x ** 2 + grid_y ** 2)
theta = torch.atan2(-grid_y, grid_x)

# First construct a low-pass filter that is as large as possible, yet falls
# away to zero at the boundaries. All log Gabor filters are multiplied by
# this to ensure no extra frequencies at the 'corners' of the FFT are
# incorporated as this seems to upset the normalisation process when
# Computed explicitly without _lowpassfilter
# Explicit low pass filter computation
n = 15 # default parameter
cutoff = .45 # default parameter
assert 0 < cutoff <= 0.5, "Cutoff frequency must be between 0 and 0.5"
assert n > 1 and int(n) == n, "n must be an integer >= 1"
lp = ifftshift(1. / (1.0 + (radius / cutoff) ** (2 * n)))

# Quadrant shift radius and theta so that filters are constructed with 0 frequency at the corners.
# Get rid of the 0 radius value at the 0 frequency point (now at top-left corner)
# so that taking the log of the radius will not cause trouble.
Expand All @@ -173,12 +177,6 @@ def _construct_filters(x: torch.Tensor, scales: int = 4, orientations: int = 4,
# 2) The angular component, which controls the orientation that the filter responds to.
# The two components are multiplied together to construct the overall filter.

# First construct a low-pass filter that is as large as possible, yet falls
# away to zero at the boundaries. All log Gabor filters are multiplied by
# this to ensure no extra frequencies at the 'corners' of the FFT are
# incorporated as this seems to upset the normalisation process when
lp = _lowpassfilter(size=(H, W), cutoff=.45, n=15).to(x)

# Construct the radial filter components...
log_gabor = []
for s in range(scales):
Expand All @@ -189,6 +187,8 @@ def _construct_filters(x: torch.Tensor, scales: int = 4, orientations: int = 4,
gabor_filter[0, 0] = 0
log_gabor.append(gabor_filter)

log_gabor = torch.stack(log_gabor)

# Then construct the angular filter components...
spread = []
for o in range(orientations):
Expand All @@ -204,30 +204,21 @@ def _construct_filters(x: torch.Tensor, scales: int = 4, orientations: int = 4,
spread.append(torch.exp((- dtheta ** 2) / (2 * theta_sigma ** 2)))

spread = torch.stack(spread)
log_gabor = torch.stack(log_gabor)

# Multiply, add batch dimension and transfer to correct device.
filters = (spread.repeat_interleave(scales, dim=0) * log_gabor.repeat(orientations, 1, 1)).unsqueeze(0)
return filters


def _phase_congruency(x: torch.Tensor, scales: int = 4, orientations: int = 4,
min_length: int = 6, mult: int = 2, sigma_f: float = 0.55,
delta_theta: float = 1.2, k: float = 2.0) -> torch.Tensor:
def _phase_congruency(x: torch.Tensor, filters: torch.Tensor, scales: int = 4, orientations: int = 4,
k: float = 2.0) -> torch.Tensor:
r"""Compute Phase Congruence for a batch of greyscale images
Args:
x: Tensor. Shape :math:`(N, 1, H, W)`.
filters: Kernels to extract features.
scales: Number of wavelet scales
orientations: Number of filter orientations
min_length: Wavelength of smallest scale filter
mult: Scaling factor between successive filters
sigma_f: Ratio of the standard deviation of the Gaussian
describing the log Gabor filter's transfer function
in the frequency domain to the filter center frequency.
delta_theta: Ratio of angular interval between filter orientations
and the standard deviation of the angular Gaussian function
used to construct filters in the freq. plane.
k: No of standard deviations of the noise energy beyond the mean
at which we set the noise threshold point, below which phase
congruency values get penalized.
Expand All @@ -241,7 +232,6 @@ def _phase_congruency(x: torch.Tensor, scales: int = 4, orientations: int = 4,
N, _, H, W = x.shape

# Fourier transform
filters = _construct_filters(x, scales, orientations, min_length, mult, sigma_f, delta_theta, k)
recommended_torch_version = _parse_version('1.8.0')
torch_version = _parse_version(torch.__version__)
if len(torch_version) != 0 and torch_version >= recommended_torch_version:
Expand Down Expand Up @@ -308,7 +298,7 @@ def _phase_congruency(x: torch.Tensor, scales: int = 4, orientations: int = 4,

sum_an2 = torch.sum(filters_ifft ** 2, dim=-3, keepdim=True)

sum_ai_aj = torch.zeros(N, orientations, 1, H, W).to(x)
sum_ai_aj = torch.zeros(N, orientations, 1, H, W, dtype=x.dtype, device=x.device)
for s in range(scales - 1):
sum_ai_aj = sum_ai_aj + (filters_ifft[:, :, s: s + 1] * filters_ifft[:, :, s + 1:]).sum(dim=-3, keepdim=True)

Expand Down Expand Up @@ -338,41 +328,12 @@ def _phase_congruency(x: torch.Tensor, scales: int = 4, orientations: int = 4,
# Apply noise threshold
energy = torch.max(energy - T, torch.zeros_like(T))

eps = torch.finfo(energy.dtype).eps
energy_all = energy.sum(dim=[1, 2]) + eps
an_all = an.sum(dim=[1, 2]) + eps
energy_all = energy.sum(dim=[1, 2]) + EPS
an_all = an.sum(dim=[1, 2]) + EPS
result_pc = energy_all / an_all
return result_pc.unsqueeze(1)


def _lowpassfilter(size: Tuple[int, int], cutoff: float, n: int) -> torch.Tensor:
r"""
Constructs a low-pass Butterworth filter.
Args:
size: Tuple with height and width of filter to construct
cutoff: Cutoff frequency of the filter in (0, 0.5()
n: Filter order. Higher `n` means sharper transition.
Note that `n` is doubled so that it is always an even integer.
Returns:
f = 1 / (1 + w/cutoff) ^ 2n
Note:
The frequency origin of the returned filter is at the corners.
"""
assert 0 < cutoff <= 0.5, "Cutoff frequency must be between 0 and 0.5"
assert n > 1 and int(n) == n, "n must be an integer >= 1"

grid_x, grid_y = get_meshgrid(size)

# A matrix with every pixel = radius relative to centre.
radius = torch.sqrt(grid_x ** 2 + grid_y ** 2)

return ifftshift(1. / (1.0 + (radius / cutoff) ** (2 * n)))


class FSIMLoss(_Loss):
r"""Creates a criterion that measures the FSIM or FSIMc for input :math:`x` and target :math:`y`.
Expand Down
Loading

0 comments on commit cc75629

Please sign in to comment.