Skip to content

Commit

Permalink
revert TS data processor integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Francisco Santos committed Dec 1, 2021
1 parent 611cb67 commit efa508a
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from tqdm import trange
from numpy import array, vstack, hstack
from numpy.random import normal
from typing import List

from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile
from tensorflow import data as tfdata
Expand All @@ -24,9 +23,9 @@ class TSCWGAN(BaseModel):

def __init__(self, model_parameters, gradient_penalty_weight=10):
"""Create a base TSCWGAN."""
super().__init__(model_parameters)
self.gradient_penalty_weight = gradient_penalty_weight
self.cond_dim = model_parameters.condition
super().__init__(model_parameters)

def define_gan(self):
self.generator = Generator(self.batch_size). \
Expand All @@ -45,18 +44,14 @@ def define_gan(self):
score = concat([cond, gen], axis=1)
score = self.critic(score)

def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str],
preprocess: bool = True):
super().train(data, num_cols, cat_cols, preprocess)

processed_data = self.processor.transform(data)
real_batches = self.get_batch_data(processed_data)
def train(self, data, train_arguments: TrainParameters):
real_batches = self.get_batch_data(data)
noise_batches = self.get_batch_noise()

for epoch in trange(train_arguments.epochs):
for i in range(train_arguments.critic_iter):
real_batch = next(real_batches)
noise_batch = next(noise_batches)[:len(real_batch)] # Truncate noise tensor to real data shape
noise_batch = next(noise_batches)[:len(real_batch)] # Truncate the noise tensor in the shape of the real data tensor

c_loss = self.update_critic(real_batch, noise_batch)

Expand Down Expand Up @@ -149,10 +144,9 @@ def get_batch_data(self, data, n_windows= None):

def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24):
"""For a given condition, produce n_samples of length seq_len.
The samples are returned in the original data format (any preprocessing transformation is inverted).
Args:
condition (numpy.array): Condition for the generated samples, must have the same length .
condition (numpy.array): Condition for the generated samples, must have the same length.
n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size).
seq_len (int): Length of the generated samples.
Expand All @@ -175,7 +169,7 @@ def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24):
data_.append(records)
data_ = hstack(data_)[:, :seq_len]
data.append(data_)
return self.processor.inverse_transform(array(vstack(data)))
return array(vstack(data))


class Generator(Model):
Expand Down

0 comments on commit efa508a

Please sign in to comment.