Skip to content

Commit

Permalink
fix(save): Making keras pickable.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Jan 23, 2021
1 parent ae8e021 commit 99773dc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
14 changes: 3 additions & 11 deletions src/ydata_synthetic/synthesizers/gan.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import tqdm
from joblib import dump, load
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import model_from_json
from ydata_synthetic.synthesizers.saving_keras import make_keras_picklable

class Model():
def __init__(
Expand Down Expand Up @@ -35,26 +35,18 @@ def model_name(self):

def train(self, data, train_arguments):
raise NotImplementedError

def to_pickle(self):
self.generator = self.generator.to_json()
try:
self.discriminator
self.discriminitor = None
except:
self.critic = None

def sample(self, n_samples):
steps = n_samples // self.batch_size + 1
data = []
for _ in tqdm.trange(steps):
for _ in tqdm.trange(steps, desc='Synthetic data generation'):
z = tf.random.uniform([self.batch_size, self.noise_dim])
records = tf.make_ndarray(tf.make_tensor_proto(self.generator(z, training=False)))
data.append(pd.DataFrame(records))
return pd.concat(data)

def save(self, path):
self.to_pickle()
make_keras_picklable()
try:
dump(self, path)
except:
Expand Down
21 changes: 21 additions & 0 deletions src/ydata_synthetic/synthesizers/saving_keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from tensorflow.keras import Model
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils

def unpack(model, training_config, weights):
restored_model = deserialize(model)
if training_config is not None:
restored_model.compile(**saving_utils.compile_args_from_training_config(training_config))
restored_model.set_weights(weights)
return restored_model

def make_keras_picklable():
def __reduce__(self):
model_metadata = saving_utils.model_metadata(self)
training_config = model_metadata.get("training_config", None)
model = serialize(self)
weights = self.get_weights()
return (unpack, (model, training_config, weights))

cls = Model
cls.__reduce__=__reduce__

0 comments on commit 99773dc

Please sign in to comment.