Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 511504b
Author: Susan Zhang <suchenzang@users.noreply.github.com>
Date:   Sun Jan 1 17:00:25 2023 +0100

    Init for model_parallel == 1 (facebookresearch#577)

    * gate by arch, not by mp size

    * add back mp > 1 conditional

commit 59403be
Author: Susan Zhang <suchenzang@users.noreply.github.com>
Date:   Sun Jan 1 00:42:37 2023 +0100

    [Cleanup] Remove MegatronTrainer (facebookresearch#576)

commit 6687b6f
Author: Susan Zhang <suchenzang@users.noreply.github.com>
Date:   Sat Dec 31 17:38:28 2022 +0100

    use bash (facebookresearch#575)

commit a87e08f
Author: Stephen Roller <roller@fb.com>
Date:   Fri Dec 30 14:11:57 2022 -0500

    Add Sharan to CODEOWNERS (facebookresearch#558)

commit 1d4af00
Author: Stephen Roller <roller@fb.com>
Date:   Fri Dec 30 14:11:47 2022 -0500

    Fix config.yml dump in training runs. (facebookresearch#557)

commit ed85aad
Author: Christian Clauss <cclauss@me.com>
Date:   Fri Dec 30 07:43:41 2022 +0100

    Current flake8 no longer accepts comments on config lines (facebookresearch#570)

    * Current flake8 no longer accepts comments on config lines

    `ValueError: Error code '#' supplied to 'extend-ignore' option does not match '^[A-Z]{1,3}[0-9]{0,3}$'`

    * flake8==6.0.0

    * Update .flake8

    * Update setup.py

    Co-authored-by: Stephen Roller <roller@fb.com>

    Co-authored-by: Stephen Roller <roller@fb.com>

commit db6842b
Author: Taichi Nishimura <lack_un@yahoo.co.jp>
Date:   Fri Dec 30 12:14:49 2022 +0900

    Add backslash to the script in projects/OPT/download_opt175b.md (facebookresearch#573)

    * add backslash to script

    * add backslash to docs/api.md

commit 966561e
Author: Binh Tang <tangbinh.na@gmail.com>
Date:   Wed Dec 28 13:11:39 2022 -0800

    Add a new script to reshard model parallel parts (facebookresearch#556)

    Co-authored-by: Binh Tang <tangbinhna@gmail.com>
  • Loading branch information
sahajgg committed Jan 2, 2023
1 parent 4eb133c commit 802aa42
Show file tree
Hide file tree
Showing 18 changed files with 303 additions and 267 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Expand Up @@ -115,7 +115,7 @@ create_conda_env: &create_conda_env
command: |
curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
chmod +x ~/miniconda.sh
~/miniconda.sh -b -p $HOME/miniconda
bash ~/miniconda.sh -b -p $HOME/miniconda
rm ~/miniconda.sh
echo 'export PATH=$HOME/miniconda/bin:$PATH' >> $BASH_ENV
source $BASH_ENV
Expand Down
9 changes: 6 additions & 3 deletions .flake8
@@ -1,7 +1,10 @@
[flake8]
extend-ignore =
F541 # f-string is missing placeholders
E203 # whitespace with black
E741 # "l" is ambiguous
# E203: whitespace with black
E203
# E741: "l" is ambiguous
E741
# F541: f-string is missing placeholders
F541
# github size
max-line-length=127
2 changes: 1 addition & 1 deletion CODEOWNERS
Validating CODEOWNERS rules …
@@ -1 +1 @@
* @suchenzang @stephenroller @ngoyal2707 @punitkoura @moyapchen @klshuster @ruanslv @davides @dgrnbrg-meta @igormolybogFB @Xirider
* @suchenzang @stephenroller @ngoyal2707 @punitkoura @moyapchen @klshuster @ruanslv @davides @dgrnbrg-meta @igormolybogFB @Xirider @sharannarang
2 changes: 1 addition & 1 deletion docs/api.md
Expand Up @@ -18,7 +18,7 @@ Complete all of the setup as mentioned in [the Setup doc](setup.md).
- Reshard the FSDP checkpoints using the script `metaseq/scripts/reshard_fsdp.py`. For example, we can merge all FSDP shards within each of the 8 model parallel parts of OPT-175B using the following command:
```bash
for j in {0..7}; do
python -m metaseq.scripts.reshard_fsdp
python -m metaseq.scripts.reshard_fsdp \
--input-glob-pattern "/path/to/raw/checkpoints/checkpoint_last-model_part-$j-shard*.pt" \
--output-shard-name "/path/to/resharded/checkpoints/reshard-model_part-$j.pt" \
--num-output-shards 1 --skip-optimizer-state True --unflatten-weights True
Expand Down
14 changes: 7 additions & 7 deletions metaseq/cli/train.py
Expand Up @@ -37,7 +37,6 @@
from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
from metaseq.file_io import PathManager
from metaseq.logging import meters, metrics, progress_bar
from metaseq.model_parallel.megatron_trainer import MegatronTrainer
from metaseq.trainer import Trainer

logging.basicConfig(
Expand Down Expand Up @@ -67,11 +66,15 @@ def main(cfg: DictConfig) -> None:

checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)

if distributed_utils.is_master(cfg.distributed_training):
if distributed_utils.is_master(cfg.distributed_training) and os.environ.get(
"METASEQ_SAVE_DIR"
):
# save a (vaguely human readable) copy of the training config
# TODO(roller): only works when launched with a sweep script
# should fix that
OmegaConf.save(
config=_flatten_config(cfg),
f=os.path.join(cfg.checkpoint.save_dir, "config.yml"),
f=os.path.join(os.environ["METASEQ_SAVE_DIR"], "config.yml"),
)

if (
Expand Down Expand Up @@ -145,10 +148,7 @@ def main(cfg: DictConfig) -> None:
task.load_dataset(valid_sub_split, combine=False, epoch=1)

# Build trainer
if cfg.common.model_parallel_size == 1:
trainer = Trainer(cfg, task, model, criterion)
else:
trainer = MegatronTrainer(cfg, task, model, criterion)
trainer = Trainer(cfg, task, model, criterion)
logger.info(
"training on {} devices (GPUs/TPUs)".format(
cfg.distributed_training.distributed_world_size
Expand Down
5 changes: 4 additions & 1 deletion metaseq/distributed/utils.py
Expand Up @@ -165,7 +165,10 @@ def distributed_init(cfg: MetaseqConfig):
if nodelist:
logger.info(f"SLURM nodelist: {nodelist}")

if cfg.common.model_parallel_size > 1:
if (
getattr(cfg.model, "arch", None) == "transformer_lm_megatron"
or cfg.common.model_parallel_size > 1
):
try:
from megatron.mpu import (
initialize_model_parallel,
Expand Down
86 changes: 0 additions & 86 deletions metaseq/model_parallel/megatron_trainer.py

This file was deleted.

Expand Up @@ -7,10 +7,8 @@

from metaseq.model_parallel.modules import (
ModelParallelTransformerDecoderLayer,
ModelParallelTransformerEncoderLayer,
)
from metaseq.models.transformer_decoder import TransformerDecoder
from metaseq.models.transformer_encoder import TransformerEncoder

try:
from megatron import mpu
Expand All @@ -26,22 +24,6 @@
logger = logging.getLogger(__name__)


class ModelParallelTransformerEncoder(TransformerEncoder):
"""
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerEncoderLayer`.
"""

def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)

if args.no_final_layer_norm:
self.layer_norm = None

def build_encoder_layer(self, args):
return ModelParallelTransformerEncoderLayer(args)


class ModelParallelTransformerDecoder(TransformerDecoder):
"""
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer
Expand Down Expand Up @@ -73,7 +55,7 @@ def output_layer(self, features, **kwargs):
False, # async_grad_allreduce
is_sequence_parallel, # sequence_parallel
)
# Gather output if model in in inference mode (i.e. evallm or generation) cause both are not yet compatible with
# Gather output if model is in inference mode (i.e. evallm or generation) cause both are not yet compatible with
# parallel vocab embeddings
if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy" or getattr(
self, "inference", False
Expand Down
23 changes: 23 additions & 0 deletions metaseq/model_parallel/models/transformer_encoder.py
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from metaseq.model_parallel.modules import ModelParallelTransformerEncoderLayer
from metaseq.models.transformer_encoder import TransformerEncoder


class ModelParallelTransformerEncoder(TransformerEncoder):
"""
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerEncoderLayer`.
"""

def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)

if args.no_final_layer_norm:
self.layer_norm = None

def build_encoder_layer(self, args):
return ModelParallelTransformerEncoderLayer(args)
4 changes: 3 additions & 1 deletion metaseq/model_parallel/models/transformer_lm.py
Expand Up @@ -5,7 +5,9 @@

import torch
import torch.nn as nn
from metaseq.model_parallel.models.transformer import ModelParallelTransformerDecoder
from metaseq.model_parallel.models.transformer_decoder import (
ModelParallelTransformerDecoder,
)
from metaseq.models import register_model, register_model_architecture
from metaseq.models.transformer_lm import TransformerLanguageModel

Expand Down
6 changes: 3 additions & 3 deletions metaseq/model_parallel/modules/__init__.py
Expand Up @@ -5,16 +5,16 @@
"""isort:skip_file"""

from .multihead_attention import ModelParallelMultiheadAttention
from .transformer_layer import (
ModelParallelTransformerEncoderLayer,
from .transformer_decoder_layer import (
ModelParallelTransformerDecoderLayer,
)
from .transformer_encoder_layer import ModelParallelTransformerEncoderLayer

from .sequence_parallel_transformer_layer import SequeuceParallelTransformerBlock

__all__ = [
"ModelParallelMultiheadAttention",
"ModelParallelTransformerEncoderLayer",
"ModelParallelTransformerDecoderLayer",
"ModelParallelTransformerEncoderLayer",
"SequeuceParallelTransformerBlock",
]
Expand Up @@ -10,7 +10,7 @@

from metaseq import utils
from metaseq.model_parallel.modules import ModelParallelMultiheadAttention
from metaseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
from metaseq.modules import TransformerDecoderLayer

try:
from megatron.mpu import (
Expand All @@ -23,31 +23,6 @@
has_megatron_submodule = False


class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer block over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""

def build_fc1(self, input_dim, output_dim):
return ColumnParallelLinear(
input_dim, output_dim, gather_output=False, skip_bias_add=True
)

def build_fc2(self, input_dim, output_dim):
return RowParallelLinear(
input_dim, output_dim, input_is_parallel=True, skip_bias_add=True
)

def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)


def _weight_init(weight):
return nn.init.kaiming_uniform_(weight, a=math.sqrt(5))

Expand All @@ -70,6 +45,11 @@ def build_fc1(
disable_bias=False,
truncate_init=False,
):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install megatron using the setup instructions!"
)

def _init_method_bias(bias):
fan_in = input_dim
bound = 1 / math.sqrt(fan_in)
Expand Down Expand Up @@ -112,6 +92,11 @@ def build_fc2(
disable_bias=False,
truncate_init=False,
):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install megatron using the setup instructions!"
)

skip_bias_add = self.skip_bias_add
if full_megatron_init:
init_method_weights = utils.scaled_init_method_normal(
Expand Down
50 changes: 50 additions & 0 deletions metaseq/model_parallel/modules/transformer_encoder_layer.py
@@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

try:
from megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
)

has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False

from metaseq.model_parallel.modules import ModelParallelMultiheadAttention
from metaseq.modules import TransformerEncoderLayer


class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer block over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""

def build_fc1(self, input_dim, output_dim):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install megatron using the setup instructions!"
)
return ColumnParallelLinear(
input_dim, output_dim, gather_output=False, skip_bias_add=True
)

def build_fc2(self, input_dim, output_dim):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install megatron using the setup instructions!"
)
return RowParallelLinear(
input_dim, output_dim, input_is_parallel=True, skip_bias_add=True
)

def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)

0 comments on commit 802aa42

Please sign in to comment.