Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch jinet implementation #196

Merged
merged 40 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8616a46
initial commit
AhmetCanSolak Jun 14, 2022
26c3e6f
forward model roughly implemented
Jun 15, 2022
f864e2b
Merge remote-tracking branch 'upstream/master' into torch-jinet
Jun 16, 2022
52ca7d8
calling the cat in correct way
Jun 21, 2022
4b5d15c
Merge remote-tracking branch 'upstream/master' into torch-jinet
Jun 21, 2022
d3e6b1b
merge conflict resolved
AhmetCanSolak Jul 14, 2022
2f91cc9
Merge branch 'torch-jinet' of github.com:AhmetCanSolak/aydin into tor…
AhmetCanSolak Jul 14, 2022
3798f4f
merge conflict resolved
Jul 18, 2022
531efd0
merge conflict resolved
Jul 18, 2022
5835b03
separating jinet test module
Jul 18, 2022
336f392
populating the jinet n2t loop
Jul 18, 2022
60cc58b
Merge remote-tracking branch 'upstream/master' into torch-jinet
Jul 18, 2022
1b2247f
merge conflict resolved
Jul 21, 2022
450ffc9
merge conflict resolved2
Jul 21, 2022
b2eef42
class __init__ is improved, kernel_size and num_features properties i…
Jul 21, 2022
9d106a1
dilated conv implemented
Jul 21, 2022
047d840
dilated conv integrated
Jul 21, 2022
9c16f9e
black fixes
Jul 21, 2022
515c69e
changed the cat function call to correct one
Jul 25, 2022
f36ccab
activation functions made member in init
Jul 25, 2022
df4f070
moved dilated conv generation to init
Jul 25, 2022
6bcbc98
fixing channels to 1
Jul 25, 2022
f48f35c
cleaned forward method further
Jul 25, 2022
0216272
syntactical implementation complete
Jul 25, 2022
27d73ef
comments in forward method is finalized
Jul 25, 2022
ffea5e9
kernel size and num features size validation bug fixed
Jul 25, 2022
3685028
fixed the prediction call in the test casE
Jul 25, 2022
f99637c
missing attributes added to torch jinet
Jul 25, 2022
36e4e33
dilation and concat problems are fixed
Jul 25, 2022
bcc7c19
magic f constant is removed
Jul 25, 2022
8ec202d
black and flake8 fixes
Jul 25, 2022
5ac60fd
loop range for channel-wise dense layers is corrected
Jul 26, 2022
df6577f
working version
Jul 26, 2022
f4b22de
n2t trainig loop is completed
Jul 26, 2022
69fe6b3
adding degressive residuals
Jul 26, 2022
f411182
input channel number parameter added
Jul 26, 2022
d0eeb6c
Merge remote-tracking branch 'upstream/master' into torch-jinet
Sep 14, 2022
2adfcea
black fix
Sep 14, 2022
ab7eb52
Merge remote-tracking branch 'upstream/master' into torch-jinet
Sep 22, 2022
8b4a7b2
Merge remote-tracking branch 'upstream/master' into torch-jinet
Sep 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions aydin/nn/layers/dilated_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from torch import nn
from torch.nn import ZeroPad2d


class DilatedConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
spacetime_ndim,
padding,
kernel_size,
dilation,
activation="ReLU",
):
super(DilatedConv, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.spacetime_ndim = spacetime_ndim
self.activation = activation

self.zero_padding = ZeroPad2d(padding)

if spacetime_ndim == 2:
self.conv_class = nn.Conv2d
elif spacetime_ndim == 3:
self.conv_class = nn.Conv3d
else:
raise ValueError("spacetime_ndim parameter can only be 2 or 3...")

self.conv = self.conv_class(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
padding='valid',
)

self.activation_function = {
"ReLU": nn.ReLU(),
"swish": nn.SiLU(),
"lrel": nn.LeakyReLU(0.1),
}[self.activation]

def forward(self, x):
x = self.zero_padding(x)

x = self.conv(x)

x = self.activation_function(x)

return x
1 change: 1 addition & 0 deletions aydin/nn/models/jinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def jinet_core(self, input_lyr):
# stack all features into one tensor:
x = Concatenate(axis=-1)(dilated_conv_list)

print("after concat", x.shape)
# We keep the number of features:
self.total_num_features = total_num_features

Expand Down
52 changes: 46 additions & 6 deletions aydin/nn/models/torch/test/test_torch_jinet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
# def test_jinet_2D():
# input_array = torch.zeros((1, 1, 64, 64))
# model2d = JINetModel((64, 64, 1), spacetime_ndim=2)
# result = model2d.predict([input_array])
# assert result.shape == input_array.shape
# assert result.dtype == input_array.dtype
import numpy
import torch

from aydin.io.datasets import add_noise, normalise, camera
from aydin.nn.models.torch.torch_jinet import JINetModel, n2t_jinet_train_loop
from aydin.nn.pytorch.it_ptcnn import to_numpy


def test_forward_2D_jinet():
input_array = torch.zeros((1, 1, 64, 64))
model2d = JINetModel(spacetime_ndim=2)
result = model2d(input_array)

assert result.shape == input_array.shape
assert result.dtype == input_array.dtype


def test_supervised_2D_n2t():
visualize = False
lizard_image = normalise(camera()[:256, :256])
lizard_image = numpy.expand_dims(lizard_image, axis=0)
lizard_image = numpy.expand_dims(lizard_image, axis=0)

input_image = add_noise(lizard_image)

input_image = torch.tensor(input_image)
lizard_image = torch.tensor(lizard_image)

model = JINetModel(spacetime_ndim=2)

n2t_jinet_train_loop(input_image, lizard_image, model)

denoised = model(input_image)

if visualize:
import napari

viewer = napari.Viewer()
viewer.add_image(to_numpy(lizard_image), name="groundtruth")
viewer.add_image(to_numpy(input_image), name="noisy")
viewer.add_image(to_numpy(denoised), name="denoised")

napari.run()

# assert result.shape == input_image.shape
# assert result.dtype == input_image.dtype
2 changes: 1 addition & 1 deletion aydin/nn/models/torch/test/test_torch_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aydin.nn.models.torch.torch_unet import UNetModel


def test_supervised_2D():
def test_forward_2D():
input_array = torch.zeros((1, 1, 64, 64))
model2d = UNetModel(
# (64, 64, 1),
Expand Down
Loading