Skip to content

Commit

Permalink
fix(base): Save and load methods
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Dec 28, 2020
1 parent 5ccebd8 commit 5a5e342
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 51 deletions.
16 changes: 14 additions & 2 deletions examples/wgan_example.py
Expand Up @@ -74,7 +74,19 @@
synthesizer = model(gan_args, n_critic=2)
synthesizer.train(train_sample, train_args)

#WGAN_GP models is now trained
#So we can easily generate a few samples
#Saving the synthesizer to later generate new events
synthesizer.save(path='models/wgan_creditcard.pkl')

#Loading the synthesizer
synth = WGAN_GP.load(path='models/wgan_creditcard.pkl')

#Sampling the data
#Note that the data returned it is not inverse processed.
data_sample = synth.sample(100000)

print('Testing the sample of data.')

#Sample events using the trained model



32 changes: 23 additions & 9 deletions src/ydata_synthetic/synthesizers/gan.py
@@ -1,9 +1,9 @@
import os
import tqdm

from joblib import dump, load
import pandas as pd
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.keras.models import model_from_json

class Model():
def __init__(
Expand Down Expand Up @@ -35,19 +35,33 @@ def model_name(self):

def train(self, data, train_arguments):
raise NotImplementedError

def to_pickle(self):
self.generator = self.generator.to_json()
try:
self.discriminator
self.discriminitor = None
except:
self.critic = None

def sample(self, n_samples):
steps = n_samples // self.batch_size + 1
data = []
for step in tqdm.trange(steps):
for _ in tqdm.trange(steps):
z = tf.random.uniform([self.batch_size, self.noise_dim])
records = tf.make_ndarray(tf.make_tensor_proto(self.generator(z, training=False)))
data.append(pd.DataFrame(records))
return pd.concat(data)

def save(self, path, name):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
model_path = os.path.join(path, name)
self.generator.save_weights(model_path) # Load the generator
return
def save(self, path):
self.to_pickle()
try:
dump(self, path)
except:
raise Exception('Please provide a valid path to save the model.')

@classmethod
def load(cls, path):
synth = load(path)
synth.generator = model_from_json(synth.generator)
return synth
14 changes: 0 additions & 14 deletions src/ydata_synthetic/synthesizers/regular/cgan/model.py
Expand Up @@ -108,20 +108,6 @@ def train(self, data, train_arguments):
label_z = tf.random.uniform((432,), minval=min(self.classes), maxval=max(self.classes)+1, dtype=tf.dtypes.int32)
gen_data = self.generator([z, label_z])

def save(self, path, name):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
model_path = os.path.join(path, name)
self.generator.save_weights(model_path) # Load the generator
return

def load(self, path):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
self.generator = Generator(self.batch_size)
self.generator = self.generator.load_weights(path)
return self.generator

class Generator():
def __init__(self, batch_size, num_classes):
self.batch_size = batch_size
Expand Down
13 changes: 0 additions & 13 deletions src/ydata_synthetic/synthesizers/regular/vanillagan/model.py
Expand Up @@ -105,19 +105,6 @@ def train(self, data, train_arguments):
gen_data = self.generator(z)
print('generated_data')

def save(self, path, name):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
model_path = os.path.join(path, name)
self.generator.save_weights(model_path) # Load the generator
return

def load(self, path):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
self.generator = Generator(self.batch_size)
self.generator = self.generator.load_weights(path)
return self.generator

class Generator(tf.keras.Model):
def __init__(self, batch_size):
Expand Down
7 changes: 0 additions & 7 deletions src/ydata_synthetic/synthesizers/regular/wgan/model.py
Expand Up @@ -131,13 +131,6 @@ def train(self, data, train_arguments):
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))

def load(self, path):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
self.generator = Generator(self.batch_size)
self.generator = self.generator.load_weights(path)
return self.generator


class Generator(tf.keras.Model):
def __init__(self, batch_size):
Expand Down
6 changes: 0 additions & 6 deletions src/ydata_synthetic/synthesizers/regular/wgangp/model.py
Expand Up @@ -152,12 +152,6 @@ def train(self, data, train_arguments):
self.generator.save_weights(model_checkpoint_base_name.format('generator', iteration))
self.critic.save_weights(model_checkpoint_base_name.format('critic', iteration))

def load(self, path):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
self.generator = Generator(self.batch_size)
self.generator = self.generator.load_weights(path)
return self.generator

class Generator(tf.keras.Model):
def __init__(self, batch_size):
Expand Down

0 comments on commit 5a5e342

Please sign in to comment.