Skip to content

Commit

Permalink
fix: warning message timeGAN (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Mar 21, 2021
1 parent 4ad3c21 commit 1ad9c7e
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/ydata_synthetic/synthesizers/timeseries/timegan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def __init__(self, model_parameters, hidden_dim, seq_len, n_seq, gamma):
def define_gan(self):
self.generator_aux=Generator(self.hidden_dim).build(input_shape=(self.seq_len, self.n_seq))
self.supervisor=Supervisor(self.hidden_dim).build(input_shape=(self.hidden_dim, self.hidden_dim))
self.discriminator=Discriminator(self.hidden_dim).build(input_shape=(self.seq_len, self.n_seq))
self.discriminator=Discriminator(self.hidden_dim).build(input_shape=(self.hidden_dim, self.hidden_dim))
self.recovery = Recovery(self.hidden_dim, self.n_seq).build(input_shape=(self.hidden_dim, self.hidden_dim))
self.embedder = Embedder(self.hidden_dim).build(input_shape=(self.hidden_dim, self.n_seq))
self.embedder = Embedder(self.hidden_dim).build(input_shape=(self.seq_len, self.n_seq))

X = Input(shape=[self.seq_len, self.n_seq], batch_size=self.batch_size, name='RealData')
Z = Input(shape=[self.seq_len, self.n_seq], batch_size=self.batch_size, name='RandomNoise')
Expand Down Expand Up @@ -282,7 +282,6 @@ def __init__(self, hidden_dim, net_type='GRU'):

def build(self, input_shape):
model = Sequential(name='Generator')
model.add(Input(shape=input_shape))
model = make_net(model,
n_layers=3,
hidden_units=self.hidden_dim,
Expand Down Expand Up @@ -312,7 +311,6 @@ def __init__(self, hidden_dim, n_seq):

def build(self, input_shape):
recovery = Sequential(name='Recovery')
recovery.add(Input(shape=input_shape, name='EmbeddedData'))
recovery = make_net(recovery,
n_layers=3,
hidden_units=self.hidden_dim,
Expand All @@ -327,7 +325,6 @@ def __init__(self, hidden_dim):

def build(self, input_shape):
embedder = Sequential(name='Embedder')
embedder.add(Input(shape=input_shape, name='Data'))
embedder = make_net(embedder,
n_layers=3,
hidden_units=self.hidden_dim,
Expand All @@ -340,7 +337,6 @@ def __init__(self, hidden_dim):

def build(self, input_shape):
model = Sequential(name='Supervisor')
model.add(Input(shape=input_shape))
model = make_net(model,
n_layers=2,
hidden_units=self.hidden_dim,
Expand Down

0 comments on commit 1ad9c7e

Please sign in to comment.