Skip to content

Commit

Permalink
fix: TimeGAn training. (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Feb 20, 2021
1 parent ca810b2 commit 5058416
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
20 changes: 17 additions & 3 deletions src/ydata_synthetic/synthesizers/gan.py
Expand Up @@ -2,13 +2,23 @@
from joblib import dump, load
import pandas as pd
import tensorflow as tf
from tensorflow import config as tfconfig

from ydata_synthetic.synthesizers.saving_keras import make_keras_picklable

class Model():
def __init__(
self,
model_parameters
):
gpu_devices = tfconfig.list_physical_devices('GPU')
if len(gpu_devices) > 0:
try:
tfconfig.experimental.set_memory_growth(gpu_devices[0], True)
except:
# Invalid device or cannot modify virtual devices once initialized.
pass

self._model_parameters = model_parameters
[self.batch_size, self.lr, self.beta_1, self.beta_2, self.noise_dim,
self.data_dim, self.layers_dim] = model_parameters
Expand Down Expand Up @@ -53,8 +63,12 @@ def save(self, path):

@classmethod
def load(cls, path):
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

gpu_devices = tf.config.list_physical_devices('GPU')
if len(gpu_devices) > 0:
try:
tfconfig.experimental.set_memory_growth(gpu_devices[0], True)
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
synth = load(path)
return synth
10 changes: 1 addition & 9 deletions src/ydata_synthetic/synthesizers/timeseries/timegan/model.py
Expand Up @@ -36,14 +36,6 @@ def make_net(model, n_layers, hidden_units, output_units, net_type='GRU'):

class TimeGAN(gan.Model):
def __init__(self, model_parameters, hidden_dim, seq_len, n_seq, gamma):
physical_devices = tfconfig.list_physical_devices('GPU')
if len(physical_devices) > 0:
try:
tfconfig.experimental.set_memory_growth(physical_devices[0], True)
except:
# Invalid device or cannot modify virtual devices once initialized.
pass

self.seq_len=seq_len
self.n_seq=n_seq
self.hidden_dim=hidden_dim
Expand Down Expand Up @@ -138,7 +130,7 @@ def train_supervisor(self, x):
h_hat_supervised = self.supervisor(h)
g_loss_s = self._mse(h[:, 1:, :], h_hat_supervised[:, 1:, :])

var_list = self.supervisor.trainable_variables
var_list = self.supervisor.trainable_variables + self.generator.trainable_variables
gradients = tape.gradient(g_loss_s, var_list)
self.supervisor_opt.apply_gradients(zip(gradients, var_list))
return g_loss_s
Expand Down

0 comments on commit 5058416

Please sign in to comment.