From bd4bd79cbdc885e406b5a48b695a4352f127ed34 Mon Sep 17 00:00:00 2001 From: Viktor Nilsson Date: Tue, 18 Jan 2022 10:39:05 +0100 Subject: [PATCH] Add option to load model weights from checkpoint before starting to train objectdetector --- .../core/task/model_spec/object_detector_spec.py | 7 ++++++- .../lite/model_maker/core/task/object_detector.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py b/tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py index 02777ccfd66..c0c680982af 100644 --- a/tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py +++ b/tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py @@ -246,7 +246,8 @@ def train(self, validation_steps: int, epochs: Optional[int] = None, batch_size: Optional[int] = None, - val_json_file: Optional[str] = None) -> tf.keras.Model: + val_json_file: Optional[str] = None, + load_checkpoint_path: Optional[str] = None) -> tf.keras.Model: """Run EfficientDet training.""" config = self.config if not epochs: @@ -263,6 +264,10 @@ def train(self, batch_size=batch_size)) train.setup_model(model, config) train.init_experimental(config) + + if load_checkpoint_path is not None: + model.load_weights(load_checkpoint_path) + model.fit( train_dataset, epochs=epochs, diff --git a/tensorflow_examples/lite/model_maker/core/task/object_detector.py b/tensorflow_examples/lite/model_maker/core/task/object_detector.py index bdf659ced99..57ef774bda4 100644 --- a/tensorflow_examples/lite/model_maker/core/task/object_detector.py +++ b/tensorflow_examples/lite/model_maker/core/task/object_detector.py @@ -94,7 +94,8 @@ def train(self, validation_data: Optional[ object_detector_dataloader.DataLoader] = None, epochs: Optional[int] = None, - batch_size: Optional[int] = None) -> tf.keras.Model: + batch_size: Optional[int] = None, + load_checkpoint_path: Optional[str] = None) -> tf.keras.Model: """Feeds the training data for training.""" if not self.model_spec.config.drop_remainder: raise ValueError('Must set `drop_remainder=True` during training. ' @@ -122,7 +123,7 @@ def train(self, validation_data, batch_size, is_training=False) return self.model_spec.train(self.model, train_ds, steps_per_epoch, validation_ds, validation_steps, epochs, - batch_size, val_json_file) + batch_size, val_json_file, load_checkpoint_path) def evaluate(self, data: object_detector_dataloader.DataLoader, @@ -225,7 +226,8 @@ def create(cls, epochs: Optional[object_detector_dataloader.DataLoader] = None, batch_size: Optional[int] = None, train_whole_model: bool = False, - do_train: bool = True) -> T: + do_train: bool = True, + load_checkpoint_path: Optional[str] = None) -> T: """Loads data and train the model for object detection. Args: @@ -238,6 +240,8 @@ def create(cls, model. Otherwise, only train the layers that are not match `model_spec.config.var_freeze_expr`. do_train: Whether to run training. + load_checkpoint_path: Optional, Path to checkpoint to load model weights from, + before training is started. Returns: An instance based on ObjectDetector. @@ -257,7 +261,7 @@ def create(cls, if do_train: tf.compat.v1.logging.info('Retraining the models...') - object_detector.train(train_data, validation_data, epochs, batch_size) + object_detector.train(train_data, validation_data, epochs, batch_size, load_checkpoint_path) else: object_detector.create_model()