Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
qywu committed Nov 12, 2021
1 parent be44501 commit 3a01236
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions torchfly/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,24 @@ def __init__(self, config: DictConfig, model: FlyModel, *args, **kwargs):
self.global_step_count = 0
self.epochs_trained = 0
self.local_step_count = 0

self.init_distributed_environment()

# Model is sent to GPU or CPU
self.init_device()
if reset_optimizers or len(self.optimizers) == 0:
self.optimizers, self.schedulers = self.configure_optimizers()

self.model = move_to_device(self.model, self.device)
self.model.device = self.device
self.init_fp16()

if self.distributed_training:
self.init_distributed_model(self.model)

# make sure the model has access to trainer info
self.model.set_trainer(self)

self.callback_handler = CallbackHandler(config,
trainer=self,
callbacks=[],
Expand Down Expand Up @@ -187,21 +202,6 @@ def train(self,

self.init_training_constants()

# Model is sent to GPU or CPU
self.init_device()
if reset_optimizers or len(self.optimizers) == 0:
self.optimizers, self.schedulers = self.configure_optimizers()

self.model = move_to_device(self.model, self.device)
self.model.device = self.device
self.init_fp16()

if self.distributed_training:
self.init_distributed_model(self.model)

# make sure the model has access to trainer info
self.model.set_trainer(self)

# Training begins
self.callback_handler.fire_event(Events.TRAIN_BEGIN)

Expand Down

0 comments on commit 3a01236

Please sign in to comment.