Skip to content

Commit

Permalink
[feature] Add True bert
Browse files Browse the repository at this point in the history
  • Loading branch information
JayYip committed Apr 24, 2019
1 parent 2a24456 commit 5bdac38
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 47 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,5 @@ dmypy.json
.pyre/
.vscode
data/
**/.DS_Store
**/.DS_Store
pubmed_pmc_470k/
Empty file added __init__.py
Empty file.
8 changes: 4 additions & 4 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def create_dataset_for_bert(
prefetch=10000,
batch_size=32,
dynamic_padding=False,
bucket_batch_sizes=[64, 32, 16],
bucket_boundaries=[100, 300],
bucket_batch_sizes=[32, 16, 8],
bucket_boundaries=[64, 128],
element_length_func=_qa_ele_to_length):

tfrecord_file_list = glob(os.path.join(
Expand All @@ -320,7 +320,7 @@ def create_dataset_for_bert(
print('TF Record not found')
make_tfrecord(
data_dir, create_generator_for_bert,
bert_serialize_fn, 'BertFFN', tokenizer=tokenizer, dynamic_padding=True)
bert_serialize_fn, 'BertFFN', tokenizer=tokenizer, dynamic_padding=True, max_seq_length=max_seq_length)
tfrecord_file_list = glob(os.path.join(
data_dir, '*_BertFFN_{0}.tfrecord'.format((mode))))

Expand Down Expand Up @@ -351,7 +351,7 @@ def _parse_bert_example(example_proto):
tf.data.experimental.bucket_by_sequence_length(
element_length_func=element_length_func,
bucket_batch_sizes=bucket_batch_sizes,
bucket_boundaries=bucket_boundaries,
bucket_boundaries=bucket_boundaries
))
else:
dataset = dataset.batch(batch_size)
Expand Down
2 changes: 1 addition & 1 deletion keras_bert/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_model(token_num,
trainable=trainable,
)
if not training:
return inputs[:2], transformed
return inputs, transformed
mlm_dense_layer = keras.layers.Dense(
units=embed_dim,
activation=feed_forward_activation,
Expand Down
2 changes: 1 addition & 1 deletion keras_bert/layers/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ def get_inputs(seq_len):
"""
names = ['Token', 'Segment', 'Masked']
return [keras.layers.Input(
shape=(seq_len,),
shape=(None,),
name='Input-%s' % name,
) for name in names]
16 changes: 8 additions & 8 deletions keras_bert/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def build_model_from_config(config_file,
with open(config_file, 'r') as reader:
config = json.loads(reader.read())
if seq_len is not None:
config['max_position_embeddings'] = min(seq_len, config['max_position_embeddings'])
config['max_position_embeddings'] = min(
seq_len, config['max_position_embeddings'])
if trainable is None:
trainable = training
model = get_model(
Expand All @@ -51,10 +52,6 @@ def build_model_from_config(config_file,
if not training:
inputs, outputs = model
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
)
return model, config


Expand All @@ -76,7 +73,8 @@ def load_model_weights_from_checkpoint(model,
loader('bert/embeddings/word_embeddings'),
])
model.get_layer(name='Embedding-Position').set_weights([
loader('bert/embeddings/position_embeddings')[:config['max_position_embeddings'], :],
loader(
'bert/embeddings/position_embeddings')[:config['max_position_embeddings'], :],
])
model.get_layer(name='Embedding-Segment').set_weights([
loader('bert/embeddings/token_type_embeddings'),
Expand Down Expand Up @@ -152,6 +150,8 @@ def load_trained_model_from_checkpoint(config_file,
position embeddings will be sliced to fit the new length.
:return: model
"""
model, config = build_model_from_config(config_file, training=training, trainable=trainable, seq_len=seq_len)
load_model_weights_from_checkpoint(model, config, checkpoint_file, training=training)
model, config = build_model_from_config(
config_file, training=training, trainable=trainable, seq_len=seq_len)
load_model_weights_from_checkpoint(
model, config, checkpoint_file, training=training)
return model
36 changes: 11 additions & 25 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import tensorflow as tf
import tensorflow.keras.backend as K

from keras_bert.loader import load_trained_model_from_checkpoint


class FFN(tf.keras.layers.Layer):
def __init__(
Expand Down Expand Up @@ -54,37 +56,21 @@ def call(self, inputs):
return tf.stack([q_embedding, a_embedding], axis=1)


class BioBert(tf.keras.Model):
def __init__(self, name=''):
super(BioBert, self).__init__(name=name)

def call(self, inputs):

# inputs is dict with input features
input_ids, input_masks, segment_ids = inputs
# pass to bert
# with shape of (batch_size/2*batch_size, max_seq_len, hidden_size)
# TODO(Alex): Add true bert model
# Input: input_ids, input_masks, segment_ids all with shape (None, max_seq_len)
# Output: a tensor with shape (None, max_seq_len, hidden_size)
fake_bert_output = tf.expand_dims(tf.ones_like(
input_ids, dtype=tf.float32), axis=-1)*tf.ones([1, 1, 768], dtype=tf.float32)
max_seq_length = tf.shape(fake_bert_output)[-2]
hidden_size = tf.shape(fake_bert_output)[-1]

bert_output = fake_bert_output
return bert_output


class MedicalQAModelwithBert(tf.keras.Model):
def __init__(
self,
hidden_size=768,
dropout=0.1,
residual=True,
config_file=None,
checkpoint_file=None,
name=''):
super(MedicalQAModelwithBert, self).__init__(name=name)
self.biobert = BioBert()
self.biobert = load_trained_model_from_checkpoint(
config_file=config_file,
checkpoint_file=checkpoint_file,
training=False,
trainable=True)
self.q_ffn_layer = FFN(
hidden_size=hidden_size,
dropout=dropout,
Expand All @@ -102,9 +88,9 @@ def _avg_across_token(self, tensor):
def call(self, inputs):

q_bert_embedding = self.biobert(
(inputs['q_input_ids'], inputs['q_input_masks'], inputs['q_segment_ids']))
(inputs['q_input_ids'], inputs['q_segment_ids'], inputs['q_input_masks']))
a_bert_embedding = self.biobert(
(inputs['a_input_ids'], inputs['a_input_masks'], inputs['a_segment_ids']))
(inputs['a_input_ids'], inputs['a_segment_ids'], inputs['a_input_masks']))

# according to USE, the DAN network average embedding across tokens
q_bert_embedding = self._avg_across_token(q_bert_embedding)
Expand Down
21 changes: 14 additions & 7 deletions train_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@


def train_all(args):
tokenizer = FullTokenizer(os.path.join(args.vocab_path, 'vocab.txt'))
K.set_floatx('float32')
tokenizer = FullTokenizer(os.path.join(args.pretrained_path, 'vocab.txt'))
d = create_dataset_for_bert(
args.data_path, tokenizer=tokenizer, batch_size=args.batch_size, shuffle_buffer=100000)
args.data_path, tokenizer=tokenizer, batch_size=args.batch_size,
shuffle_buffer=1000, dynamic_padding=True, max_seq_length=args.max_seq_len)
eval_d = create_dataset_for_bert(
args.data_path, tokenizer=tokenizer, batch_size=args.batch_size, mode='eval')
medical_qa_model = MedicalQAModelwithBert()
args.data_path, tokenizer=tokenizer, batch_size=args.batch_size,
mode='eval', dynamic_padding=True, max_seq_length=args.max_seq_len)
medical_qa_model = MedicalQAModelwithBert(
config_file=os.path.join(
args.pretrained_path, 'bert_config.json'),
checkpoint_file=os.path.join(args.pretrained_path, 'biobert_model.ckpt'))
optimizer = tf.keras.optimizers.Adam()
medical_qa_model.compile(
optimizer=optimizer, loss=qa_pair_loss)

epochs = args.num_epochs
loss_metric = tf.keras.metrics.Mean()

medical_qa_model.fit(d, epochs=epochs, validation_data=eval_d)
medical_qa_model.fit(d, epochs=epochs)
medical_qa_model.summary()
K.set_learning_phase(0)
q_embedding, a_embedding = tf.unstack(
Expand All @@ -48,10 +54,11 @@ def train_all(args):
default='models/', help='path for saving trained models')
parser.add_argument('--data_path', type=str,
default='/content/gdrive/', help='path for saving trained models')
parser.add_argument('--vocab_path', type=str,
default='/content/gdrive/', help='path for saving trained models')
parser.add_argument('--pretrained_path', type=str,
default='/content/gdrive/', help='pretrained model path')
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--max_seq_len', type=int, default=256)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--validation_split', type=float, default=0.2)

Expand Down

0 comments on commit 5bdac38

Please sign in to comment.