Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch jinet implementation #196

Merged
merged 40 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8616a46
initial commit
AhmetCanSolak Jun 14, 2022
26c3e6f
forward model roughly implemented
Jun 15, 2022
f864e2b
Merge remote-tracking branch 'upstream/master' into torch-jinet
Jun 16, 2022
52ca7d8
calling the cat in correct way
Jun 21, 2022
4b5d15c
Merge remote-tracking branch 'upstream/master' into torch-jinet
Jun 21, 2022
d3e6b1b
merge conflict resolved
AhmetCanSolak Jul 14, 2022
2f91cc9
Merge branch 'torch-jinet' of github.com:AhmetCanSolak/aydin into tor…
AhmetCanSolak Jul 14, 2022
3798f4f
merge conflict resolved
Jul 18, 2022
531efd0
merge conflict resolved
Jul 18, 2022
5835b03
separating jinet test module
Jul 18, 2022
336f392
populating the jinet n2t loop
Jul 18, 2022
60cc58b
Merge remote-tracking branch 'upstream/master' into torch-jinet
Jul 18, 2022
1b2247f
merge conflict resolved
Jul 21, 2022
450ffc9
merge conflict resolved2
Jul 21, 2022
b2eef42
class __init__ is improved, kernel_size and num_features properties i…
Jul 21, 2022
9d106a1
dilated conv implemented
Jul 21, 2022
047d840
dilated conv integrated
Jul 21, 2022
9c16f9e
black fixes
Jul 21, 2022
515c69e
changed the cat function call to correct one
Jul 25, 2022
f36ccab
activation functions made member in init
Jul 25, 2022
df4f070
moved dilated conv generation to init
Jul 25, 2022
6bcbc98
fixing channels to 1
Jul 25, 2022
f48f35c
cleaned forward method further
Jul 25, 2022
0216272
syntactical implementation complete
Jul 25, 2022
27d73ef
comments in forward method is finalized
Jul 25, 2022
ffea5e9
kernel size and num features size validation bug fixed
Jul 25, 2022
3685028
fixed the prediction call in the test casE
Jul 25, 2022
f99637c
missing attributes added to torch jinet
Jul 25, 2022
36e4e33
dilation and concat problems are fixed
Jul 25, 2022
bcc7c19
magic f constant is removed
Jul 25, 2022
8ec202d
black and flake8 fixes
Jul 25, 2022
5ac60fd
loop range for channel-wise dense layers is corrected
Jul 26, 2022
df6577f
working version
Jul 26, 2022
f4b22de
n2t trainig loop is completed
Jul 26, 2022
69fe6b3
adding degressive residuals
Jul 26, 2022
f411182
input channel number parameter added
Jul 26, 2022
d0eeb6c
Merge remote-tracking branch 'upstream/master' into torch-jinet
Sep 14, 2022
2adfcea
black fix
Sep 14, 2022
ab7eb52
Merge remote-tracking branch 'upstream/master' into torch-jinet
Sep 22, 2022
8b4a7b2
Merge remote-tracking branch 'upstream/master' into torch-jinet
Sep 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions aydin/nn/layers/dilated_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from torch import nn
from torch.nn import ZeroPad2d


class DilatedConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
spacetime_ndim,
padding,
kernel_size,
dilation,
activation="ReLU",
):
super(DilatedConv, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.spacetime_ndim = spacetime_ndim
self.activation = activation

self.zero_padding = ZeroPad2d(padding)

if spacetime_ndim == 2:
self.conv_class = nn.Conv2d
elif spacetime_ndim == 3:
self.conv_class = nn.Conv3d
else:
raise ValueError("spacetime_ndim parameter can only be 2 or 3...")

self.conv = self.conv_class(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
padding='valid',
)

self.activation_function = {
"ReLU": nn.ReLU(),
"swish": nn.SiLU(),
"lrel": nn.LeakyReLU(0.1),
}[self.activation]

def forward(self, x):
x = self.zero_padding(x)

x = self.conv(x)

x = self.activation_function(x)

return x
1 change: 1 addition & 0 deletions aydin/nn/models/jinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def jinet_core(self, input_lyr):
# stack all features into one tensor:
x = Concatenate(axis=-1)(dilated_conv_list)

print("after concat", x.shape)
# We keep the number of features:
self.total_num_features = total_num_features

Expand Down
18 changes: 12 additions & 6 deletions aydin/nn/models/torch/test/test_torch_jinet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# def test_jinet_2D():
# input_array = torch.zeros((1, 1, 64, 64))
# model2d = JINetModel((64, 64, 1), spacetime_ndim=2)
# result = model2d.predict([input_array])
# assert result.shape == input_array.shape
# assert result.dtype == input_array.dtype
import torch

from aydin.nn.models.torch.torch_jinet import JINetModel


def test_forward_2D_jinet():
input_array = torch.zeros((1, 1, 64, 64))
model2d = JINetModel(spacetime_ndim=2)
result = model2d(input_array)

assert result.shape == input_array.shape
assert result.dtype == input_array.dtype
2 changes: 1 addition & 1 deletion aydin/nn/models/torch/test/test_torch_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from aydin.nn.pytorch.it_ptcnn import to_numpy


def test_supervised_2D():
def test_forward_2D():
input_array = torch.zeros((1, 1, 64, 64))
model2d = UNetModel(
# (64, 64, 1),
Expand Down
280 changes: 280 additions & 0 deletions aydin/nn/models/torch/torch_jinet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
from collections import OrderedDict
from itertools import chain

import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from aydin.nn.layers.dilated_conv import DilatedConv
from aydin.nn.pytorch.optimizers.esadam import ESAdam
AhmetCanSolak marked this conversation as resolved.
Show resolved Hide resolved
from aydin.util.log.log import lprint


class JINetModel(nn.Module):
def __init__(
self,
spacetime_ndim,
nb_out_channels: int = 1,
kernel_sizes=None,
num_features=None,
nb_dense_layers: int = 3,
nb_channels: int = None,
final_relu: bool = False,
):
super(JINetModel, self).__init__()

self.spacetime_ndim = spacetime_ndim
self.nb_out_channels = nb_out_channels
self._kernel_sizes = kernel_sizes
self._num_features = num_features
self.nb_dense_layers = nb_dense_layers
self.nb_channels = nb_channels
self.final_relu = final_relu

if len(self.kernel_sizes) != len(self.num_features):
raise ValueError("Number of kernel sizes and features does not match.")

self.dilated_conv_functions = []
AhmetCanSolak marked this conversation as resolved.
Show resolved Hide resolved
current_receptive_field_radius = 0
for scale_index in range(len(self.kernel_sizes)):
# Get kernel size and number of features:
kernel_size = self.kernel_sizes[scale_index]

# radius and dilation:
radius = (kernel_size - 1) // 2
dilation = 1 + current_receptive_field_radius

self.dilated_conv_functions.append(
DilatedConv(
1 if scale_index == 0 else self.num_features[scale_index - 1],
self.num_features[scale_index],
self.spacetime_ndim,
padding=dilation * radius,
kernel_size=kernel_size,
dilation=dilation,
activation="lrel",
)
)

# update receptive field radius
current_receptive_field_radius += dilation * radius

if spacetime_ndim == 2:
self.conv = nn.Conv2d
elif spacetime_ndim == 3:
self.conv = nn.Conv3d
else:
raise ValueError("spacetime_ndim can not be anything other than 2 or 3...")

if self.nb_channels is None:
self.nb_channels = sum(self.num_features) # * 2
AhmetCanSolak marked this conversation as resolved.
Show resolved Hide resolved

nb_out = self.nb_channels
self.kernel_one_conv_functions = []
for index in range(self.nb_dense_layers):
nb_in = nb_out
nb_out = (
self.nb_out_channels
if index == (self.nb_dense_layers - 1)
else self.nb_channels
)
print(index, nb_in, nb_out)

self.kernel_one_conv_functions.append(
self.conv(
in_channels=nb_in,
out_channels=nb_out,
kernel_size=(1,) * spacetime_ndim,
)
)

self.final_kernel_one_conv = self.conv(
in_channels=self.nb_channels,
out_channels=1,
kernel_size=(1,) * spacetime_ndim,
)

self.relu = nn.ReLU()
self.lrelu = nn.LeakyReLU(negative_slope=0.01)

@property
def kernel_sizes(self):
if self._kernel_sizes is None:
if self.spacetime_ndim == 2:
self._kernel_sizes = [7, 5, 3, 3, 3, 3, 3, 3]
elif self.spacetime_ndim == 3:
self._kernel_sizes = [7, 5, 3, 3]

return self._kernel_sizes

@property
def num_features(self):
if self._num_features is None:
if self.spacetime_ndim == 2:
self._num_features = [64, 32, 16, 8, 4, 2, 1, 1]
elif self.spacetime_ndim == 3:
self._num_features = [10, 8, 4, 2]

return self._num_features

def forward(self, x):
dilated_conv_list = []

# Calculate dilated convolutions
for index in range(len(self.kernel_sizes)):
x = self.dilated_conv_functions[index](x)
dilated_conv_list.append(x)
print(x.shape)

# Concat the results
x = torch.cat(dilated_conv_list, dim=1)
print(f"after cat: {x.shape}")

# First kernel size one conv
x = self.kernel_one_conv_functions[0](x)
print(f"after first kernel one conv: {x.shape}")
x = self.lrelu(x)
y = x

# Rest of the kernel size one convolutions
for index in range(1, self.nb_dense_layers):
x = self.kernel_one_conv_functions[index](x)
x = self.lrelu(x)
y += x

# Final kernel size one convolution
y = self.final_kernel_one_conv(y)

# Final ReLU
if self.final_relu:
y = self.relu(y)

return y


def n2t_jinet_train_loop(
input_images,
target_images,
model: JINetModel,
nb_epochs: int = 1024,
learning_rate=0.01,
training_noise=0.001,
l2_weight_regularization=1e-9,
patience=128,
patience_epsilon=0.0,
reduce_lr_factor=0.5,
reload_best_model_period=1024,
best_val_loss_value=None,
):
writer = SummaryWriter()

optimizer = ESAdam(
chain(model.parameters()),
lr=learning_rate,
start_noise_level=training_noise,
weight_decay=l2_weight_regularization,
)

scheduler = ReduceLROnPlateau(
optimizer,
'min',
factor=reduce_lr_factor,
verbose=True,
# patience=reduce_lr_patience, TODO: enable this parameter
)

def loss_function(u, v):
return torch.abs(u - v)

for epoch in range(nb_epochs):
train_loss_value = 0
validation_loss_value = 0
iteration = 0
for i, (input_image, target_image) in enumerate(
zip([input_images], [target_images])
):
optimizer.zero_grad()

model.train()

translated_image = model(input_image)

translation_loss = loss_function(translated_image, target_image)

translation_loss_value = translation_loss.mean()

translation_loss_value.backward()

optimizer.step()

train_loss_value += translation_loss_value.item()
iteration += 1

# Validation:
with torch.no_grad():
model.eval()

translated_image = model(input_image)

translation_loss = loss_function(translated_image, target_image)

translation_loss_value = translation_loss.mean().cpu().item()

validation_loss_value += translation_loss_value
iteration += 1

train_loss_value /= iteration
lprint(f"Training loss value: {train_loss_value}")

validation_loss_value /= iteration
lprint(f"Validation loss value: {validation_loss_value}")

writer.add_scalar("Loss/train", train_loss_value, epoch)
writer.add_scalar("Loss/valid", validation_loss_value, epoch)

scheduler.step(validation_loss_value)

if validation_loss_value < best_val_loss_value:
lprint("## New best val loss!")
if validation_loss_value < best_val_loss_value - patience_epsilon:
lprint("## Good enough to reset patience!")
patience_counter = 0

best_val_loss_value = validation_loss_value

best_model_state_dict = OrderedDict(
{k: v.to('cpu') for k, v in model.state_dict().items()}
)
else:
if epoch % max(1, reload_best_model_period) == 0 and best_model_state_dict:
lprint("Reloading best models to date!")
model.load_state_dict(best_model_state_dict)

if patience_counter > patience:
lprint("Early stopping!")
break

lprint(
f"No improvement of validation losses, patience = {patience_counter}/{patience}"
)
patience_counter += 1

lprint(f"## Best val loss: {best_val_loss_value}")

writer.flush()
writer.close()


# def n2s_jinet_train_loop():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this implementation also going to be included in the current PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be a separate one as it needs more testing

# writer = SummaryWriter()
#
# optimizer = ESAdam(
# chain(model.parameters()),
# lr=learning_rate,
# start_noise_level=training_noise,
# weight_decay=l2_weight_regularisation,
# )
#
# writer.flush()
# writer.close()