Skip to content

Commit

Permalink
feat(ts): Add sample method.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Jan 23, 2021
1 parent 99773dc commit 2742168
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion src/ydata_synthetic/synthesizers/time_series/timegan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError

from tqdm import tqdm
import numpy as np
from tqdm import tqdm, trange

from ydata_synthetic.synthesizers import gan

Expand Down Expand Up @@ -270,6 +271,17 @@ def train(self, data, train_steps):
if step_d_loss > 0.15:
step_d_loss = self.train_discriminator(X_, Z_)

def sample(self, n_samples):
steps = n_samples // self.batch_size + 1
data = []
for _ in trange(steps, desc='Synthetic data generation'):
Z_ = next(self.get_batch_noise())
records = self.generator(Z_)
data.append(records)
return np.array(np.vstack(data))




class Generator(Model):
def __init__(self, hidden_dim, net_type='GRU'):
Expand Down Expand Up @@ -343,3 +355,62 @@ def build(self, input_shape):
output_units=self.hidden_dim)
return model


if __name__ == '__main__':
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler


"""
import quandl
quandl_api_key = "ssf2fazqBjcLq-qhWTvo"
quandl.ApiConfig.api_key=quandl_api_key
quandl.ApiConfig.verify_ssl = False
dataset = []
for tick in tickers:
dataset.append(quandl.get_table('WIKI/PRICES', ticker=tick))
data = pd.concat(dataset)
"""
tickers = ['BA', 'CAT', 'DIS', 'GE', 'IBM', 'KO']
data = pd.read_csv('wiki_prices.csv')
data=data.drop('None', axis=1)

data = data.set_index(['ticker', 'date']).adj_close.unstack(level=0).loc['2000':, tickers].dropna()

#Normalize the data
scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(data).astype(np.float32)

#Create rolling windows for the data
seq_len=24
n_seq=6

dataset = []
for i in range(len(data) - seq_len):
dataset.append(scaled_data[i:i+seq_len])
n_windows=len(dataset)

noise_dim = 32
dim = 128
batch_size = 128

log_step = 100
epochs = 500 + 1
learning_rate = 5e-4
models_dir = './cache'

gan_args = [batch_size, learning_rate, noise_dim, 24, 2, (0, 1), dim]

synth = TimeGAN(model_parameters=gan_args, hidden_dim=24, seq_len=seq_len, n_seq=n_seq, gamma=1)
synth.train(dataset, train_steps=10)

synth.sample(1000)

synth.save('./ts_synth.pkl')

print('result')

0 comments on commit 2742168

Please sign in to comment.