Skip to content

Commit

Permalink
Fix bugs on saving model.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangermeng2021 committed Jun 29, 2021
1 parent 574a95a commit 0fdc147
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@

# PVT-tensorflow2
[![Python 3.7](https://img.shields.io/badge/Python-3.7-3776AB)](https://www.python.org/downloads/release/python-360/)
[![TensorFlow 2.4](https://img.shields.io/badge/TensorFlow-2.4-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.2.0)

A Tensorflow2.x implementation of Pyramid Vision Transformer as described in [Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions](https://arxiv.org/abs/2102.12122)

## Update Log
[2021-06-29]
* Fix bug on saving model

[2021-03-20]
* Add PVT-tiny,PVT-small,PVT-medium,PVT-large.

Expand Down
8 changes: 8 additions & 0 deletions model/PVT.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def build(self, input_shape):
def call(self, x):
return x+self.pos_embed

def get_config(self):

config = super().get_config().copy()
config.update({
'img_len': self.img_len,
})
return config

def get_pvt(img_size,num_classes,block_depth,mlp_ratio,drop_path_rate,first_level_patch_size,embed_dims,num_heads,sr_ratio,attention_drop_rate,drop_rate):
block_drop_path_rate = np.linspace(0, drop_path_rate, sum(block_depth))
block_depth_index = 0
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ def main(args):
os.makedirs(args.checkpoints)
# lr_cb = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=20, verbose=1, min_lr=0)
lr_cb = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath=args.checkpoints+'/best_weight_{epoch}_{accuracy:.3f}_{val_accuracy:.3f}',
monitor='val_accuracy',mode='max',
verbose=1,save_best_only=True)
model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath=args.checkpoints+'/best_weight_{epoch}_{accuracy:.3f}_{val_accuracy:.3f}.h5',
monitor='val_accuracy',mode='max',
verbose=1,save_best_only=True,save_weights_only=True)
cbs=[lr_cb,
# model_checkpoint_cb
model_checkpoint_cb
]
model.compile(optimizer,loss_object,metrics=["accuracy"],)
model.fit(train_generator,
validation_data=val_generator,
epochs=args.epochs,
callbacks=cbs,
verbose=2,
verbose=1,
)

if __name__== "__main__":
Expand Down

0 comments on commit 0fdc147

Please sign in to comment.