Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch model demos #218

Merged
merged 1 commit into from Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file.
46 changes: 46 additions & 0 deletions aydin/nn/models/torch/demo/demo_n2t.py
@@ -0,0 +1,46 @@
# flake8: noqa
import numpy
import torch

from aydin.io.datasets import add_noise, normalise, camera
from aydin.nn.models.torch.torch_res_unet import ResidualUNetModel
from aydin.nn.models.torch.torch_unet import n2t_unet_train_loop, UNetModel
from aydin.nn.pytorch.it_ptcnn import to_numpy


def demo_supervised_2D_n2t(model_class):
visualize = True
lizard_image = normalise(camera()[:256, :256])
lizard_image = numpy.expand_dims(lizard_image, axis=0)
lizard_image = numpy.expand_dims(lizard_image, axis=0)

input_image = add_noise(lizard_image)

input_image = torch.tensor(input_image)
lizard_image = torch.tensor(lizard_image)

model = model_class(nb_unet_levels=2, supervised=True, spacetime_ndim=2)

n2t_unet_train_loop(input_image, lizard_image, model)

denoised = model(input_image)

if visualize:
import napari

viewer = napari.Viewer()
viewer.add_image(to_numpy(lizard_image), name="groundtruth")
viewer.add_image(to_numpy(input_image), name="noisy")
viewer.add_image(to_numpy(denoised), name="denoised")

napari.run()

# assert result.shape == input_image.shape
# assert result.dtype == input_image.dtype


if __name__ == '__main__':
model_class = UNetModel
# model_class = ResidualUNetModel

demo_supervised_2D_n2t(model_class)
36 changes: 1 addition & 35 deletions aydin/nn/models/torch/test/test_torch_res_unet.py
@@ -1,11 +1,8 @@
import numpy
# flake8: noqa
import pytest
import torch

from aydin.io.datasets import camera, normalise, add_noise
from aydin.nn.models.torch.torch_res_unet import ResidualUNetModel
from aydin.nn.models.torch.torch_unet import n2t_unet_train_loop
from aydin.nn.pytorch.it_ptcnn import to_numpy


@pytest.mark.parametrize("nb_unet_levels", [2, 3, 5, 8])
Expand All @@ -32,34 +29,3 @@ def test_masking_3D(nb_unet_levels):
result = model3d(input_array, torch.ones(input_array.shape))
assert result.shape == input_array.shape
assert result.dtype == input_array.dtype


def test_supervised_2D_n2t():
visualize = False
lizard_image = normalise(camera()[:256, :256])
lizard_image = numpy.expand_dims(lizard_image, axis=0)
lizard_image = numpy.expand_dims(lizard_image, axis=0)

input_image = add_noise(lizard_image)

input_image = torch.tensor(input_image)
lizard_image = torch.tensor(lizard_image)

model = ResidualUNetModel(nb_unet_levels=2, supervised=True, spacetime_ndim=2)

n2t_unet_train_loop(input_image, lizard_image, model)

denoised = model(input_image)

if visualize:
import napari

viewer = napari.Viewer()
viewer.add_image(to_numpy(lizard_image), name="groundtruth")
viewer.add_image(to_numpy(input_image), name="noisy")
viewer.add_image(to_numpy(denoised), name="denoised")

napari.run()

# assert result.shape == input_image.shape
# assert result.dtype == input_image.dtype
36 changes: 1 addition & 35 deletions aydin/nn/models/torch/test/test_torch_unet.py
@@ -1,11 +1,8 @@
# flake8: noqa
import numpy
import pytest
import torch

from aydin.io.datasets import add_noise, camera, normalise
from aydin.nn.models.torch.torch_unet import UNetModel, n2t_unet_train_loop
from aydin.nn.pytorch.it_ptcnn import to_numpy
from aydin.nn.models.torch.torch_unet import UNetModel


def test_supervised_2D():
Expand All @@ -21,37 +18,6 @@ def test_supervised_2D():
assert result.dtype == input_array.dtype


def test_supervised_2D_n2t():
visualize = False
lizard_image = normalise(camera()[:256, :256])
lizard_image = numpy.expand_dims(lizard_image, axis=0)
lizard_image = numpy.expand_dims(lizard_image, axis=0)

input_image = add_noise(lizard_image)

input_image = torch.tensor(input_image)
lizard_image = torch.tensor(lizard_image)

model = UNetModel(nb_unet_levels=2, supervised=True, spacetime_ndim=2)

n2t_unet_train_loop(input_image, lizard_image, model)

denoised = model(input_image)

if visualize:
import napari

viewer = napari.Viewer()
viewer.add_image(to_numpy(lizard_image), name="groundtruth")
viewer.add_image(to_numpy(input_image), name="noisy")
viewer.add_image(to_numpy(denoised), name="denoised")

napari.run()

# assert result.shape == input_image.shape
# assert result.dtype == input_image.dtype


@pytest.mark.parametrize("nb_unet_levels", [2, 3, 5, 8])
def test_masking_2D(nb_unet_levels):
input_array = torch.zeros((1, 1, 1024, 1024))
Expand Down