From ac6ff37a26e23fa9a090d6fa150854f43157b913 Mon Sep 17 00:00:00 2001 From: fabclmnt Date: Sat, 23 Jan 2021 19:42:24 +0000 Subject: [PATCH] fix(timeseries): Sampling method corrected --- .../synthesizers/time_series/timegan/model.py | 60 ------------------- 1 file changed, 60 deletions(-) diff --git a/src/ydata_synthetic/synthesizers/time_series/timegan/model.py b/src/ydata_synthetic/synthesizers/time_series/timegan/model.py index 8232d27b..9bf8eee9 100644 --- a/src/ydata_synthetic/synthesizers/time_series/timegan/model.py +++ b/src/ydata_synthetic/synthesizers/time_series/timegan/model.py @@ -354,63 +354,3 @@ def build(self, input_shape): hidden_units=self.hidden_dim, output_units=self.hidden_dim) return model - - -if __name__ == '__main__': - import pandas as pd - import numpy as np - import matplotlib.pyplot as plt - from sklearn.preprocessing import MinMaxScaler - - - """ - import quandl - quandl_api_key = "ssf2fazqBjcLq-qhWTvo" - quandl.ApiConfig.api_key=quandl_api_key - quandl.ApiConfig.verify_ssl = False - - - dataset = [] - for tick in tickers: - dataset.append(quandl.get_table('WIKI/PRICES', ticker=tick)) - - data = pd.concat(dataset) - """ - tickers = ['BA', 'CAT', 'DIS', 'GE', 'IBM', 'KO'] - data = pd.read_csv('wiki_prices.csv') - data=data.drop('None', axis=1) - - data = data.set_index(['ticker', 'date']).adj_close.unstack(level=0).loc['2000':, tickers].dropna() - - #Normalize the data - scaler = MinMaxScaler() - scaled_data = scaler.fit_transform(data).astype(np.float32) - - #Create rolling windows for the data - seq_len=24 - n_seq=6 - - dataset = [] - for i in range(len(data) - seq_len): - dataset.append(scaled_data[i:i+seq_len]) - n_windows=len(dataset) - - noise_dim = 32 - dim = 128 - batch_size = 128 - - log_step = 100 - epochs = 500 + 1 - learning_rate = 5e-4 - models_dir = './cache' - - gan_args = [batch_size, learning_rate, noise_dim, 24, 2, (0, 1), dim] - - synth = TimeGAN(model_parameters=gan_args, hidden_dim=24, seq_len=seq_len, n_seq=n_seq, gamma=1) - synth.train(dataset, train_steps=10) - - synth.sample(1000) - - synth.save('./ts_synth.pkl') - - print('result')