Skip to content

Commit

Permalink
Renmaed bert parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
w4nderlust committed Jul 24, 2019
1 parent b6a1b8c commit 7f528de
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions ludwig/models/modules/sequence_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ class BERT:
def __init__(
self,
config_path,
init_checkpoint_path=None,
checkpoint_path=None,
do_lower_case=True,
reduce_output=True,
**kwargs
Expand All @@ -1680,10 +1680,10 @@ def __init__(
"Please install bert-tensorflow: pip install bert-tensorflow"
)

self.init_checkpoint_path = init_checkpoint_path
self.checkpoint_path = checkpoint_path
self.do_lower_case = do_lower_case

if config_path is None or init_checkpoint_path is None:
if config_path is None or checkpoint_path is None:
raise ValueError(
'BERT config and model checkpoint paths are required'
)
Expand Down Expand Up @@ -1715,10 +1715,10 @@ def __call__(
)

# initialize weights from the checkpoint file
if self.init_checkpoint_path is not None:
if self.checkpoint_path is not None:
validate_case_matches_checkpoint(
self.do_lower_case,
self.init_checkpoint_path
self.checkpoint_path
)

tvars = tf.trainable_variables()
Expand All @@ -1728,12 +1728,12 @@ def __call__(
initialized_variable_names
) = BERT.get_assignment_map_from_checkpoint(
tvars,
self.init_checkpoint_path,
self.checkpoint_path,
prefix=prefix
)

tf.train.init_from_checkpoint(
self.init_checkpoint_path,
self.checkpoint_path,
assignment_map
)

Expand Down

0 comments on commit 7f528de

Please sign in to comment.