Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch unet residual #215

Merged
merged 17 commits into from
Jul 21, 2022
65 changes: 65 additions & 0 deletions aydin/nn/models/torch/test/test_torch_res_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy
import pytest
import torch

from aydin.io.datasets import camera, normalise, add_noise
from aydin.nn.models.torch.torch_res_unet import ResidualUNetModel
from aydin.nn.models.torch.torch_unet import n2t_unet_train_loop
from aydin.nn.pytorch.it_ptcnn import to_numpy


@pytest.mark.parametrize("nb_unet_levels", [2, 3, 5, 8])
def test_masking_2D(nb_unet_levels):
input_array = torch.zeros((1, 1, 1024, 1024))
model2d = ResidualUNetModel(
nb_unet_levels=nb_unet_levels,
supervised=False,
spacetime_ndim=2,
)
result = model2d(input_array, torch.ones(input_array.shape))
assert result.shape == input_array.shape
assert result.dtype == input_array.dtype


@pytest.mark.parametrize("nb_unet_levels", [2, 3, 5])
def test_masking_3D(nb_unet_levels):
input_array = torch.zeros((1, 1, 64, 64, 64))
model3d = ResidualUNetModel(
nb_unet_levels=nb_unet_levels,
supervised=False,
spacetime_ndim=3,
)
result = model3d(input_array, torch.ones(input_array.shape))
assert result.shape == input_array.shape
assert result.dtype == input_array.dtype


def test_supervised_2D_n2t():
visualize = False
lizard_image = normalise(camera()[:256, :256])
lizard_image = numpy.expand_dims(lizard_image, axis=0)
lizard_image = numpy.expand_dims(lizard_image, axis=0)

input_image = add_noise(lizard_image)

input_image = torch.tensor(input_image)
lizard_image = torch.tensor(lizard_image)

model = ResidualUNetModel(nb_unet_levels=2, supervised=True, spacetime_ndim=2)

n2t_unet_train_loop(input_image, lizard_image, model)

denoised = model(input_image)

if visualize:
import napari

viewer = napari.Viewer()
viewer.add_image(to_numpy(lizard_image), name="groundtruth")
viewer.add_image(to_numpy(input_image), name="noisy")
viewer.add_image(to_numpy(denoised), name="denoised")

napari.run()

# assert result.shape == input_image.shape
# assert result.dtype == input_image.dtype
137 changes: 137 additions & 0 deletions aydin/nn/models/torch/torch_res_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from torch import nn

from aydin.nn.layers.custom_conv import double_conv_block
from aydin.nn.layers.pooling_down import PoolingDown


class ResidualUNetModel(nn.Module):
def __init__(
self,
spacetime_ndim,
nb_unet_levels: int = 4,
nb_filters: int = 8,
learning_rate=0.01,
supervised: bool = False,
pooling_mode: str = 'max',
):
super(ResidualUNetModel, self).__init__()

self.spacetime_ndim = spacetime_ndim
self.nb_unet_levels = nb_unet_levels
self.nb_filters = nb_filters
self.learning_rate = learning_rate
self.supervised = supervised
self.pooling_down = PoolingDown(spacetime_ndim, pooling_mode)
self.upsampling = nn.Upsample(scale_factor=2, mode='nearest')

self.double_conv_blocks_encoder = self._encoder_convolutions()

self.unet_bottom_conv_out_channels = self.nb_filters * (
2 ** (self.nb_unet_levels - 1)
)
self.unet_bottom_conv_block = double_conv_block(
self.unet_bottom_conv_out_channels,
self.unet_bottom_conv_out_channels * 2,
self.unet_bottom_conv_out_channels,
spacetime_ndim,
)

self.double_conv_blocks_decoder = self._decoder_convolutions()

if spacetime_ndim == 2:
self.final_conv = nn.Conv2d(self.nb_filters, 1, 1)
else:
self.final_conv = nn.Conv3d(self.nb_filters, 1, 1)

def forward(self, x, input_msk=None):
"""
UNet forward method.

Parameters
----------
x
input_msk : numpy.ArrayLike
A mask per image must be passed with self-supervised training.

Returns
-------

"""
skip_layer = []

# Encoder
for layer_index in range(self.nb_unet_levels):
x = self.double_conv_blocks_encoder[layer_index](x)
skip_layer.append(x)
x = self.pooling_down(x)

# Bottom
x = self.unet_bottom_conv_block(x)

# Decoder
for layer_index in range(self.nb_unet_levels):
x = self.upsampling(x)
x = torch.add(x, skip_layer.pop())
x = self.double_conv_blocks_decoder[layer_index](x)

# Final convolution
x = self.final_conv(x)

# Masking for self-supervised training
if not self.supervised:
if input_msk is not None:
x *= input_msk
else:
raise ValueError(
"input_msk cannot be None for self-supervised training"
)

return x

def _encoder_convolutions(self):
convolution = []
for layer_index in range(self.nb_unet_levels):
if layer_index == 0:
nb_filters_in = 1
nb_filters_inner = self.nb_filters
nb_filters_out = self.nb_filters
else:
nb_filters_in = self.nb_filters * (2 ** (layer_index - 1))
nb_filters_inner = self.nb_filters * (2**layer_index)
nb_filters_out = self.nb_filters * (2**layer_index)

convolution.append(
double_conv_block(
nb_filters_in,
nb_filters_inner,
nb_filters_out,
self.spacetime_ndim,
)
)

return convolution

def _decoder_convolutions(self):
convolutions = []
for layer_index in range(self.nb_unet_levels):
nb_filters_in = self.nb_filters * (
2 ** (self.nb_unet_levels - layer_index - 1)
)

if layer_index == self.nb_unet_levels - 1:
nb_filters_inner = nb_filters_out = self.nb_filters
AhmetCanSolak marked this conversation as resolved.
Show resolved Hide resolved
else:
nb_filters_inner = nb_filters_out = nb_filters_in // 2

convolutions.append(
double_conv_block(
nb_filters_in,
nb_filters_inner,
nb_filters_out,
spacetime_ndim=self.spacetime_ndim,
normalizations=(None, "batch"),
)
)

return convolutions