-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: DISTS metric #131
Conversation
Codecov Report
@@ Coverage Diff @@
## master #131 +/- ##
==========================================
+ Coverage 90.49% 91.00% +0.50%
==========================================
Files 21 22 +1
Lines 1505 1556 +51
==========================================
+ Hits 1362 1416 +54
+ Misses 143 140 -3
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great progress.
I suggest to address some changes
|
||
class DISTS(ContentLoss): | ||
r"""Deep Image Structure and Texture Similarity metric. | ||
Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really support [-1, 1]? Let's add images with [-1, 1] dynamic range to the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but forward
pass fails because of _validate_input
function :\
piq/perceptual.py
Outdated
def construct_kernel(self, x: torch.Tensor) -> torch.Tensor: | ||
r"""Returns 2D Hann window kernel with number of channels equal to input channels""" | ||
C = x.size(1) | ||
|
||
# Take bigger window and drop borders | ||
window = torch.hann_window(self.kernel_size + 2, periodic=False)[1:-1] | ||
kernel = window[:, None] * window[None, :] | ||
|
||
# Normalize and reshape kernel | ||
self.kernel = (kernel / kernel.sum()).repeat((C, 1, 1, 1)).to(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we added functional.py
into the project, which contains some filters, I would propose to move to there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I propose the following:
- Make a
utils
folder - Put the
functional.py
here "as is" - Rename
utils.py
->utils/common.py
- Create a new file
utils/layers.py
- Add
L2Pool2d
here
I also think that this refactoring is the subject of another PR. Current structure will work for the current PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this refactoring in a single commit. It's not so big to require a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice one @zakajd!
Just a few minor changes are needed and then we are good to go 👍
piq/perceptual.py
Outdated
def construct_kernel(self, x: torch.Tensor) -> torch.Tensor: | ||
r"""Returns 2D Hann window kernel with number of channels equal to input channels""" | ||
C = x.size(1) | ||
|
||
# Take bigger window and drop borders | ||
window = torch.hann_window(self.kernel_size + 2, periodic=False)[1:-1] | ||
kernel = window[:, None] * window[None, :] | ||
|
||
# Normalize and reshape kernel | ||
self.kernel = (kernel / kernel.sum()).repeat((C, 1, 1, 1)).to(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I propose the following:
- Make a
utils
folder - Put the
functional.py
here "as is" - Rename
utils.py
->utils/common.py
- Create a new file
utils/layers.py
- Add
L2Pool2d
here
I also think that this refactoring is the subject of another PR. Current structure will work for the current PR.
Also closed #132 in one of the commits |
bab2b6d
to
4022043
Compare
Kudos, SonarCloud Quality Gate passed! 0 Bugs No Coverage information |
@snk4tr Ready to be merged |
Closes #123
Proposed Changes
use_average_pooling
toreplace_pooling
because in DISTS pooling is replaced by L2pooling layer, not AveragePool.__init__
descriptions to class description, so that it's visible in VS Code and Jupyter IDA