-
Notifications
You must be signed in to change notification settings - Fork 235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add CGAN for timeseries #108
Open
jfsantos-ds
wants to merge
6
commits into
dev
Choose a base branch
from
feat/tscwgan
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -373,4 +373,4 @@ DerivedData/ | |
|
||
# User created | ||
VERSION | ||
version.py | ||
version.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from numpy import reshape | ||
|
||
from ydata_synthetic.preprocessing.timeseries import processed_stock | ||
from ydata_synthetic.synthesizers.timeseries import TSCWGAN | ||
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters | ||
from ydata_synthetic.postprocessing.regular.inverse_preprocesser import inverse_transform | ||
|
||
model = TSCWGAN | ||
|
||
#Define the GAN and training parameters | ||
noise_dim = 32 | ||
dim = 128 | ||
seq_len = 48 | ||
cond_dim = 24 | ||
batch_size = 128 | ||
|
||
log_step = 100 | ||
epochs = 300+1 | ||
learning_rate = 5e-4 | ||
beta_1 = 0.5 | ||
beta_2 = 0.9 | ||
models_dir = './cache' | ||
critic_iter = 5 | ||
|
||
# Get transformed data stock - Univariate | ||
data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = ['Open']) | ||
data_sample = processed_data[0] | ||
|
||
model_parameters = ModelParameters(batch_size=batch_size, | ||
lr=learning_rate, | ||
betas=(beta_1, beta_2), | ||
noise_dim=noise_dim, | ||
n_cols=seq_len, | ||
layers_dim=dim, | ||
condition = cond_dim) | ||
|
||
train_args = TrainParameters(epochs=epochs, | ||
sample_interval=log_step, | ||
critic_iter=critic_iter) | ||
|
||
#Training the TSCWGAN model | ||
synthesizer = model(model_parameters, gradient_penalty_weight=10) | ||
synthesizer.train(processed_data, train_args) | ||
|
||
#Saving the synthesizer to later generate new events | ||
synthesizer.save(path='./tscwgan_stock.pkl') | ||
|
||
#Loading the synthesizer | ||
synth = model.load(path='./tscwgan_stock.pkl') | ||
|
||
#Sampling the data | ||
#Note that the data returned is not inverse processed. | ||
cond_index = 100 # Arbitrary sequence for conditioning | ||
cond_array = reshape(processed_data[cond_index][:cond_dim], (1,-1)) | ||
|
||
data_sample = synth.sample(cond_array, 1000, 100) | ||
|
||
# Inverting the scaling of the synthetic samples | ||
inv_data_sample = inverse_transform(data_sample, scaler) |
25 changes: 13 additions & 12 deletions
25
src/ydata_synthetic/postprocessing/regular/inverse_preprocesser.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,46 @@ | ||
# Inverts all preprocessing pipelines provided in the preprocessing examples | ||
from typing import Union | ||
|
||
import pandas as pd | ||
from pandas import DataFrame, concat | ||
|
||
from sklearn.pipeline import Pipeline | ||
from sklearn.compose import ColumnTransformer | ||
from sklearn.preprocessing import PowerTransformer, OneHotEncoder, StandardScaler | ||
from sklearn.preprocessing import PowerTransformer, OneHotEncoder, StandardScaler, MinMaxScaler | ||
|
||
|
||
def inverse_transform(data: pd.DataFrame, processor: Union[Pipeline, ColumnTransformer, PowerTransformer, OneHotEncoder, StandardScaler]) -> pd.DataFrame: | ||
def inverse_transform(data: DataFrame, processor: Union[Pipeline, ColumnTransformer, PowerTransformer, | ||
OneHotEncoder, StandardScaler, MinMaxScaler]) -> DataFrame: | ||
"""Inverts data transformations taking place in a standard sklearn processor. | ||
Supported processes are sklearn pipelines, column transformers or base estimators like standard scalers. | ||
|
||
Args: | ||
data (pd.DataFrame): The data object that needs inversion of preprocessing | ||
data (DataFrame): The data object that needs inversion of preprocessing | ||
processor (Union[Pipeline, ColumnTransformer, BaseEstimator]): The processor applied on the original data | ||
|
||
Returns: | ||
inv_data (pd.DataFrame): The data object after inverting preprocessing""" | ||
inv_data (DataFrame): The data object after inverting preprocessing""" | ||
inv_data = data.copy() | ||
if isinstance(processor, (PowerTransformer, OneHotEncoder, StandardScaler, Pipeline)): | ||
inv_data = pd.DataFrame(processor.inverse_transform(data), columns=processor.feature_names_in_) | ||
if isinstance(processor, (PowerTransformer, OneHotEncoder, StandardScaler, MinMaxScaler, Pipeline)): | ||
inv_data = DataFrame(processor.inverse_transform(data), columns=processor.feature_names_in_ if hasattr(processor, "feature_names_in") else None) | ||
elif isinstance(processor, ColumnTransformer): | ||
output_indices = processor.output_indices_ | ||
assert isinstance(data, pd.DataFrame), "The data to be inverted from a ColumnTransformer has to be a Pandas DataFrame." | ||
assert isinstance(data, DataFrame), "The data to be inverted from a ColumnTransformer has to be a Pandas DataFrame." | ||
for t_name, t, t_cols in processor.transformers_[::-1]: | ||
slice_ = output_indices[t_name] | ||
t_indices = list(range(slice_.start, slice_.stop, 1 if slice_.step is None else slice_.step)) | ||
if t == 'drop': | ||
continue | ||
elif t == 'passthrough': | ||
inv_cols = pd.DataFrame(data.iloc[:,t_indices].values, columns = t_cols, index = data.index) | ||
inv_cols = DataFrame(data.iloc[:,t_indices].values, columns = t_cols, index = data.index) | ||
inv_col_names = inv_cols.columns | ||
else: | ||
inv_cols = pd.DataFrame(t.inverse_transform(data.iloc[:,t_indices].values), columns = t_cols, index = data.index) | ||
inv_cols = DataFrame(t.inverse_transform(data.iloc[:,t_indices].values), columns = t_cols, index = data.index) | ||
inv_col_names = inv_cols.columns | ||
if set(inv_col_names).issubset(set(inv_data.columns)): | ||
inv_data[inv_col_names] = inv_cols[inv_col_names] | ||
else: | ||
inv_data = pd.concat([inv_data, inv_cols], axis=1) | ||
inv_data = concat([inv_data, inv_cols], axis=1) | ||
else: | ||
print('The provided data processor is not supported and cannot be inverted with this method.') | ||
return None | ||
return inv_data[processor.feature_names_in_] | ||
return inv_data[processor.feature_names_in_] if hasattr(processor, "feature_names_in") else inv_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,17 +2,30 @@ | |
Get the stock data from Yahoo finance data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Allowing subsetting of columns |
||
Data from the period 01 January 2017 - 24 January 2021 | ||
""" | ||
from typing import Optional, List | ||
|
||
import pandas as pd | ||
from typeguard import typechecked | ||
|
||
from ydata_synthetic.preprocessing.timeseries.utils import real_data_loading | ||
|
||
def transformations(path, seq_len: int): | ||
stock_df = pd.read_csv(path) | ||
@typechecked | ||
def transformations(path, seq_len: int, cols: Optional[List] = None): | ||
"""Apply min max scaling and roll windows of a temporal dataset. | ||
|
||
Args: | ||
path(str): path to a csv temporal dataframe | ||
seq_len(int): length of the rolled sequences | ||
cols (Union[str, List]): Column or list of columns to be used""" | ||
if isinstance(cols, list): | ||
stock_df = pd.read_csv(path)[cols] | ||
else: | ||
stock_df = pd.read_csv(path) | ||
try: | ||
stock_df = stock_df.set_index('Date').sort_index() | ||
except: | ||
stock_df=stock_df | ||
#Data transformations to be applied prior to be used with the synthesizer model | ||
processed_data = real_data_loading(stock_df.values, seq_len=seq_len) | ||
data, processed_data, scaler = real_data_loading(stock_df.values, seq_len=seq_len) | ||
|
||
return processed_data | ||
return data, processed_data, scaler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enabling inverse_transform of the timeseries scaler (processor needs to be returned here) |
||
from sklearn.preprocessing import MinMaxScaler | ||
|
||
# Method implemented here: https://github.com/jsyoon0823/TimeGAN/blob/master/data_loading.py | ||
# Method adapted from here: https://github.com/jsyoon0823/TimeGAN/blob/master/data_loading.py | ||
# Originally used in TimeGAN research | ||
def real_data_loading(data: np.array, seq_len): | ||
"""Load and preprocess real-world datasets. | ||
|
@@ -30,7 +30,7 @@ def real_data_loading(data: np.array, seq_len): | |
|
||
# Mix the datasets (to make it similar to i.i.d) | ||
idx = np.random.permutation(len(temp_data)) | ||
data = [] | ||
processed_data = [] | ||
for i in range(len(temp_data)): | ||
data.append(temp_data[idx[i]]) | ||
return data | ||
processed_data.append(temp_data[idx[i]]) | ||
return data, processed_data, scaler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from ydata_synthetic.synthesizers.timeseries.timegan.model import TimeGAN | ||
from ydata_synthetic.synthesizers.timeseries.tscwgan.model import TSCWGAN | ||
|
||
__all__ = [ | ||
'TimeGAN', | ||
'TSCWGAN', | ||
] |
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes in this script consist in extending the inverse support to the MinMaxScaler