Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing logging and errors blocking multi GPU trianing of Torch models #1509

Merged
merged 11 commits into from
Feb 21, 2023
17 changes: 15 additions & 2 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ def training_step(self, train_batch, batch_idx) -> torch.Tensor:
-1
] # By convention target is always the last element returned by datasets
loss = self._compute_loss(output, target)
self.log("train_loss", loss, batch_size=train_batch[0].shape[0], prog_bar=True)
self.log(
"train_loss",
loss,
batch_size=train_batch[0].shape[0],
prog_bar=True,
sync_dist=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PTL doc says Use with care as this may lead to a significant communication overhead..
Do we have any idea if/when this could cause issues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far in practical testing on 8 GPU-s I noticed no adverse effects. Thus said, it depends also on the distribution strategy also. I used the default ddp_spawn, as mentioned.

)
self._calculate_metrics(output, target, self.train_metrics)
return loss

Expand All @@ -159,7 +165,13 @@ def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
output = self._produce_train_output(val_batch[:-1])
target = val_batch[-1]
loss = self._compute_loss(output, target)
self.log("val_loss", loss, batch_size=val_batch[0].shape[0], prog_bar=True)
self.log(
"val_loss",
loss,
batch_size=val_batch[0].shape[0],
prog_bar=True,
sync_dist=True,
)
self._calculate_metrics(output, target, self.val_metrics)
return loss

Expand Down Expand Up @@ -274,6 +286,7 @@ def _calculate_metrics(self, output, target, metrics):
on_step=False,
logger=True,
prog_bar=True,
sync_dist=True,
)

def configure_optimizers(self):
Expand Down