Skip to content

Commit

Permalink
fix: save load TimeGAN (#58)
Browse files Browse the repository at this point in the history
* fix: remove optimizers from the class to enable save and load.

* fix: Remove not used grads from the supervised training.
  • Loading branch information
fabclmnt committed Mar 24, 2021
1 parent 1ad9c7e commit f631807
Showing 1 changed file with 22 additions and 26 deletions.
48 changes: 22 additions & 26 deletions src/ydata_synthetic/synthesizers/timeseries/timegan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,6 @@ def define_gan(self):
outputs=Y_real,
name="RealDiscriminator")

# ----------------------------
# Init the optimizers
# ----------------------------
self.autoencoder_opt = Adam(learning_rate=self.lr)
self.supervisor_opt = Adam(learning_rate=self.lr)
self.generator_opt = Adam(learning_rate=self.lr)
self.discriminator_opt = Adam(learning_rate=self.lr)
self.embedding_opt = Adam(learning_rate=self.lr)

# ----------------------------
# Define the loss functions
# ----------------------------
Expand All @@ -112,31 +103,32 @@ def define_gan(self):


@function
def train_autoencoder(self, x):
def train_autoencoder(self, x, opt):
with GradientTape() as tape:
x_tilde = self.autoencoder(x)
embedding_loss_t0 = self._mse(x, x_tilde)
e_loss_0 = 10 * sqrt(embedding_loss_t0)

var_list = self.embedder.trainable_variables + self.recovery.trainable_variables
gradients = tape.gradient(e_loss_0, var_list)
self.autoencoder_opt.apply_gradients(zip(gradients, var_list))
opt.apply_gradients(zip(gradients, var_list))
return sqrt(embedding_loss_t0)

@function
def train_supervisor(self, x):
def train_supervisor(self, x, opt):
with GradientTape() as tape:
h = self.embedder(x)
h_hat_supervised = self.supervisor(h)
g_loss_s = self._mse(h[:, 1:, :], h_hat_supervised[:, 1:, :])

var_list = self.supervisor.trainable_variables + self.generator.trainable_variables
gradients = tape.gradient(g_loss_s, var_list)
self.supervisor_opt.apply_gradients(zip(gradients, var_list))
apply_grads = [(grad, var) for (grad, var) in zip(gradients, var_list) if grad is not None]
opt.apply_gradients(apply_grads)
return g_loss_s

@function
def train_embedder(self,x):
def train_embedder(self,x, opt):
with GradientTape() as tape:
h = self.embedder(x)
h_hat_supervised = self.supervisor(h)
Expand All @@ -148,7 +140,7 @@ def train_embedder(self,x):

var_list = self.embedder.trainable_variables + self.recovery.trainable_variables
gradients = tape.gradient(e_loss, var_list)
self.embedding_opt.apply_gradients(zip(gradients, var_list))
opt.apply_gradients(zip(gradients, var_list))
return sqrt(embedding_loss_t0)

def discriminator_loss(self, x, z):
Expand Down Expand Up @@ -176,7 +168,7 @@ def calc_generator_moments_loss(y_true, y_pred):
return g_loss_mean + g_loss_var

@function
def train_generator(self, x, z):
def train_generator(self, x, z, opt):
with GradientTape() as tape:
y_fake = self.adversarial_supervised(z)
generator_loss_unsupervised = self._bce(y_true=ones_like(y_fake),
Expand All @@ -199,17 +191,17 @@ def train_generator(self, x, z):

var_list = self.generator_aux.trainable_variables + self.supervisor.trainable_variables
gradients = tape.gradient(generator_loss, var_list)
self.generator_opt.apply_gradients(zip(gradients, var_list))
opt.apply_gradients(zip(gradients, var_list))
return generator_loss_unsupervised, generator_loss_supervised, generator_moment_loss

@function
def train_discriminator(self, x, z):
def train_discriminator(self, x, z, opt):
with GradientTape() as tape:
discriminator_loss = self.discriminator_loss(x, z)

var_list = self.discriminator.trainable_variables
gradients = tape.gradient(discriminator_loss, var_list)
self.discriminator_opt.apply_gradients(zip(gradients, var_list))
opt.apply_gradients(zip(gradients, var_list))
return discriminator_loss

def get_batch_data(self, data, n_windows):
Expand All @@ -229,16 +221,22 @@ def get_batch_noise(self):

def train(self, data, train_steps):
## Embedding network training
autoencoder_opt = Adam(learning_rate=self.lr)
for _ in tqdm(range(train_steps), desc='Emddeding network training'):
X_ = next(self.get_batch_data(data, n_windows=len(data)))
step_e_loss_t0 = self.train_autoencoder(X_)
step_e_loss_t0 = self.train_autoencoder(X_, autoencoder_opt)

## Supervised Network training
supervisor_opt = Adam(learning_rate=self.lr)
for _ in tqdm(range(train_steps), desc='Supervised network training'):
X_ = next(self.get_batch_data(data, n_windows=len(data)))
step_g_loss_s = self.train_supervisor(X_)
step_g_loss_s = self.train_supervisor(X_, supervisor_opt)

## Joint training
generator_opt = Adam(learning_rate=self.lr)
embedder_opt = Adam(learning_rate=self.lr)
discriminator_opt = Adam(learning_rate=self.lr)

step_g_loss_u = step_g_loss_s = step_g_loss_v = step_e_loss_t0 = step_d_loss = 0
for _ in tqdm(range(train_steps), desc='Joint networks training'):

Expand All @@ -250,18 +248,18 @@ def train(self, data, train_steps):
# --------------------------
# Train the generator
# --------------------------
step_g_loss_u, step_g_loss_s, step_g_loss_v = self.train_generator(X_, Z_)
step_g_loss_u, step_g_loss_s, step_g_loss_v = self.train_generator(X_, Z_, generator_opt)

# --------------------------
# Train the embedder
# --------------------------
step_e_loss_t0 = self.train_embedder(X_)
step_e_loss_t0 = self.train_embedder(X_, embedder_opt)

X_ = next(self.get_batch_data(data, n_windows=len(data)))
Z_ = next(self.get_batch_noise())
step_d_loss = self.discriminator_loss(X_, Z_)
if step_d_loss > 0.15:
step_d_loss = self.train_discriminator(X_, Z_)
step_d_loss = self.train_discriminator(X_, Z_, discriminator_opt)

def sample(self, n_samples):
steps = n_samples // self.batch_size + 1
Expand All @@ -273,8 +271,6 @@ def sample(self, n_samples):
return np.array(np.vstack(data))




class Generator(Model):
def __init__(self, hidden_dim, net_type='GRU'):
self.hidden_dim = hidden_dim
Expand Down

0 comments on commit f631807

Please sign in to comment.