Skip to content

Commit

Permalink
fix: remove gauumbel softmax dependencies (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Feb 23, 2023
1 parent cea1d8e commit f25ff47
Show file tree
Hide file tree
Showing 7 changed files with 3 additions and 11 deletions.
1 change: 0 additions & 1 deletion src/ydata_synthetic/synthesizers/regular/cgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#Import ydata synthetic classes
from ....synthesizers import TrainParameters
from ....synthesizers.gan import ConditionalModel
from ....utils.gumbel_softmax import GumbelSoftmaxActivation

class CGAN(ConditionalModel):
"CGAN model for discrete conditions"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ....synthesizers import TrainParameters
from ....synthesizers.gan import BaseModel
from ....synthesizers.loss import Mode, gradient_penalty
from ....utils.gumbel_softmax import GumbelSoftmaxActivation

class CRAMERGAN(BaseModel):

Expand Down
2 changes: 1 addition & 1 deletion src/ydata_synthetic/synthesizers/regular/cwgangp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#Import ydata synthetic classes
from ....synthesizers import TrainParameters
from ....synthesizers.gan import BaseModel, ConditionalModel
from ....synthesizers.gan import ConditionalModel
from ....synthesizers.regular.wgangp.model import WGAN_GP

class CWGANGP(ConditionalModel, WGAN_GP):
Expand Down
1 change: 0 additions & 1 deletion src/ydata_synthetic/synthesizers/regular/dragan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#Import ydata synthetic classes
from ....synthesizers.gan import BaseModel
from ....synthesizers.loss import Mode, gradient_penalty
from ....utils.gumbel_softmax import GumbelSoftmaxActivation

class DRAGAN(BaseModel):

Expand Down
7 changes: 2 additions & 5 deletions src/ydata_synthetic/synthesizers/regular/vanillagan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def __init__(self, model_parameters):

def define_gan(self, activation_info: Optional[NamedTuple]):
self.generator = Generator(self.batch_size).\
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
activation_info = activation_info, tau = self.tau)
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,)

self.discriminator = Discriminator(self.batch_size).\
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
Expand Down Expand Up @@ -136,14 +135,12 @@ class Generator(tf.keras.Model):
def __init__(self, batch_size):
self.batch_size=batch_size

def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None, tau: Optional[float] = None):
def build_model(self, input_shape, dim, data_dim):
input= Input(shape=input_shape, batch_size=self.batch_size)
x = Dense(dim, activation='relu')(input)
x = Dense(dim * 2, activation='relu')(x)
x = Dense(dim * 4, activation='relu')(x)
x = Dense(data_dim)(x)
if activation_info:
x = GumbelSoftmaxActivation(activation_info, tau=tau)(x)
return Model(inputs=input, outputs=x)

class Discriminator(tf.keras.Model):
Expand Down
1 change: 0 additions & 1 deletion src/ydata_synthetic/synthesizers/regular/wgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#Import ydata synthetic classes
from ....synthesizers import TrainParameters
from ....synthesizers.gan import BaseModel
from ....utils.gumbel_softmax import GumbelSoftmaxActivation

#Auxiliary Keras backend class to calculate the Random Weighted average
#https://stackoverflow.com/questions/58133430/how-to-substitute-keras-layers-merge-merge-in-tensorflow-keras
Expand Down
1 change: 0 additions & 1 deletion src/ydata_synthetic/synthesizers/regular/wgangp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#Import ydata synthetic classes
from ....synthesizers import TrainParameters
from ....synthesizers.gan import BaseModel
from ....utils.gumbel_softmax import GumbelSoftmaxActivation

class WGAN_GP(BaseModel):

Expand Down

0 comments on commit f25ff47

Please sign in to comment.