Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def train(self,
validation_steps: int,
epochs: Optional[int] = None,
batch_size: Optional[int] = None,
val_json_file: Optional[str] = None) -> tf.keras.Model:
val_json_file: Optional[str] = None,
load_checkpoint_path: Optional[str] = None) -> tf.keras.Model:
"""Run EfficientDet training."""
config = self.config
if not epochs:
Expand All @@ -263,6 +264,10 @@ def train(self,
batch_size=batch_size))
train.setup_model(model, config)
train.init_experimental(config)

if load_checkpoint_path is not None:
model.load_weights(load_checkpoint_path)

model.fit(
train_dataset,
epochs=epochs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def train(self,
validation_data: Optional[
object_detector_dataloader.DataLoader] = None,
epochs: Optional[int] = None,
batch_size: Optional[int] = None) -> tf.keras.Model:
batch_size: Optional[int] = None,
load_checkpoint_path: Optional[str] = None) -> tf.keras.Model:
"""Feeds the training data for training."""
if not self.model_spec.config.drop_remainder:
raise ValueError('Must set `drop_remainder=True` during training. '
Expand Down Expand Up @@ -122,7 +123,7 @@ def train(self,
validation_data, batch_size, is_training=False)
return self.model_spec.train(self.model, train_ds, steps_per_epoch,
validation_ds, validation_steps, epochs,
batch_size, val_json_file)
batch_size, val_json_file, load_checkpoint_path)

def evaluate(self,
data: object_detector_dataloader.DataLoader,
Expand Down Expand Up @@ -225,7 +226,8 @@ def create(cls,
epochs: Optional[object_detector_dataloader.DataLoader] = None,
batch_size: Optional[int] = None,
train_whole_model: bool = False,
do_train: bool = True) -> T:
do_train: bool = True,
load_checkpoint_path: Optional[str] = None) -> T:
"""Loads data and train the model for object detection.

Args:
Expand All @@ -238,6 +240,8 @@ def create(cls,
model. Otherwise, only train the layers that are not match
`model_spec.config.var_freeze_expr`.
do_train: Whether to run training.
load_checkpoint_path: Optional, Path to checkpoint to load model weights from,
before training is started.

Returns:
An instance based on ObjectDetector.
Expand All @@ -257,7 +261,7 @@ def create(cls,

if do_train:
tf.compat.v1.logging.info('Retraining the models...')
object_detector.train(train_data, validation_data, epochs, batch_size)
object_detector.train(train_data, validation_data, epochs, batch_size, load_checkpoint_path)
else:
object_detector.create_model()

Expand Down