Skip to content

Commit

Permalink
feat: initial setup h3 feature
Browse files Browse the repository at this point in the history
  • Loading branch information
jimthompson5802 committed Jun 11, 2020
1 parent a7810e0 commit 5959be8
Showing 1 changed file with 19 additions and 59 deletions.
78 changes: 19 additions & 59 deletions ludwig/features/h3_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@


class H3BaseFeature(BaseFeature):
def __init__(self, feature):
super().__init__(feature)
self.type = H3

type = H3
preprocessing_defaults = {
'missing_value_strategy': FILL_WITH_CONST,
'fill_value': 576495936675512319
# mode 1 edge 0 resolution 0 base_cell 0
}

def __init__(self, feature):
super().__init__(feature)

@staticmethod
def get_feature_meta(column, preprocessing_parameters):
return {}
Expand All @@ -66,9 +66,7 @@ def h3_to_list(h3_int):
def add_feature_data(
feature,
dataset_df,
data,
metadata,
preprocessing_parameters=None
data
):
data[feature['name']] = np.array(
[H3BaseFeature.h3_to_list(row)
Expand All @@ -77,56 +75,18 @@ def add_feature_data(


class H3InputFeature(H3BaseFeature, InputFeature):
def __init__(self, feature):
super().__init__(feature)
encoder = 'embed'

self.encoder = 'embed'
def __init__(self, feature, encoder_obj=None):
H3BaseFeature.__init__(self, feature)
InputFeature.__init__(self)

encoder_parameters = self.overwrite_defaults(feature)
self.overwrite_defaults(feature)
if encoder_obj:
self.encoder_obj = encoder_obj
else:
self.encoder_obj = self.initialize_encoder(feature)

self.encoder_obj = self.get_h3_encoder(encoder_parameters)

def get_h3_encoder(self, encoder_parameters):
return get_from_registry(
self.encoder, h3_encoder_registry)(
**encoder_parameters
)

def _get_input_placeholder(self):
# None dimension is for dealing with variable batch size
return tf.placeholder(
tf.int32,
shape=[None, H3_VECTOR_LENGTH],
name=self.feature_name
)

def build_input(
self,
regularizer,
dropout_rate,
is_training=False,
**kwargs
):
placeholder = self._get_input_placeholder()
logger.debug('placeholder: {0}'.format(placeholder))

feature_representation, feature_representation_size = self.encoder_obj(
placeholder,
regularizer=regularizer,
dropout_rate=dropout_rate,
is_training=is_training
)
logging.debug(' feature_representation: {0}'.format(
feature_representation))

feature_representation = {
'name': self.feature_name,
'type': self.type,
'representation': feature_representation,
'size': feature_representation_size,
'placeholder': placeholder
}
return feature_representation

@staticmethod
def update_model_definition_with_metadata(
Expand All @@ -142,8 +102,8 @@ def populate_defaults(input_feature):
set_default_value(input_feature, TIED, None)


h3_encoder_registry = {
'embed': H3Embed,
'weighted_sum': H3WeightedSum,
'rnn': H3RNN
}
encoder_registry = {
'embed': H3Embed,
'weighted_sum': H3WeightedSum,
'rnn': H3RNN
}

0 comments on commit 5959be8

Please sign in to comment.