Skip to content

Commit

Permalink
feat(gan): Adding sample method
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Oct 30, 2020
1 parent 41fcb1d commit f52a06e
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/ydata_synthetic/synthesizers/gan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import tqdm

import pandas as pd
import tensorflow as tf
from tensorflow.python import keras

Expand All @@ -22,10 +25,6 @@ def define_gan(self):
def trainable_variables(self, network):
return network.trainable_variables

@property
def model(self):
return self._model

@property
def model_parameters(self):
return self._model_parameters
Expand All @@ -37,9 +36,14 @@ def model_name(self):
def train(self, data, train_arguments):
raise NotImplementedError

@tf.function
def samples(self, z):
return self.generator(z, training=False)
def sample(self, n_samples):
steps = n_samples // self.batch_size + 1
data = []
for step in tqdm.trange(steps):
z = tf.random.uniform([self.batch_size, self.noise_dim])
records = tf.make_ndarray(tf.make_tensor_proto(self.generator(z, training=False)))
data.append(pd.DataFrame(records))
return pd.concat(data)

def save(self, path, name):
assert os.path.isdir(path) == True, \
Expand Down

0 comments on commit f52a06e

Please sign in to comment.