Skip to content

Commit

Permalink
Torch unet maskout (#209)
Browse files Browse the repository at this point in the history
* initial commit

* to keep changes

* input msk argument added

* to keep changes

* implementation completed

* black fix

* removed redundant print statement

* unit tests are passing

Co-authored-by: acs-ws <asolak@ku.edu.tr>
  • Loading branch information
AhmetCanSolak and acs-ws committed Jul 13, 2022
1 parent cb6df38 commit fae3c75
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
14 changes: 2 additions & 12 deletions aydin/nn/models/torch/test/test_torch_models.py
Expand Up @@ -45,16 +45,6 @@ def test_supervised_2D_n2t():
input_image = torch.tensor(input_image)
lizard_image = torch.tensor(lizard_image)

# learning_rate = 0.01
# training_noise = 0.001
# l2_weight_regularisation = 1e-9
# patience = 128
# patience_epsilon = 0.0
# reduce_lr_factor = 0.5
# reduce_lr_patience = patience // 2
# reload_best_model_period = 1024
# best_val_loss_value = math.inf

# dataset = TorchDataset(input_image, lizard_image, 64, self_supervised=False)

# data_loader = DataLoader(
Expand All @@ -79,7 +69,7 @@ def test_masking_2D():
supervised=False,
spacetime_ndim=2,
)
result = model2d(input_array)
result = model2d(input_array, torch.ones(input_array.shape))
assert result.shape == input_array.shape
assert result.dtype == input_array.dtype

Expand Down Expand Up @@ -113,7 +103,7 @@ def test_masking_3D():
supervised=False,
spacetime_ndim=3,
)
result = model3d(input_array)
result = model3d(input_array, torch.ones(input_array.shape))
assert result.shape == input_array.shape
assert result.dtype == input_array.dtype

Expand Down
27 changes: 20 additions & 7 deletions aydin/nn/models/torch/torch_unet.py
Expand Up @@ -2,16 +2,13 @@
from collections import OrderedDict
from itertools import chain

# import napari
import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from aydin.nn.layers.conv_with_batch_norm import ConvWithBatchNorm
from aydin.nn.layers.pooling_down import PoolingDown

# from aydin.nn.pytorch.it_ptcnn import to_numpy
from aydin.nn.pytorch.optimizers.esadam import ESAdam
from aydin.util.log.log import lprint

Expand Down Expand Up @@ -97,9 +94,20 @@ def __init__(
else:
self.conv = nn.Conv3d(8, 1, 1)

self.maskout = None # TODO: assign correct maskout module
def forward(self, x, input_msk=None):
"""
UNet forward method.
Parameters
----------
x
input_msk : numpy.ArrayLike
A mask per image must be passed with self-supervised training.
Returns
-------
def forward(self, x):
"""

skip_layer = [x]

Expand Down Expand Up @@ -138,8 +146,13 @@ def forward(self, x):

x = self.conv(x)

# if not self.supervised:
# x = self.maskout(x)
if not self.supervised:
if input_msk is not None:
x *= input_msk
else:
raise ValueError(
"input_msk cannot be None for self-supervised training"
)

return x

Expand Down
2 changes: 1 addition & 1 deletion aydin/nn/util/mask_generator.py
Expand Up @@ -12,7 +12,7 @@ def masker(batch_vol, i=None, mask_shape=None, p=None):
Parameters
----------
batch_vol
batch volume, desn't include batch and ch dimensions
batch volume, doesn't include batch and ch dimensions
i
mask_shape
mask shape e.g. (3, 3)
Expand Down

0 comments on commit fae3c75

Please sign in to comment.