Skip to content

Commit

Permalink
Init for model_parallel == 1 (facebookresearch#577)
Browse files Browse the repository at this point in the history
* gate by arch, not by mp size

* add back mp > 1 conditional
  • Loading branch information
suchenzang committed Jan 1, 2023
1 parent 59403be commit 511504b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion metaseq/distributed/utils.py
Expand Up @@ -165,7 +165,10 @@ def distributed_init(cfg: MetaseqConfig):
if nodelist:
logger.info(f"SLURM nodelist: {nodelist}")

if cfg.common.model_parallel_size > 1:
if (
getattr(cfg.model, "arch", None) == "transformer_lm_megatron"
or cfg.common.model_parallel_size > 1
):
try:
from megatron.mpu import (
initialize_model_parallel,
Expand Down

0 comments on commit 511504b

Please sign in to comment.