Skip to content

Commit

Permalink
run formatter and linter
Browse files Browse the repository at this point in the history
  • Loading branch information
sascha-kirch committed Nov 3, 2023
1 parent 442dddc commit 018f376
Show file tree
Hide file tree
Showing 17 changed files with 99 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python -m build
python -m twine upload -r testpypi dist/*
python -m twine upload dist/*
4 changes: 3 additions & 1 deletion DeepSaki/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from DeepSaki import losses
from DeepSaki import models
from DeepSaki import optimizers
from DeepSaki import types
from DeepSaki import utils

__version__ = "1.0.2"
__version__ = "1.0.0"

__author__ = "Sascha Kirch"

Expand All @@ -21,5 +22,6 @@
"losses",
"models",
"optimizers",
"types",
"utils",
]
4 changes: 4 additions & 0 deletions DeepSaki/activations/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from DeepSaki.activations.complex_valued_activations import ComplexActivation

__all__ = [
"ComplexActivation",
]
1 change: 0 additions & 1 deletion DeepSaki/activations/complex_valued_activations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any
from typing import Dict
from typing import Union

import tensorflow as tf

Expand Down
5 changes: 5 additions & 0 deletions DeepSaki/augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from DeepSaki.augmentations.grid_cutting import cut_mix
from DeepSaki.augmentations.grid_cutting import cut_out

__all__ = [
"cut_mix",
"cut_out",
]
4 changes: 4 additions & 0 deletions DeepSaki/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from DeepSaki.constraints.constraints import NonNegative

__all__ = [
"NonNegative",
]
8 changes: 8 additions & 0 deletions DeepSaki/initializers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# he_alpha.py
from DeepSaki.initializers.he_alpha import HeAlpha
from DeepSaki.initializers.he_alpha import HeAlphaUniform
from DeepSaki.initializers.he_alpha import HeAlphaNormal

# initializer_helper.py
from DeepSaki.initializers.complex_initializer import ComplexInitializer

__all__ = [
"HeAlpha",
"HeAlphaUniform",
"HeAlphaNormal",
"ComplexInitializer",
]
33 changes: 32 additions & 1 deletion DeepSaki/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from DeepSaki.layers.fourier_layer import FourierFilter2D
from DeepSaki.layers.fourier_layer import FFT2D
from DeepSaki.layers.fourier_layer import iFFT2D
from DeepSaki.layers.fourier_layer import FFT2D
from DeepSaki.layers.fourier_layer import FFT3D
from DeepSaki.layers.fourier_layer import iFFT3D
from DeepSaki.layers.fourier_layer import FourierPooling2D
from DeepSaki.layers.fourier_layer import rFFT2DFilter

Expand Down Expand Up @@ -38,3 +38,34 @@
from DeepSaki.layers.layer_helper import get_initializer
from DeepSaki.layers.layer_helper import pad_func
from DeepSaki.layers.layer_helper import dropout_func

__all__ = [
"GlobalSumPooling2D",
"LearnedPooling",
"FourierFilter2D",
"FFT2D",
"iFFT2D",
"FFT3D",
"iFFT3D",
"FourierConvolution2D",
"FourierPooling2D",
"rFFT2DFilter",
"ReflectionPadding2D",
"Conv2DSplitted",
"Conv2DBlock",
"DenseBlock",
"DownSampleBlock",
"ResBlockDown",
"ResBlockUp",
"UpSampleBlock",
"ScaleLayer",
"ScalarGatedSelfAttention",
"Encoder",
"ResidualBlock",
"Bottleneck",
"Decoder",
"get_initializer",
"plot_layer",
"pad_func",
"dropout_func",
]
3 changes: 2 additions & 1 deletion DeepSaki/layers/fourier_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def _elementwise_product(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor:
return tf.math.reduce_sum(c, axis=-1)

def _get_multiplication_function(
self, multiplication_type: MultiplicationType
self,
multiplication_type: MultiplicationType,
) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
"""Returns the corresponding elementwise multiplication function for a given type."""
valid_multiplication_types = {
Expand Down
7 changes: 7 additions & 0 deletions DeepSaki/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from DeepSaki.losses.image_based_losses import ImageBasedLoss
from DeepSaki.losses.image_based_losses import PixelDistanceLoss
from DeepSaki.losses.image_based_losses import StructuralSimilarityLoss

__all__ = [
"ImageBasedLoss",
"PixelDistanceLoss",
"StructuralSimilarityLoss",
]
8 changes: 8 additions & 0 deletions DeepSaki/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,11 @@

from DeepSaki.models.autoencoders import UNet
from DeepSaki.models.autoencoders import ResNet

__all__ = [
"LayoutContentDiscriminator",
"PatchDiscriminator",
"UNetDiscriminator",
"UNet",
"ResNet",
]
2 changes: 1 addition & 1 deletion DeepSaki/models/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self,
#To enable mixed precission support for matplotlib and distributed training and to increase training stability
self.linear_dtype = tf.keras.layers.Activation("linear", dtype = tf.float32)

def build(self, input_shape):
def build(self, input_shape:tf.TensorShape)->None:
_, height, width,_ = input_shape
if height < 256 or width < 256:
raise ValueError(f"Input requires a height and width of minimum 256, but got {height=} and {width=}")
Expand Down
4 changes: 0 additions & 4 deletions DeepSaki/optimizers/swats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from __future__ import division
from __future__ import print_function

from enum import Enum
from enum import auto
from typing import Any
from typing import Dict
from typing import List
Expand All @@ -21,8 +19,6 @@

from DeepSaki.types.optimizers_enums import CurrentOptimizer



class SwatsAdam(optimizer_v2.OptimizerV2):
"""Initializer that can switch from ADAM to SGD and vice versa.
Expand Down
16 changes: 16 additions & 0 deletions DeepSaki/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,19 @@
# Losses Enums
from DeepSaki.types.losses_enums import LossType
from DeepSaki.types.losses_enums import LossCalcType

# Optimizers Enums
from DeepSaki.types.optimizers_enums import CurrentOptimizer

__all__ = [
"PaddingType",
"InitializerFunc",
"MultiplicationType",
"FrequencyFilter",
"UpSampleType",
"LinearLayerType",
"DownSampleType",
"LossType",
"LossCalcType",
"CurrentOptimizer",
]
1 change: 1 addition & 0 deletions DeepSaki/types/optimizers_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class CurrentOptimizer(Enum):
ADAM: Indicates to switch to the ADAM optimizer.
NADAM: Indicates to switch to the NADAM optimizer.
"""

SGD = auto()
ADAM = auto()
NADAM = auto()
6 changes: 6 additions & 0 deletions DeepSaki/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from DeepSaki.utils.environment import detect_accelerator
from DeepSaki.utils.environment import enable_xla_acceleration
from DeepSaki.utils.environment import enable_mixed_precision

__all__ = [
"detect_accelerator",
"enable_xla_acceleration",
"enable_mixed_precision",
]
2 changes: 1 addition & 1 deletion DeepSaki/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ def enable_mixed_precision() -> None:
policy_config = "mixed_bfloat16" if tpu else "mixed_float16"
policy = tf.keras.mixed_precision.Policy(policy_config)
tf.keras.mixed_precision.set_global_policy(policy)
logging.info("Mixed precision enabled to {}".format(policy_config))
logging.info(f"Mixed precision enabled to {policy_config}")

0 comments on commit 018f376

Please sign in to comment.