Skip to content

Commit

Permalink
Auto regressive timeseries sampling method
Browse files Browse the repository at this point in the history
  • Loading branch information
Francisco Santos committed Dec 1, 2021
1 parent 3f6cbe5 commit 611cb67
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
6 changes: 3 additions & 3 deletions examples/timeseries/tscwgan_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from numpy import squeeze
from numpy import reshape

from ydata_synthetic.preprocessing.timeseries import processed_stock
from ydata_synthetic.synthesizers.timeseries import TSCWGAN
Expand Down Expand Up @@ -51,9 +51,9 @@
#Sampling the data
#Note that the data returned is not inverse processed.
cond_index = 100 # Arbitrary sequence for conditioning
cond_array = squeeze(processed_data[cond_index][:cond_dim], axis=1)
cond_array = reshape(processed_data[cond_index][:cond_dim], (1,-1))

data_sample = synth.sample(cond_array, 1000)
data_sample = synth.sample(cond_array, 1000, 100)

# Inverting the scaling of the synthetic samples
data_sample = inverse_transform(data_sample, scaler)
51 changes: 27 additions & 24 deletions src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
And on: https://github.com/CasperHogenboom/WGAN_financial_time-series
"""
from tqdm import trange
from numpy import array, vstack
from numpy import array, vstack, hstack
from numpy.random import normal
from typing import List

from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, constant
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile
from tensorflow import data as tfdata
from tensorflow.keras import Model, Sequential
from tensorflow.keras.optimizers import Adam
Expand All @@ -17,7 +17,6 @@
from ydata_synthetic.synthesizers.gan import BaseModel
from ydata_synthetic.synthesizers import TrainParameters
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty
from ydata_synthetic.synthesizers.timeseries import TimeSeriesDataProcessor

class TSCWGAN(BaseModel):

Expand Down Expand Up @@ -148,31 +147,35 @@ def get_batch_data(self, data, n_windows= None):
.shuffle(buffer_size=n_windows)
.batch(self.batch_size).repeat())

def sample(self, cond_array: array, n_samples: int, inverse_transform: bool = True):
"""Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.
The returned samples per condition will always be a multiple of batch_size and equal or bigger than n_samples.
Arguments:
cond_array (numpy array): Array with the set of conditions for the sampling process.
n_samples (int): Number of samples to be taken for each condition in cond_array.
inverse_transform (bool): """
assert len(cond_array.shape) == 2, "Condition array should be two-dimensional. N_conditions x cond_dim"
assert cond_array.shape[1] == self.cond_dim, \
f"The condition sequences should have a {self.cond_dim} length."
steps = n_samples // self.batch_size + 1
def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24):
"""For a given condition, produce n_samples of length seq_len.
The samples are returned in the original data format (any preprocessing transformation is inverted).
Args:
condition (numpy.array): Condition for the generated samples, must have the same length .
n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size).
seq_len (int): Length of the generated samples.
Returns:
data (numpy.array): An array of data of shape [n_samples, seq_len]"""
assert len(condition.shape) == 2, "Condition array should be two-dimensional."
assert condition.shape[1] == self.cond_dim, \
f"The condition sequence should have {self.cond_dim} length."
batches = n_samples // self.batch_size + 1
ar_steps = seq_len // self.data_dim + 1
data = []
z_dist = self.get_batch_noise()
for condition in cond_array:
for batch in trange(batches, desc=f'Synthetic data generation'):
data_ = []
cond_seq = convert_to_tensor(condition, float32)
cond_seq = tile(cond_seq, multiples=[self.batch_size, 1])
for step in trange(steps, desc=f'Synthetic data generation'):
gen_input = concat([cond_seq, next(z_dist)], axis=1)
gen_input = concat([tile(cond_seq, multiples=[self.batch_size, 1]), next(z_dist)], axis=1)
for step in range(ar_steps):
records = self.generator(gen_input, training=False)
data.append(records)
data = array(vstack(data))
if inverse_transform:
return self.processor.inverse_transform(data)
return data
gen_input = concat([records[:, -self.cond_dim:], next(z_dist)], axis=1)
data_.append(records)
data_ = hstack(data_)[:, :seq_len]
data.append(data_)
return self.processor.inverse_transform(array(vstack(data)))


class Generator(Model):
Expand Down

0 comments on commit 611cb67

Please sign in to comment.