Skip to content

Commit

Permalink
Add working addons (#182)
Browse files Browse the repository at this point in the history
* Add working addons

* Add eradicate

* Add dlint

* Decrease complexity (#184)

* Add addon (#186)

* Add `pytest-style` (#192)

* Add addon

* Fix randomized error message

* Add addon (#188)

* Add addon (#191)

* Add `pandas-vet` (#190)

* Add addon

* noqa torch.stack

* remove double quotes (#187)

* Add addon (#185)

* Add `flake8-docstrings` (#193)

* Add addon

* Fix D100

* Add more docstrings

* Fix docstrings

* Update docstrings

* Fix lint

* Add `flake8-builtins` (#189)

* Add addon

* Add variables-names

* Fix bug

* Fix mistakes

* Add `flake8-multiline-containers` (#183)

* Add addon

* Add addon

* Address feedback

* Fix lint

* Fix bugs

* Remove pydoclint

* Ignore D101 errors

* Update ignores
  • Loading branch information
fealho committed Dec 2, 2021
1 parent 5c0f5bf commit 7e11e15
Show file tree
Hide file tree
Showing 20 changed files with 277 additions and 133 deletions.
13 changes: 8 additions & 5 deletions ctgan/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""CLI."""

import argparse

from ctgan.data import read_csv, read_tsv, write_tsv
Expand Down Expand Up @@ -45,10 +47,10 @@ def _parse_args():
parser.add_argument('--load', default=None, type=str,
help='A filename to load a trained synthesizer.')

parser.add_argument("--sample_condition_column", default=None, type=str,
help="Select a discrete column name.")
parser.add_argument("--sample_condition_column_value", default=None, type=str,
help="Specify the value of the selected discrete column.")
parser.add_argument('--sample_condition_column', default=None, type=str,
help='Select a discrete column name.')
parser.add_argument('--sample_condition_column_value', default=None, type=str,
help='Specify the value of the selected discrete column.')

parser.add_argument('data', help='Path to training data')
parser.add_argument('output', help='Path of the output file')
Expand All @@ -57,6 +59,7 @@ def _parse_args():


def main():
"""CLI."""
args = _parse_args()
if args.tsv:
data, discrete_columns = read_tsv(args.data, args.metadata)
Expand Down Expand Up @@ -95,5 +98,5 @@ def main():
sampled.to_csv(args.output, index=False)


if __name__ == "__main__":
if __name__ == '__main__':
main()
17 changes: 11 additions & 6 deletions ctgan/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Data loading."""

import json

import numpy as np
import pandas as pd


def read_csv(csv_filename, meta_filename=None, header=True, discrete=None):

"""Read a csv file."""
data = pd.read_csv(csv_filename, header='infer' if header else None)

if meta_filename:
Expand All @@ -30,11 +32,12 @@ def read_csv(csv_filename, meta_filename=None, header=True, discrete=None):


def read_tsv(data_filename, meta_filename):
"""Read a tsv file."""
with open(meta_filename) as f:
column_info = f.readlines()

column_info_raw = [
x.replace("{", " ").replace("}", " ").split()
x.replace('{', ' ').replace('}', ' ').split()
for x in column_info
]

Expand All @@ -52,9 +55,9 @@ def read_tsv(data_filename, meta_filename):
column_info.append(item[1:])

meta = {
"continuous_columns": continuous,
"discrete_columns": discrete,
"column_info": column_info
'continuous_columns': continuous,
'discrete_columns': discrete,
'column_info': column_info
}

with open(data_filename) as f:
Expand All @@ -77,7 +80,9 @@ def read_tsv(data_filename, meta_filename):


def write_tsv(data, meta, output_filename):
with open(output_filename, "w") as f:
"""Write to a tsv file."""
with open(output_filename, 'w') as f:

for row in data:
for idx, col in enumerate(row):
if idx in meta['continuous_columns']:
Expand Down
42 changes: 23 additions & 19 deletions ctgan/data_sampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""DataSampler module."""

import numpy as np


Expand All @@ -9,13 +11,13 @@ def __init__(self, data, output_info, log_frequency):

def is_discrete_column(column_info):
return (len(column_info) == 1
and column_info[0].activation_fn == "softmax")
and column_info[0].activation_fn == 'softmax')

n_discrete_columns = sum(
[1 for column_info in output_info if is_discrete_column(column_info)])

self._discrete_column_matrix_st = np.zeros(
n_discrete_columns, dtype="int32")
n_discrete_columns, dtype='int32')

# Store the row id for each category in each discrete column.
# For example _rid_by_cat_cols[a][b] is a list of all rows with the
Expand All @@ -39,19 +41,21 @@ def is_discrete_column(column_info):
assert st == data.shape[1]

# Prepare an interval matrix for efficiently sample conditional vector
max_category = max(
[column_info[0].dim for column_info in output_info
if is_discrete_column(column_info)], default=0)
max_category = max([
column_info[0].dim
for column_info in output_info
if is_discrete_column(column_info)
], default=0)

self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32')
self._discrete_column_n_category = np.zeros(
n_discrete_columns, dtype='int32')
self._discrete_column_category_prob = np.zeros(
(n_discrete_columns, max_category))
self._discrete_column_n_category = np.zeros(n_discrete_columns, dtype='int32')
self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category))
self._n_discrete_columns = n_discrete_columns
self._n_categories = sum(
[column_info[0].dim for column_info in output_info
if is_discrete_column(column_info)])
self._n_categories = sum([
column_info[0].dim
for column_info in output_info
if is_discrete_column(column_info)
])

st = 0
current_id = 0
Expand All @@ -64,8 +68,7 @@ def is_discrete_column(column_info):
if log_frequency:
category_freq = np.log(category_freq + 1)
category_prob = category_freq / np.sum(category_freq)
self._discrete_column_category_prob[current_id, :span_info.dim] = (
category_prob)
self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob
self._discrete_column_cond_st[current_id] = current_cond_st
self._discrete_column_n_category[current_id] = span_info.dim
current_cond_st += span_info.dim
Expand Down Expand Up @@ -102,8 +105,7 @@ def sample_condvec(self, batch):
mask = np.zeros((batch, self._n_discrete_columns), dtype='float32')
mask[np.arange(batch), discrete_column_id] = 1
category_id_in_col = self._random_choice_prob_index(discrete_column_id)
category_id = (self._discrete_column_cond_st[discrete_column_id]
+ category_id_in_col)
category_id = (self._discrete_column_cond_st[discrete_column_id] + category_id_in_col)
cond[np.arange(batch), category_id] = 1

return cond, mask, discrete_column_id, category_id_in_col
Expand Down Expand Up @@ -142,11 +144,13 @@ def sample_data(self, n, col, opt):
return self._data[idx]

def dim_cond_vec(self):
"""Return the total number of categories."""
return self._n_categories

def generate_cond_from_condition_column_info(self, condition_info, batch):
"""Generate the condition vector."""
vec = np.zeros((batch, self._n_categories), dtype='float32')
id = self._discrete_column_matrix_st[condition_info["discrete_column_id"]
] + condition_info["value_id"]
vec[:, id] = 1
id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']]
id_ += condition_info['value_id']
vec[:, id_] = 1
return vec
50 changes: 29 additions & 21 deletions ctgan/data_transformer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""DataTransformer module."""

from collections import namedtuple

import numpy as np
import pandas as pd
from rdt.transformers import OneHotEncodingTransformer
from sklearn.mixture import BayesianGaussianMixture

SpanInfo = namedtuple("SpanInfo", ["dim", "activation_fn"])
SpanInfo = namedtuple('SpanInfo', ['dim', 'activation_fn'])
ColumnTransformInfo = namedtuple(
"ColumnTransformInfo", ["column_name", "column_type",
"transform", "transform_aux",
"output_info", "output_dimensions"])
'ColumnTransformInfo', [
'column_name', 'column_type',
'transform', 'transform_aux',
'output_info', 'output_dimensions'
]
)


class DataTransformer(object):
Expand Down Expand Up @@ -45,7 +50,7 @@ def _fit_continuous(self, column_name, raw_column_data):
num_components = valid_component_indicator.sum()

return ColumnTransformInfo(
column_name=column_name, column_type="continuous", transform=gm,
column_name=column_name, column_type='continuous', transform=gm,
transform_aux=valid_component_indicator,
output_info=[SpanInfo(1, 'tanh'), SpanInfo(num_components, 'softmax')],
output_dimensions=1 + num_components)
Expand All @@ -59,12 +64,12 @@ def _fit_discrete(self, column_name, raw_column_data):
num_categories = len(ohe.dummies)

return ColumnTransformInfo(
column_name=column_name, column_type="discrete", transform=ohe,
column_name=column_name, column_type='discrete', transform=ohe,
transform_aux=None,
output_info=[SpanInfo(num_categories, 'softmax')],
output_dimensions=num_categories)

def fit(self, raw_data, discrete_columns=tuple()):
def fit(self, raw_data, discrete_columns=()):
"""Fit GMM for continuous columns and One hot encoder for discrete columns.
This step also counts the #columns in matrix data, and span information.
Expand All @@ -83,7 +88,7 @@ def fit(self, raw_data, discrete_columns=tuple()):
self._column_raw_dtypes = raw_data.infer_objects().dtypes
self._column_transform_info_list = []
for column_name in raw_data.columns:
raw_column_data = raw_data[column_name].values
raw_column_data = raw_data[column_name].to_numpy()
if column_name in discrete_columns:
column_transform_info = self._fit_discrete(column_name, raw_data[column_name])
else:
Expand All @@ -109,10 +114,12 @@ def _transform_continuous(self, column_transform_info, raw_column_data):
component_porb_t = component_probs[i] + 1e-6
component_porb_t = component_porb_t / component_porb_t.sum()
selected_component[i] = np.random.choice(
np.arange(num_components), p=component_porb_t)
np.arange(num_components),
p=component_porb_t
)

selected_normalized_value = normalized_values[
np.arange(len(raw_column_data)), selected_component].reshape([-1, 1])
aranged = np.arange(len(raw_column_data))
selected_normalized_value = normalized_values[aranged, selected_component].reshape([-1, 1])
selected_normalized_value = np.clip(selected_normalized_value, -.99, .99)

selected_component_onehot = np.zeros_like(component_probs)
Expand All @@ -122,7 +129,7 @@ def _transform_continuous(self, column_transform_info, raw_column_data):
def _transform_discrete(self, column_transform_info, raw_column_data):
ohe = column_transform_info.transform
data = pd.DataFrame(raw_column_data, columns=[column_transform_info.column_name])
return [ohe.transform(data).values]
return [ohe.transform(data).to_numpy()]

def transform(self, raw_data):
"""Take raw data and output a matrix data."""
Expand All @@ -132,11 +139,11 @@ def transform(self, raw_data):

column_data_list = []
for column_transform_info in self._column_transform_info_list:
column_data = raw_data[[column_transform_info.column_name]].values
if column_transform_info.column_type == "continuous":
column_data = raw_data[[column_transform_info.column_name]].to_numpy()
if column_transform_info.column_type == 'continuous':
column_data_list += self._transform_continuous(column_transform_info, column_data)
else:
assert column_transform_info.column_type == "discrete"
assert column_transform_info.column_type == 'discrete'
column_data_list += self._transform_discrete(column_transform_info, column_data)

return np.concatenate(column_data_list, axis=1).astype(float)
Expand Down Expand Up @@ -200,17 +207,18 @@ def inverse_transform(self, data, sigmas=None):
recovered_data = (pd.DataFrame(recovered_data, columns=column_names)
.astype(self._column_raw_dtypes))
if not self.dataframe:
recovered_data = recovered_data.values
recovered_data = recovered_data.to_numpy()

return recovered_data

def convert_column_name_value_to_id(self, column_name, value):
"""Get the ids of the given `column_name`."""
discrete_counter = 0
column_id = 0
for column_transform_info in self._column_transform_info_list:
if column_transform_info.column_name == column_name:
break
if column_transform_info.column_type == "discrete":
if column_transform_info.column_type == 'discrete':
discrete_counter += 1

column_id += 1
Expand All @@ -220,12 +228,12 @@ def convert_column_name_value_to_id(self, column_name, value):

ohe = column_transform_info.transform
data = pd.DataFrame([value], columns=[column_transform_info.column_name])
one_hot = ohe.transform(data).values[0]
one_hot = ohe.transform(data).to_numpy()[0]
if sum(one_hot) == 0:
raise ValueError(f"The value `{value}` doesn't exist in the column `{column_name}`.")

return {
"discrete_column_id": discrete_counter,
"column_id": column_id,
"value_id": np.argmax(one_hot)
'discrete_column_id': discrete_counter,
'column_id': column_id,
'value_id': np.argmax(one_hot)
}
3 changes: 3 additions & 0 deletions ctgan/demo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Demo module."""

import pandas as pd

DEMO_URL = 'http://ctgan-data.s3.amazonaws.com/census.csv.gz'


def load_demo():
"""Load the demo."""
return pd.read_csv(DEMO_URL, compression='gzip')
2 changes: 2 additions & 0 deletions ctgan/synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Synthesizers module."""

from ctgan.synthesizers.ctgan import CTGANSynthesizer
from ctgan.synthesizers.tvae import TVAESynthesizer

Expand Down
8 changes: 6 additions & 2 deletions ctgan/synthesizers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""BaseSynthesizer module."""

import torch


Expand All @@ -8,14 +10,16 @@ class BaseSynthesizer:
"""

def save(self, path):
"""Save the model in the passed `path`."""
device_backup = self._device
self.set_device(torch.device("cpu"))
self.set_device(torch.device('cpu'))
torch.save(self, path)
self.set_device(device_backup)

@classmethod
def load(cls, path):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
"""Load the model stored in the passed `path`."""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = torch.load(path)
model.set_device(device)
return model
Loading

0 comments on commit 7e11e15

Please sign in to comment.