Skip to content

Commit

Permalink
Refactor text formatters (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
ydudin3-zz authored and w4nderlust committed Jun 17, 2019
1 parent 5526d46 commit 9be4432
Show file tree
Hide file tree
Showing 3 changed files with 521 additions and 342 deletions.
6 changes: 3 additions & 3 deletions ludwig/features/feature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ludwig.constants import TEXT
from ludwig.constants import TIMESERIES
from ludwig.utils.strings_utils import UNKNOWN_SYMBOL
from ludwig.utils.strings_utils import format_registry
from ludwig.utils.strings_utils import tokenizer_registry

SEQUENCE_TYPES = {SEQUENCE, TEXT, TIMESERIES}

Expand All @@ -37,11 +37,11 @@ def should_regularize(regularize_layers):

def set_str_to_idx(set_string, feature_dict, format_func):
try:
format_function = format_registry[format_func]
tokenizer = tokenizer_registry[format_func]()
except ValueError:
raise Exception('Format {} not supported'.format(format_func))

out = [feature_dict.get(item, feature_dict[UNKNOWN_SYMBOL]) for item in
format_function(set_string)]
tokenizer(set_string)]

return np.array(out, dtype=np.int32)
18 changes: 9 additions & 9 deletions ludwig/features/timeseries_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ludwig.models.modules.measure_modules import squared_error
from ludwig.utils.misc import get_from_registry
from ludwig.utils.misc import set_default_value
from ludwig.utils.strings_utils import format_registry
from ludwig.utils.strings_utils import tokenizer_registry


logger = logging.getLogger(__name__)
Expand All @@ -53,13 +53,13 @@ def __init__(self, feature):

@staticmethod
def get_feature_meta(column, preprocessing_parameters):
format_function = get_from_registry(
tokenizer = get_from_registry(
preprocessing_parameters['format'],
format_registry
)
tokenizer_registry
)()
max_length = 0
for timeseries in column:
processed_line = format_function(timeseries)
processed_line = tokenizer(timeseries)
max_length = max(max_length, len(processed_line))
max_length = min(
preprocessing_parameters['timeseries_length_limit'],
Expand All @@ -76,14 +76,14 @@ def build_matrix(
padding_value,
padding='right'
):
format_function = get_from_registry(
tokenizer = get_from_registry(
format_str,
format_registry
)
tokenizer_registry
)()
max_length = 0
ts_vectors = []
for ts in timeseries:
ts_vector = np.array(format_function(ts)).astype(np.float32)
ts_vector = np.array(tokenizer(ts)).astype(np.float32)
ts_vectors.append(ts_vector)
if len(ts_vector) > max_length:
max_length = len(ts_vector)
Expand Down

0 comments on commit 9be4432

Please sign in to comment.