Skip to content

Commit

Permalink
Torch jinet implementation (#196)
Browse files Browse the repository at this point in the history
* initial commit

* forward model roughly implemented

* calling the cat in correct way

* separating jinet test module

* populating the jinet n2t loop

* merge conflict resolved2

* class __init__ is improved, kernel_size and num_features properties implemented, some of needed missing imports for training loop is added

* dilated conv implemented

* dilated conv integrated

* black fixes

* changed the cat function call to correct one

* activation functions made member in init

* moved dilated conv generation to init

* fixing channels to 1

* cleaned forward method further

* syntactical implementation complete

* comments in forward method is finalized

* kernel size and num features size validation bug fixed

* fixed the prediction call in the test casE

* missing attributes added to torch jinet

* dilation and concat problems are fixed

* magic f constant is removed

* black and flake8 fixes

* loop range for channel-wise dense layers is corrected

* working version

* n2t trainig loop is completed

* adding degressive residuals

* input channel number parameter added

* black fix

Co-authored-by: acs-ws <asolak@ku.edu.tr>
  • Loading branch information
AhmetCanSolak and acs-ws authored Sep 26, 2022
1 parent 77868a1 commit 7e39b03
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 7 deletions.
53 changes: 53 additions & 0 deletions aydin/nn/layers/dilated_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from torch import nn
from torch.nn import ZeroPad2d


class DilatedConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
spacetime_ndim,
padding,
kernel_size,
dilation,
activation="ReLU",
):
super(DilatedConv, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.spacetime_ndim = spacetime_ndim
self.activation = activation

self.zero_padding = ZeroPad2d(padding)

if spacetime_ndim == 2:
self.conv_class = nn.Conv2d
elif spacetime_ndim == 3:
self.conv_class = nn.Conv3d
else:
raise ValueError("spacetime_ndim parameter can only be 2 or 3...")

self.conv = self.conv_class(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
padding='valid',
)

self.activation_function = {
"ReLU": nn.ReLU(),
"swish": nn.SiLU(),
"lrel": nn.LeakyReLU(0.1),
}[self.activation]

def forward(self, x):
x = self.zero_padding(x)

x = self.conv(x)

x = self.activation_function(x)

return x
1 change: 1 addition & 0 deletions aydin/nn/models/jinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def jinet_core(self, input_lyr):
# stack all features into one tensor:
x = Concatenate(axis=-1)(dilated_conv_list)

print("after concat", x.shape)
# We keep the number of features:
self.total_num_features = total_num_features

Expand Down
52 changes: 46 additions & 6 deletions aydin/nn/models/torch/test/test_torch_jinet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
# def test_jinet_2D():
# input_array = torch.zeros((1, 1, 64, 64))
# model2d = JINetModel((64, 64, 1), spacetime_ndim=2)
# result = model2d.predict([input_array])
# assert result.shape == input_array.shape
# assert result.dtype == input_array.dtype
import numpy
import torch

from aydin.io.datasets import add_noise, normalise, camera
from aydin.nn.models.torch.torch_jinet import JINetModel, n2t_jinet_train_loop
from aydin.nn.pytorch.it_ptcnn import to_numpy


def test_forward_2D_jinet():
input_array = torch.zeros((1, 1, 64, 64))
model2d = JINetModel(spacetime_ndim=2)
result = model2d(input_array)

assert result.shape == input_array.shape
assert result.dtype == input_array.dtype


def test_supervised_2D_n2t():
visualize = False
lizard_image = normalise(camera()[:256, :256])
lizard_image = numpy.expand_dims(lizard_image, axis=0)
lizard_image = numpy.expand_dims(lizard_image, axis=0)

input_image = add_noise(lizard_image)

input_image = torch.tensor(input_image)
lizard_image = torch.tensor(lizard_image)

model = JINetModel(spacetime_ndim=2)

n2t_jinet_train_loop(input_image, lizard_image, model)

denoised = model(input_image)

if visualize:
import napari

viewer = napari.Viewer()
viewer.add_image(to_numpy(lizard_image), name="groundtruth")
viewer.add_image(to_numpy(input_image), name="noisy")
viewer.add_image(to_numpy(denoised), name="denoised")

napari.run()

# assert result.shape == input_image.shape
# assert result.dtype == input_image.dtype
2 changes: 1 addition & 1 deletion aydin/nn/models/torch/test/test_torch_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aydin.nn.models.torch.torch_unet import UNetModel


def test_supervised_2D():
def test_forward_2D():
input_array = torch.zeros((1, 1, 64, 64))
model2d = UNetModel(
# (64, 64, 1),
Expand Down
Loading

0 comments on commit 7e39b03

Please sign in to comment.