Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: 3D jinet bug patch #244

Merged
merged 7 commits into from Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions aydin/analysis/image_metrics.py
Expand Up @@ -2,6 +2,19 @@
from numpy.linalg import norm
from scipy.fft import dct
from skimage.metrics import mean_squared_error
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim


def calculate_print_psnr_ssim(clean_image, noisy_image, denoised_image):
psnr_noisy = psnr(clean_image, noisy_image)
ssim_noisy = ssim(clean_image, noisy_image)
psnr_denoised = psnr(clean_image, denoised_image)
ssim_denoised = ssim(clean_image, denoised_image)
print("noisy :", psnr_noisy, ssim_noisy)
print("denoised:", psnr_denoised, ssim_denoised)

return psnr_noisy, psnr_denoised, ssim_noisy, ssim_denoised


def spectral_psnr(norm_true_image, norm_test_image):
Expand Down
10 changes: 9 additions & 1 deletion aydin/nn/models/torch/test/test_torch_jinet.py
Expand Up @@ -6,7 +6,7 @@
from aydin.nn.pytorch.it_ptcnn import to_numpy


def test_forward_2D_jinet():
def test_forward_2D():
input_array = torch.zeros((1, 1, 64, 64))
model2d = JINetModel(spacetime_ndim=2)
result = model2d(input_array)
Expand All @@ -15,6 +15,14 @@ def test_forward_2D_jinet():
assert result.dtype == input_array.dtype


def test_forward_3D():
input_array = torch.zeros((1, 1, 128, 128, 128))
model3d = JINetModel(spacetime_ndim=3)
result = model3d(input_array)
assert result.shape == input_array.shape
assert result.dtype == input_array.dtype
AhmetCanSolak marked this conversation as resolved.
Show resolved Hide resolved


def test_supervised_2D_n2t():
visualize = False
lizard_image = normalise(camera()[:256, :256])
Expand Down
59 changes: 28 additions & 31 deletions aydin/nn/models/torch/test/test_torch_models.py
@@ -1,8 +1,11 @@
# flake8: noqa
import numpy

import torch

from aydin.analysis.image_metrics import calculate_print_psnr_ssim
from aydin.io.datasets import lizard, add_noise, camera, normalise
from aydin.nn.models.torch.torch_jinet import JINetModel
from aydin.nn.models.torch.torch_res_unet import ResidualUNetModel
from aydin.nn.models.torch.torch_unet import UNetModel, n2t_train, n2s_train

Expand Down Expand Up @@ -41,46 +44,40 @@ def test_supervised_2D_n2t():
assert result.dtype == input_image.dtype


def test_supervised_2D_n2s():
lizard_image = normalise(camera())
lizard_image = numpy.expand_dims(lizard_image, axis=0)
lizard_image = numpy.expand_dims(lizard_image, axis=0)
def test_2D_n2s_unet():
run_2D_n2s(
UNetModel(
nb_unet_levels=2,
spacetime_ndim=2,
)
)

input_image = add_noise(lizard_image)

input_image = torch.tensor(input_image)
def test_2D_n2s_jinet():
run_2D_n2s(JINetModel(spacetime_ndim=2))

model = UNetModel(
nb_unet_levels=2,
spacetime_ndim=2,
)

n2s_train(input_image, model, nb_epochs=2)
model.cpu()
result = model(input_image)
def run_2D_n2s(model):
camera_image = normalise(camera())
camera_image = numpy.expand_dims(camera_image, axis=0)
camera_image = numpy.expand_dims(camera_image, axis=0)
noisy_image = add_noise(camera_image)
noisy_image = torch.tensor(noisy_image)

assert result.shape == input_image.shape
assert result.dtype == input_image.dtype
n2s_train(noisy_image, model, nb_epochs=20)
model.cpu()
denoised = model(noisy_image)

camera_image = camera_image[0, 0, :, :]
noisy_image = noisy_image.detach().numpy()[0, 0, :, :]
denoised = denoised.detach().numpy()[0, 0, :, :]

def test_masking_2D():
input_array = torch.zeros((1, 1, 64, 64))
model2d = UNetModel(
# (64, 64, 1),
nb_unet_levels=2,
spacetime_ndim=2,
_, _, ssim_noisy, ssim_denoised = calculate_print_psnr_ssim(
clean_image=camera_image, noisy_image=noisy_image, denoised_image=denoised
)
result = model2d(input_array)
assert result.shape == input_array.shape
assert result.dtype == input_array.dtype


# def test_jinet_2D():
# input_array = torch.zeros((1, 1, 64, 64))
# model2d = JINetModel((64, 64, 1), spacetime_ndim=2)
# result = model2d.predict([input_array])
# assert result.shape == input_array.shape
# assert result.dtype == input_array.dtype
assert ssim_denoised > ssim_noisy
assert ssim_denoised > 0.46


def test_supervised_3D():
Expand Down
14 changes: 7 additions & 7 deletions aydin/nn/models/torch/torch_jinet.py
Expand Up @@ -39,7 +39,7 @@ def __init__(
if len(self.kernel_sizes) != len(self.num_features):
raise ValueError("Number of kernel sizes and features does not match.")

self.dilated_conv_functions = []
self.dilated_conv_functions = nn.ModuleList()
current_receptive_field_radius = 0
for scale_index in range(len(self.kernel_sizes)):
# Get kernel size and number of features:
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
self.nb_channels = sum(self.num_features) # * 2

nb_out = self.nb_channels
self.kernel_one_conv_functions = []
self.kernel_one_conv_functions = nn.ModuleList()
for index in range(self.nb_dense_layers):
nb_in = nb_out
nb_out = (
Expand Down Expand Up @@ -130,15 +130,15 @@ def forward(self, x):
for index in range(len(self.kernel_sizes)):
x = self.dilated_conv_functions[index](x)
dilated_conv_list.append(x)
print(x.shape)
# print(x.shape)

# Concat the results
x = torch.cat(dilated_conv_list, dim=1)
print(f"after cat: {x.shape}")
# print(f"after cat: {x.shape}")

# First kernel size one conv
x = self.kernel_one_conv_functions[0](x)
print(f"after first kernel one conv: {x.shape}")
# print(f"after first kernel one conv: {x.shape}")
x = self.lrelu(x)
y = x
f = 1
Expand All @@ -147,10 +147,10 @@ def forward(self, x):
for index in range(1, self.nb_dense_layers):
x = self.kernel_one_conv_functions[index](x)
x = self.lrelu(x)
y += f * x
y = y + f * x

if self.degressive_residuals:
f *= 0.5
f = f * 0.5

# Final kernel size one convolution
y = self.final_kernel_one_conv(y)
Expand Down
3 changes: 2 additions & 1 deletion aydin/nn/models/torch/torch_unet.py
Expand Up @@ -155,7 +155,6 @@ def n2s_train(
model : UNetModel
nb_epochs : int
learning_rate : float
patch_size : int

"""
if torch.cuda.is_available():
Expand All @@ -165,6 +164,8 @@ def n2s_train(
device = torch.device(dev)
print(dev)

torch.autograd.set_detect_anomaly(True)

model = model.to(device)
print(f"device {device}")

Expand Down