-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial commit * forward model roughly implemented * calling the cat in correct way * separating jinet test module * populating the jinet n2t loop * merge conflict resolved2 * class __init__ is improved, kernel_size and num_features properties implemented, some of needed missing imports for training loop is added * dilated conv implemented * dilated conv integrated * black fixes * changed the cat function call to correct one * activation functions made member in init * moved dilated conv generation to init * fixing channels to 1 * cleaned forward method further * syntactical implementation complete * comments in forward method is finalized * kernel size and num features size validation bug fixed * fixed the prediction call in the test casE * missing attributes added to torch jinet * dilation and concat problems are fixed * magic f constant is removed * black and flake8 fixes * loop range for channel-wise dense layers is corrected * working version * n2t trainig loop is completed * adding degressive residuals * input channel number parameter added * black fix Co-authored-by: acs-ws <asolak@ku.edu.tr>
- Loading branch information
1 parent
77868a1
commit 7e39b03
Showing
5 changed files
with
392 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from torch import nn | ||
from torch.nn import ZeroPad2d | ||
|
||
|
||
class DilatedConv(nn.Module): | ||
def __init__( | ||
self, | ||
in_channels, | ||
out_channels, | ||
spacetime_ndim, | ||
padding, | ||
kernel_size, | ||
dilation, | ||
activation="ReLU", | ||
): | ||
super(DilatedConv, self).__init__() | ||
|
||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.spacetime_ndim = spacetime_ndim | ||
self.activation = activation | ||
|
||
self.zero_padding = ZeroPad2d(padding) | ||
|
||
if spacetime_ndim == 2: | ||
self.conv_class = nn.Conv2d | ||
elif spacetime_ndim == 3: | ||
self.conv_class = nn.Conv3d | ||
else: | ||
raise ValueError("spacetime_ndim parameter can only be 2 or 3...") | ||
|
||
self.conv = self.conv_class( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=kernel_size, | ||
dilation=dilation, | ||
padding='valid', | ||
) | ||
|
||
self.activation_function = { | ||
"ReLU": nn.ReLU(), | ||
"swish": nn.SiLU(), | ||
"lrel": nn.LeakyReLU(0.1), | ||
}[self.activation] | ||
|
||
def forward(self, x): | ||
x = self.zero_padding(x) | ||
|
||
x = self.conv(x) | ||
|
||
x = self.activation_function(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,46 @@ | ||
# def test_jinet_2D(): | ||
# input_array = torch.zeros((1, 1, 64, 64)) | ||
# model2d = JINetModel((64, 64, 1), spacetime_ndim=2) | ||
# result = model2d.predict([input_array]) | ||
# assert result.shape == input_array.shape | ||
# assert result.dtype == input_array.dtype | ||
import numpy | ||
import torch | ||
|
||
from aydin.io.datasets import add_noise, normalise, camera | ||
from aydin.nn.models.torch.torch_jinet import JINetModel, n2t_jinet_train_loop | ||
from aydin.nn.pytorch.it_ptcnn import to_numpy | ||
|
||
|
||
def test_forward_2D_jinet(): | ||
input_array = torch.zeros((1, 1, 64, 64)) | ||
model2d = JINetModel(spacetime_ndim=2) | ||
result = model2d(input_array) | ||
|
||
assert result.shape == input_array.shape | ||
assert result.dtype == input_array.dtype | ||
|
||
|
||
def test_supervised_2D_n2t(): | ||
visualize = False | ||
lizard_image = normalise(camera()[:256, :256]) | ||
lizard_image = numpy.expand_dims(lizard_image, axis=0) | ||
lizard_image = numpy.expand_dims(lizard_image, axis=0) | ||
|
||
input_image = add_noise(lizard_image) | ||
|
||
input_image = torch.tensor(input_image) | ||
lizard_image = torch.tensor(lizard_image) | ||
|
||
model = JINetModel(spacetime_ndim=2) | ||
|
||
n2t_jinet_train_loop(input_image, lizard_image, model) | ||
|
||
denoised = model(input_image) | ||
|
||
if visualize: | ||
import napari | ||
|
||
viewer = napari.Viewer() | ||
viewer.add_image(to_numpy(lizard_image), name="groundtruth") | ||
viewer.add_image(to_numpy(input_image), name="noisy") | ||
viewer.add_image(to_numpy(denoised), name="denoised") | ||
|
||
napari.run() | ||
|
||
# assert result.shape == input_image.shape | ||
# assert result.dtype == input_image.dtype |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.