Skip to content

Commit

Permalink
fix(timeseries): Sampling method corrected
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Jan 23, 2021
1 parent 2742168 commit ac6ff37
Showing 1 changed file with 0 additions and 60 deletions.
60 changes: 0 additions & 60 deletions src/ydata_synthetic/synthesizers/time_series/timegan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,63 +354,3 @@ def build(self, input_shape):
hidden_units=self.hidden_dim,
output_units=self.hidden_dim)
return model


if __name__ == '__main__':
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler


"""
import quandl
quandl_api_key = "ssf2fazqBjcLq-qhWTvo"
quandl.ApiConfig.api_key=quandl_api_key
quandl.ApiConfig.verify_ssl = False
dataset = []
for tick in tickers:
dataset.append(quandl.get_table('WIKI/PRICES', ticker=tick))
data = pd.concat(dataset)
"""
tickers = ['BA', 'CAT', 'DIS', 'GE', 'IBM', 'KO']
data = pd.read_csv('wiki_prices.csv')
data=data.drop('None', axis=1)

data = data.set_index(['ticker', 'date']).adj_close.unstack(level=0).loc['2000':, tickers].dropna()

#Normalize the data
scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(data).astype(np.float32)

#Create rolling windows for the data
seq_len=24
n_seq=6

dataset = []
for i in range(len(data) - seq_len):
dataset.append(scaled_data[i:i+seq_len])
n_windows=len(dataset)

noise_dim = 32
dim = 128
batch_size = 128

log_step = 100
epochs = 500 + 1
learning_rate = 5e-4
models_dir = './cache'

gan_args = [batch_size, learning_rate, noise_dim, 24, 2, (0, 1), dim]

synth = TimeGAN(model_parameters=gan_args, hidden_dim=24, seq_len=seq_len, n_seq=n_seq, gamma=1)
synth.train(dataset, train_steps=10)

synth.sample(1000)

synth.save('./ts_synth.pkl')

print('result')

0 comments on commit ac6ff37

Please sign in to comment.