Skip to content

Commit

Permalink
Moved encoder and decoder registries in base input and output feature…
Browse files Browse the repository at this point in the history
…s to enforce every new feature type to have them
  • Loading branch information
w4nderlust committed Mar 21, 2020
1 parent f650b72 commit 27cba17
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 31 deletions.
22 changes: 21 additions & 1 deletion ludwig/features/base_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ludwig.models.modules.fully_connected_modules import FCStack
from ludwig.models.modules.reduction_modules import reduce_sequence
from ludwig.utils.misc import merge_dict
from ludwig.utils.misc import merge_dict, get_from_registry
from ludwig.utils.tf_utils import sequence_length_3D


Expand Down Expand Up @@ -66,6 +66,16 @@ def update_model_definition_with_metadata(
def populate_defaults(input_feature):
pass

@property
@abstractmethod
def encoder_registry(self):
pass

def initialize_encoder(self, encoder_parameters):
return get_from_registry(self.encoder, self.encoder_registry)(
**encoder_parameters
)


class OutputFeature(ABC, BaseFeature, tf.keras.Model):

Expand Down Expand Up @@ -110,6 +120,16 @@ def __init__(self, feature):
# default_bias_constraint=None,
)

@property
@abstractmethod
def decoder_registry(self):
pass

def initialize_decoder(self, decoder_parameters):
return get_from_registry(self.decoder, self.decoder_registry)(
**decoder_parameters
)

def train_loss(self, targets, predictions):
return self.train_loss_function(targets, predictions)

Expand Down
48 changes: 18 additions & 30 deletions ludwig/features/numerical_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ludwig.models.modules.metric_modules import R2Score
from ludwig.models.modules.numerical_decoders import Regressor
from ludwig.models.modules.numerical_encoders import NumericalPassthroughEncoder
from ludwig.utils.misc import set_default_value, get_from_registry
from ludwig.utils.misc import set_default_value
from ludwig.utils.misc import set_default_values

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -109,12 +109,7 @@ def __init__(self, feature, encoder_obj=None):
if encoder_obj:
self.encoder_obj = encoder_obj
else:
self.encoder_obj = self.get_numerical_encoder(encoder_parameters)

def get_numerical_encoder(self, encoder_parameters):
return get_from_registry(self.encoder, numerical_encoder_registry)(
**encoder_parameters
)
self.encoder_obj = self.initialize_encoder(encoder_parameters)

def call(self, inputs, training=None, mask=None):
assert isinstance(inputs, tf.Tensor)
Expand All @@ -141,15 +136,14 @@ def update_model_definition_with_metadata(
def populate_defaults(input_feature):
set_default_value(input_feature, TIED, None)


numerical_encoder_registry = {
'dense': FCStack,
'passthrough': NumericalPassthroughEncoder,
'null': NumericalPassthroughEncoder,
'none': NumericalPassthroughEncoder,
'None': NumericalPassthroughEncoder,
None: NumericalPassthroughEncoder
}
encoder_registry = {
'dense': FCStack,
'passthrough': NumericalPassthroughEncoder,
'null': NumericalPassthroughEncoder,
'none': NumericalPassthroughEncoder,
'None': NumericalPassthroughEncoder,
None: NumericalPassthroughEncoder
}


class NumericalOutputFeature(NumericalBaseFeature, OutputFeature):
Expand All @@ -165,16 +159,11 @@ def __init__(self, feature):

decoder_parameters = self.overwrite_defaults(feature)

self.decoder_obj = self.get_numerical_decoder(decoder_parameters)
self.decoder_obj = self.initialize_decoder(decoder_parameters)

self._setup_loss()
self._setup_metrics()

def get_numerical_decoder(self, decoder_parameters):
return get_from_registry(self.decoder, numerical_decoder_registry)(
**decoder_parameters
)

def predictions(
self,
inputs, # hidden
Expand Down Expand Up @@ -303,11 +292,10 @@ def populate_defaults(output_feature):
}
)


numerical_decoder_registry = {
'regressor': Regressor,
'null': Regressor,
'none': Regressor,
'None': Regressor,
None: Regressor
}
decoder_registry = {
'regressor': Regressor,
'null': Regressor,
'none': Regressor,
'None': Regressor,
None: Regressor
}
2 changes: 2 additions & 0 deletions ludwig/models/modules/numerical_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# ==============================================================================
import logging

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Layer

logger = logging.getLogger(__name__)
Expand Down

0 comments on commit 27cba17

Please sign in to comment.