Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: save load TimeGAN #58

Merged
merged 2 commits into from
Mar 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions src/ydata_synthetic/synthesizers/timeseries/timegan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,6 @@ def define_gan(self):
outputs=Y_real,
name="RealDiscriminator")

# ----------------------------
# Init the optimizers
# ----------------------------
self.autoencoder_opt = Adam(learning_rate=self.lr)
self.supervisor_opt = Adam(learning_rate=self.lr)
self.generator_opt = Adam(learning_rate=self.lr)
self.discriminator_opt = Adam(learning_rate=self.lr)
self.embedding_opt = Adam(learning_rate=self.lr)

# ----------------------------
# Define the loss functions
# ----------------------------
Expand All @@ -112,31 +103,32 @@ def define_gan(self):


@function
def train_autoencoder(self, x):
def train_autoencoder(self, x, opt):
with GradientTape() as tape:
x_tilde = self.autoencoder(x)
embedding_loss_t0 = self._mse(x, x_tilde)
e_loss_0 = 10 * sqrt(embedding_loss_t0)

var_list = self.embedder.trainable_variables + self.recovery.trainable_variables
gradients = tape.gradient(e_loss_0, var_list)
self.autoencoder_opt.apply_gradients(zip(gradients, var_list))
opt.apply_gradients(zip(gradients, var_list))
return sqrt(embedding_loss_t0)

@function
def train_supervisor(self, x):
def train_supervisor(self, x, opt):
with GradientTape() as tape:
h = self.embedder(x)
h_hat_supervised = self.supervisor(h)
g_loss_s = self._mse(h[:, 1:, :], h_hat_supervised[:, 1:, :])

var_list = self.supervisor.trainable_variables + self.generator.trainable_variables
gradients = tape.gradient(g_loss_s, var_list)
self.supervisor_opt.apply_gradients(zip(gradients, var_list))
apply_grads = [(grad, var) for (grad, var) in zip(gradients, var_list) if grad is not None]
opt.apply_gradients(apply_grads)
return g_loss_s

@function
def train_embedder(self,x):
def train_embedder(self,x, opt):
with GradientTape() as tape:
h = self.embedder(x)
h_hat_supervised = self.supervisor(h)
Expand All @@ -148,7 +140,7 @@ def train_embedder(self,x):

var_list = self.embedder.trainable_variables + self.recovery.trainable_variables
gradients = tape.gradient(e_loss, var_list)
self.embedding_opt.apply_gradients(zip(gradients, var_list))
opt.apply_gradients(zip(gradients, var_list))
return sqrt(embedding_loss_t0)

def discriminator_loss(self, x, z):
Expand Down Expand Up @@ -176,7 +168,7 @@ def calc_generator_moments_loss(y_true, y_pred):
return g_loss_mean + g_loss_var

@function
def train_generator(self, x, z):
def train_generator(self, x, z, opt):
with GradientTape() as tape:
y_fake = self.adversarial_supervised(z)
generator_loss_unsupervised = self._bce(y_true=ones_like(y_fake),
Expand All @@ -199,17 +191,17 @@ def train_generator(self, x, z):

var_list = self.generator_aux.trainable_variables + self.supervisor.trainable_variables
gradients = tape.gradient(generator_loss, var_list)
self.generator_opt.apply_gradients(zip(gradients, var_list))
opt.apply_gradients(zip(gradients, var_list))
return generator_loss_unsupervised, generator_loss_supervised, generator_moment_loss

@function
def train_discriminator(self, x, z):
def train_discriminator(self, x, z, opt):
with GradientTape() as tape:
discriminator_loss = self.discriminator_loss(x, z)

var_list = self.discriminator.trainable_variables
gradients = tape.gradient(discriminator_loss, var_list)
self.discriminator_opt.apply_gradients(zip(gradients, var_list))
opt.apply_gradients(zip(gradients, var_list))
return discriminator_loss

def get_batch_data(self, data, n_windows):
Expand All @@ -229,16 +221,22 @@ def get_batch_noise(self):

def train(self, data, train_steps):
## Embedding network training
autoencoder_opt = Adam(learning_rate=self.lr)
for _ in tqdm(range(train_steps), desc='Emddeding network training'):
X_ = next(self.get_batch_data(data, n_windows=len(data)))
step_e_loss_t0 = self.train_autoencoder(X_)
step_e_loss_t0 = self.train_autoencoder(X_, autoencoder_opt)

## Supervised Network training
supervisor_opt = Adam(learning_rate=self.lr)
for _ in tqdm(range(train_steps), desc='Supervised network training'):
X_ = next(self.get_batch_data(data, n_windows=len(data)))
step_g_loss_s = self.train_supervisor(X_)
step_g_loss_s = self.train_supervisor(X_, supervisor_opt)

## Joint training
generator_opt = Adam(learning_rate=self.lr)
embedder_opt = Adam(learning_rate=self.lr)
discriminator_opt = Adam(learning_rate=self.lr)

step_g_loss_u = step_g_loss_s = step_g_loss_v = step_e_loss_t0 = step_d_loss = 0
for _ in tqdm(range(train_steps), desc='Joint networks training'):

Expand All @@ -250,18 +248,18 @@ def train(self, data, train_steps):
# --------------------------
# Train the generator
# --------------------------
step_g_loss_u, step_g_loss_s, step_g_loss_v = self.train_generator(X_, Z_)
step_g_loss_u, step_g_loss_s, step_g_loss_v = self.train_generator(X_, Z_, generator_opt)

# --------------------------
# Train the embedder
# --------------------------
step_e_loss_t0 = self.train_embedder(X_)
step_e_loss_t0 = self.train_embedder(X_, embedder_opt)

X_ = next(self.get_batch_data(data, n_windows=len(data)))
Z_ = next(self.get_batch_noise())
step_d_loss = self.discriminator_loss(X_, Z_)
if step_d_loss > 0.15:
step_d_loss = self.train_discriminator(X_, Z_)
step_d_loss = self.train_discriminator(X_, Z_, discriminator_opt)

def sample(self, n_samples):
steps = n_samples // self.batch_size + 1
Expand All @@ -273,8 +271,6 @@ def sample(self, n_samples):
return np.array(np.vstack(data))




class Generator(Model):
def __init__(self, hidden_dim, net_type='GRU'):
self.hidden_dim = hidden_dim
Expand Down