Skip to content

Commit

Permalink
feat(wgan-gp): Add to init.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Oct 29, 2020
1 parent 5d37e62 commit b5867da
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/ydata_synthetic/synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ydata_synthetic.synthesizers.regular.cgan.model import CGAN
from ydata_synthetic.synthesizers.regular.wgan.model import WGAN
from ydata_synthetic.synthesizers.regular.vanillagan.model import VanilllaGAN

from ydata_synthetic.synthesizers.regular.wgan_gp.model import WGAN_GP

__all__ = [
"VanilllaGAN",
Expand Down
3 changes: 1 addition & 2 deletions src/ydata_synthetic/synthesizers/regular/wgan_gp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam


#Auxiliary Keras backend class to calculate the Random Weighted average
#https://stackoverflow.com/questions/58133430/how-to-substitute-keras-layers-merge-merge-in-tensorflow-keras
class RandomWeightedAverage(tf.keras.layers.Layer):
Expand All @@ -27,7 +26,7 @@ def call(self, inputs, **kwargs):
def compute_output_shape(self, input_shape):
return input_shape[0]

class WGAN(gan.Model):
class WGAN_GP(gan.Model):

def __init__(self, model_parameters, n_critic):
# As recommended in WGAN paper - https://arxiv.org/abs/1701.07875
Expand Down

0 comments on commit b5867da

Please sign in to comment.