diff --git a/Jenkinsfile b/Jenkinsfile index 55e836eea13a..83e6daa8ccb7 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -125,6 +125,7 @@ pipeline { sh 'python tests/core_ptl/check_imports.py --domain "nlp"' } } + stage('L0: Unit Tests GPU') { steps { sh 'NEMO_NUMBA_MINVER=0.53 pytest -m "not pleasefixme" --with_downloads' @@ -3517,6 +3518,64 @@ pipeline { failFast true steps { sh "python examples/nlp/language_modeling/megatron_retro_pretraining.py \ + trainer.num_nodes=1 \ + trainer.devices=2 \ + trainer.precision=bf16 \ + trainer.accelerator=gpu \ + model.data.data_prefix=['none'] \ + exp_manager.exp_dir=examples/nlp/language_modeling/mcore_retro_results \ + model.mcore_gpt=True \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.optim.name=distributed_fused_adam \ + model.retro.retro_project_dir=/home/TestData/nlp/megatron_retro/mcore_retro/micro-wiki-core \ + model.data.num_workers=4 \ + model.micro_batch_size=1 \ + model.data.shuffle_documents=False \ + trainer.val_check_interval=30 \ + +trainer.num_sanity_val_steps=0 \ + model.init_method_std=0.023 \ + model.optim.lr=6.0e-4 \ + model.megatron_amp_O2=True \ + model.data.splits_string=\'\"98,2,0\"\' \ + model.data.dataloader_type=cyclic \ + trainer.max_steps=10" + sh "python examples/nlp/language_modeling/megatron_retro_pretraining.py \ + trainer.num_nodes=1 \ + trainer.devices=2 \ + trainer.precision=bf16 \ + trainer.accelerator=gpu \ + model.data.data_prefix=['none'] \ + exp_manager.exp_dir=examples/nlp/language_modeling/mcore_retro_results \ + model.mcore_gpt=True \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.optim.name=distributed_fused_adam \ + model.retro.retro_project_dir=/home/TestData/nlp/megatron_retro/mcore_retro/micro-wiki-core \ + model.data.num_workers=4 \ + model.micro_batch_size=1 \ + model.data.shuffle_documents=False \ + trainer.val_check_interval=30 \ + +trainer.num_sanity_val_steps=0 \ + model.init_method_std=0.023 \ + model.optim.lr=6.0e-4 \ + model.megatron_amp_O2=True \ + model.data.splits_string=\'\"98,2,0\"\' \ + model.data.dataloader_type=cyclic \ + trainer.max_steps=20" + sh "rm -rf examples/nlp/language_modeling/mcore_retro_results" + } + } + stage('L2: (Legacy) Megatron RETRO Pretraining and Resume Training') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "python examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py \ trainer.devices=2 \ trainer.num_nodes=1 \ trainer.accelerator=gpu \ @@ -3527,7 +3586,7 @@ pipeline { trainer.precision=16 \ trainer.gradient_clip_val=1.0 \ trainer.val_check_interval=10 \ - exp_manager.exp_dir=examples/nlp/language_modeling/retro_results \ + exp_manager.exp_dir=examples/nlp/language_modeling/retro_legacy_results \ model.data.data_prefix='' \ model.data.knn_index='' \ model.data.retrieval_prefix='' \ @@ -3546,7 +3605,7 @@ pipeline { model.enc_cross_attention=[1] \ model.dec_cross_attention=[1] \ +model.data.mock=True" - sh "python examples/nlp/language_modeling/megatron_retro_pretraining.py \ + sh "python examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py \ trainer.devices=2 \ trainer.num_nodes=1 \ trainer.accelerator=gpu \ @@ -3557,7 +3616,7 @@ pipeline { trainer.precision=16 \ trainer.gradient_clip_val=1.0 \ trainer.val_check_interval=10 \ - exp_manager.exp_dir=examples/nlp/language_modeling/retro_results \ + exp_manager.exp_dir=examples/nlp/language_modeling/retro_legacy_results \ model.data.data_prefix='' \ model.data.knn_index='' \ model.data.retrieval_prefix='' \ @@ -3576,10 +3635,10 @@ pipeline { model.enc_cross_attention=[1] \ model.dec_cross_attention=[1] \ +model.data.mock=True" - sh "rm -rf examples/nlp/language_modeling/retro_results" + sh "rm -rf examples/nlp/language_modeling/retro_legacy_results" } } - stage('L2: Megatron RETRO muTransfer Pretraining Performance') { + stage('L2: (Legacy) Megatron RETRO muTransfer Pretraining Performance') { when { anyOf { branch 'main' @@ -3600,7 +3659,7 @@ pipeline { trainer.limit_val_batches=0 \ trainer.gradient_clip_val=1.0 \ +trainer.num_sanity_val_steps=0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/retro_results/ \ + exp_manager.exp_dir=examples/nlp/language_modeling/retro_legacy_results/ \ +exp_manager.version=smalltest \ model.data.neighbors=2 \ model.megatron_amp_O2=False \ @@ -3651,7 +3710,7 @@ import torch if not (torch.cuda.is_available() and 'A100' in torch.cuda.get_device_name()): import sys sys.exit(0) -event_file = list(pathlib.Path('examples/nlp/language_modeling/retro_results/megatron_retro/smalltest').glob('events.out.tfevents*'))[0] +event_file = list(pathlib.Path('examples/nlp/language_modeling/retro_legacy_results/megatron_retro/smalltest').glob('events.out.tfevents*'))[0] ea = EventAccumulator(str(event_file)).Reload() vals = [] for i in ea.Scalars('reduced_train_loss'): @@ -3659,7 +3718,7 @@ for i in ea.Scalars('reduced_train_loss'): training_curve = pd.DataFrame({'loss': vals}) gt_curve = pd.read_csv('/home/TestData/nlp/megatron_retro/expected_learning_curve.csv') assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' - sh "rm -rf examples/nlp/language_modeling/retro_results" + sh "rm -rf examples/nlp/language_modeling/retro_legacy_results" } } stage('L2: BioMegatron Bert NER Task') { diff --git a/examples/nlp/language_modeling/conf/megatron_bert_config.yaml b/examples/nlp/language_modeling/conf/megatron_bert_config.yaml index 58e874386c44..bc66ae717ebb 100644 --- a/examples/nlp/language_modeling/conf/megatron_bert_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_bert_config.yaml @@ -158,4 +158,4 @@ model: name: CosineAnnealing warmup_steps: 500 constant_steps: 50000 - min_lr: 2e-5 + min_lr: 2e-5 \ No newline at end of file diff --git a/examples/nlp/language_modeling/conf/megatron_retro_config.yaml b/examples/nlp/language_modeling/conf/megatron_retro_config.yaml old mode 100644 new mode 100755 index dafdcf542f11..159bb163ad0a --- a/examples/nlp/language_modeling/conf/megatron_retro_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_retro_config.yaml @@ -1,127 +1,257 @@ defaults: - - .@model: megatron_model_base_config + - _self_ + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: -name: test_retro +name: megatron_retro restore_from_path: null # used when starting from a .nemo file trainer: - devices: 2 + devices: 1 num_nodes: 1 accelerator: gpu precision: 16 logger: False # logger provided by exp_manager enable_checkpointing: False use_distributed_sampler: False - max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch. + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches log_every_n_steps: 10 val_check_interval: 100 - limit_val_batches: null - limit_test_batches: null - accumulate_grad_batches: 1 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually exp_manager: explicit_log_dir: null exp_dir: null - name: megatron_retro + name: ${name} create_wandb_logger: False wandb_logger_kwargs: project: null name: null resume_if_exists: True resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} create_checkpoint_callback: True checkpoint_callback_params: monitor: val_loss save_top_k: 10 mode: min always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits filename: 'megatron_retro--{val_loss:.2f}-{step}-{consumed_samples}' model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} - model: - version: 1 # indicate the retro model version + # use RETROModel from megatron.core, since RETRO model inherited from gpt, mcore_gpt is used + mcore_gpt: True - # model parallelism - micro_batch_size: 4 - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 # has to be one. not supporting pipeline parallel yet + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 16 # limited by GPU memory + global_batch_size: 256 # will be overrided by value from RETRO preprocessed workdir + rampup_batch_size: null # Should be a list of 3 values: [, , ] + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline # model architecture - encoder_seq_length: 2048 - max_position_embeddings: ${.encoder_seq_length} - - gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) - - dump_debug_info: False # dump out the debug information - dump_debug_info_to_file: False # dump out the debug information to files - - # retro architecture - chunk_size: 64 # the chunk size used to retrive - enc_num_layers: 4 # total number of encoder layers - dec_num_layers: 6 # total number of decoder layers - enc_cross_attention: [3] # layer numbers for cross attention in encoder - dec_cross_attention: [3, 5] # layer numbers for chunked cross attention in decoder - add_position_embedding: False # whether use the absolute position encoding - + encoder_seq_length: 512 # will be overrided by value from RETRO preprocessed workdir + max_position_embeddings: ${.encoder_seq_length} # will be overrided by value from RETRO preprocessed workdir + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 12 + init_method_std: 0.023 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0.1 # Dropout probability for hidden state transformer. + attention_dropout: 0.1 # Dropout probability for attention + ffn_dropout: 0.1 # Dropout probability in the feed-forward layer. + kv_channels: 64 # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: False # scale Q * K^T by 1 / layer-number. + normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm' + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. pre_process: True # add embedding post_process: True # add pooler - bert_binary_head: True # BERT binary head + persist_layer_norm: True # Use of persistent fused layer norm kernel. + bias: True # Whether to use bias terms in all weight matrices. + activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] + headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. + transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] + openai_gelu: False # Use OpenAI's GELU instead of the default GeLU + normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. + position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. + rotary_percentage: 0.5 # If using position_embedding_type=rope, then the per head dim is multiplied by this. + attention_type: 'multihead' # Attention type. Options ['multihead'] + share_embeddings_and_output_weights: True # Share embedding and output layer weights. + overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595. + num_query_groups: null # Number of query groups for group query attention. If None, normal attention is used. - megatron_amp_O2: False # use AMP with O2 style mixed precision instead of native amp on-the-fly weight autocasting. - grad_allreduce_chunk_size_mb: 125 - - megatron_lm_compatible: False # a flag to indicate whether the model is compatible with Megatron LM + retro: # specific arguments for RETRO model + retro_project_dir: null + retro_encoder_num_layers: 2 + retro_encoder_hidden_dropout: 0.1 + retro_encoder_attention_dropout: 0.1 + retro_num_neighbors: 2 + retro_num_retrieved_chunks: 2 + retro_verify_neighbor_count: True tokenizer: library: 'megatron' - type: 'GPT2BPETokenizer' - model: null - vocab_file: null - merge_file: null + type: null # will be overrided by value from RETRO preprocessed workdir + model: null # will be overrided by value from RETRO preprocessed workdir + vocab_file: null # will be overrided by value from RETRO preprocessed workdir + merge_file: null # will be overrided by value from RETRO preprocessed workdir delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. - # precision + # Mixed precision native_amp_init_scale: 4294967296 # 2 ** 32 native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 - # miscellaneous - seed: 1234 + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + + # Fusion + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. + bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + get_attention_mask_from_fusion: False # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. + # Miscellaneous + seed: 1234 # will be overrided by value from RETRO preprocessed workdir + resume_from_checkpoint: null # manually set the checkpoint file to load from + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + sync_batch_comm: False # Enable stream synchronization after each p2p communication between pipeline stages + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Transformer Engine + transformer_engine: True + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: True # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1024 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: max # 'most_recent' or 'max'. Algorithm for computing amax from history + reduce_amax: True # Perform reduction to sync amax tensors across GPUs after every iteration + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + ub_tp_comm_overlap: False + # Use userbuffer backend to overlap tensor-parallel communications with computes. + # This feature is only available with Transformer Engine and squence parallelism enabled and, currently, supports only GPT models. + ub_tp_comm_overlap_cfg: null + # A yaml file with userbuffer communicator configurations. This file should provide `method`, `dtype`, `num_sm`, `num_splits`, + # `cga_size`, `num_splits`, `set_sm_margin`, and `aggregate` for the communicators to use custom settings. + # If the configuration file is not provided a default setting is used for all communicators. + + ## Flash Attention + use_flash_attention: False # Use flash attention in self-attention module, this config does nothing when transformer_engine=True or mcore_gpt=True + data: - # Path to data must be specified by the user. - # can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Path to data must be specified by the user. + # Supports List, String and Dictionary + # List : can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", # Or see example below: # data_prefix: # - .5 # - /raid/data/pile/my-gpt3_00_text_document # - .5 # - /raid/data/pile/my-gpt3_01_text_document - data_prefix: ??? # list of training datasets - knn_index: ??? # list of KNN map index files - retrieval_prefix: ??? # a singe path to retrieval data + # Dictionary: can override from CLI "model.data.data_prefix"={"train":[1.0, /path/to/data], "validation":/path/to/data, "test":/path/to/test} + # Or see example below: + # "model.data.data_prefix: {train:[1.0,/path/to/data], validation:[/path/to/data], test:[/path/to/test]}" + data_prefix: ??? # will be overrided by value from RETRO preprocessed workdir index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix - data_impl: retmmap # for retro model, this is the only allowed type - splits_string: 900,50,50 - seq_length: ${model.encoder_seq_length} # must be multiple of the chunk_size in your dataset + data_impl: mmap + splits_string: 98,2,0 + seq_length: ${model.encoder_seq_length} # will be overrided by value from RETRO preprocessed workdir skip_warmup: True - num_workers: 0 + num_workers: 2 dataloader_type: single # cyclic - neighbors: 2 # number of retrieved neighbors + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + validation_drop_last: True # Set to false if the last partial validation samples is to be consumed + no_seqlen_plus_one_input_tokens: False # Set to True to disable fetching (sequence length + 1) input tokens, instead get (sequence length) input tokens and mask the last token + pad_samples_to_global_batch_size: False # Set to True if you want to pad the last partial batch with -1's to equal global batch size + shuffle_documents: True # Set to False to disable documents shuffling. Sample index will still be shuffled + exchange_indices_distributed: False # Set to True to exchange indices via torch.distributed instead of filesystem + retro_data: + retro_block_size: 10000 + retro_chunk_length: 64 + retro_split_preprocessing: 98,2,0 + retro_neighbor_dirs: null + + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [0] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes optim: - name: fused_adam - lr: 1e-4 - weight_decay: 0.01 + name: distributed_fused_adam + lr: 6.0e-4 + weight_decay: 0.1 betas: - 0.9 - - 0.98 + - 0.95 sched: name: CosineAnnealing - warmup_steps: 500 - constant_steps: 50000 - min_lr: 1e-5 + min_lr: 6.0e-5 + warmup_steps: null + max_steps: 750000 + + gc_interval: 0 + # Interval of the host memory garbage collection. When it is zero, collectiion relies on the automatic garbage collector. + # If an interger value larger than zero is set, collection is done manually by the batch step interval of `gc_interval`. diff --git a/examples/nlp/language_modeling/conf/megatron_retro_config_legacy.yaml b/examples/nlp/language_modeling/conf/megatron_retro_config_legacy.yaml new file mode 100644 index 000000000000..dafdcf542f11 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_retro_config_legacy.yaml @@ -0,0 +1,127 @@ +defaults: + - .@model: megatron_model_base_config + +name: test_retro +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 2 + num_nodes: 1 + accelerator: gpu + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch. + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: null + limit_test_batches: null + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_retro + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + filename: 'megatron_retro--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + + +model: + version: 1 # indicate the retro model version + + # model parallelism + micro_batch_size: 4 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 # has to be one. not supporting pipeline parallel yet + + # model architecture + encoder_seq_length: 2048 + max_position_embeddings: ${.encoder_seq_length} + + gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + dump_debug_info: False # dump out the debug information + dump_debug_info_to_file: False # dump out the debug information to files + + # retro architecture + chunk_size: 64 # the chunk size used to retrive + enc_num_layers: 4 # total number of encoder layers + dec_num_layers: 6 # total number of decoder layers + enc_cross_attention: [3] # layer numbers for cross attention in encoder + dec_cross_attention: [3, 5] # layer numbers for chunked cross attention in decoder + add_position_embedding: False # whether use the absolute position encoding + + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + bert_binary_head: True # BERT binary head + + megatron_amp_O2: False # use AMP with O2 style mixed precision instead of native amp on-the-fly weight autocasting. + grad_allreduce_chunk_size_mb: 125 + + megatron_lm_compatible: False # a flag to indicate whether the model is compatible with Megatron LM + + tokenizer: + library: 'megatron' + type: 'GPT2BPETokenizer' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # miscellaneous + seed: 1234 + + data: + # Path to data must be specified by the user. + # can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Or see example below: + # data_prefix: + # - .5 + # - /raid/data/pile/my-gpt3_00_text_document + # - .5 + # - /raid/data/pile/my-gpt3_01_text_document + data_prefix: ??? # list of training datasets + knn_index: ??? # list of KNN map index files + retrieval_prefix: ??? # a singe path to retrieval data + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_impl: retmmap # for retro model, this is the only allowed type + splits_string: 900,50,50 + seq_length: ${model.encoder_seq_length} # must be multiple of the chunk_size in your dataset + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic + neighbors: 2 # number of retrieved neighbors + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 1e-5 diff --git a/examples/nlp/language_modeling/megatron_retro_pretraining.py b/examples/nlp/language_modeling/megatron_retro_pretraining.py index c84656d4b657..2a0c04f695f6 100644 --- a/examples/nlp/language_modeling/megatron_retro_pretraining.py +++ b/examples/nlp/language_modeling/megatron_retro_pretraining.py @@ -12,88 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +# To suppress BF16 compile related issue in the CI runs with turing/V100 +import torch._dynamo +import torch.multiprocessing as mp from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector -from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel -from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo -from nemo.collections.nlp.parts.nlp_overrides import ( - CustomProgressBar, - GradScaler, - MegatronHalfPrecisionPlugin, - NLPDDPStrategy, - NLPSaveRestoreConnector, -) +from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager +torch._dynamo.config.suppress_errors = True + @hydra_runner(config_path="conf", config_name="megatron_retro_config") def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') - megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) - plugins = [] - strategy = NLPDDPStrategy( - no_ddp_communication_hook=True if megatron_amp_O2 else False, - gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, - find_unused_parameters=False, - ) - - if cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: - scaler = None - if cfg.trainer.precision in [16, '16', '16-mixed']: - scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), - growth_interval=cfg.model.get('native_amp_growth_interval', 1000), - hysteresis=cfg.model.get('hysteresis', 2), - ) - plugin_precision = '16-mixed' - else: - plugin_precision = 'bf16-mixed' - if megatron_amp_O2: - plugins.append(MegatronHalfPrecisionPlugin(plugin_precision, device='cuda', scaler=scaler)) - else: - plugins.append(MixedPrecisionPlugin(plugin_precision, device='cuda', scaler=scaler)) - # Set precision None after precision plugins are created as PTL >= 2.1 does not allow both - # precision plugins and precision to exist - cfg.trainer.precision = None - - if cfg.get('cluster_type', None) == 'BCP': - plugins.append(TorchElasticEnvironment()) - - callbacks = [] - # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks - if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: - callbacks.append(CustomProgressBar()) - trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer, callbacks=callbacks) - + trainer = MegatronTrainerBuilder(cfg).create_trainer() exp_manager(trainer, cfg.exp_manager) - # resume_from_checkpoint = uninject_model_parallel_rank(resume_from_checkpoint) - logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') - - # load existing nemo retro model - if cfg.get("restore_from_path", None) is not None: - save_restore_connector = NLPSaveRestoreConnector() - if os.path.isdir(cfg.restore_from_path): - save_restore_connector.model_extracted_dir = cfg.restore_from_path - model = MegatronRetrievalModel.restore_from( - restore_path=cfg.restore_from_path, - trainer=trainer, - override_config_path=cfg.model, - save_restore_connector=save_restore_connector, - strict=False, - ) - else: - model = MegatronRetrievalModel(cfg.model, trainer) + model = MegatronRetroModel(cfg.model, trainer) trainer.fit(model) diff --git a/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py b/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py new file mode 100644 index 000000000000..4653222b3438 --- /dev/null +++ b/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py @@ -0,0 +1,102 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning import Trainer +from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector + +from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo +from nemo.collections.nlp.parts.nlp_overrides import ( + CustomProgressBar, + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + NLPSaveRestoreConnector, +) +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="megatron_retro_config_legacy") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) + plugins = [] + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True if megatron_amp_O2 else False, + gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + ) + + if cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: + scaler = None + if cfg.trainer.precision in [16, '16', '16-mixed']: + scaler = GradScaler( + init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + growth_interval=cfg.model.get('native_amp_growth_interval', 1000), + hysteresis=cfg.model.get('hysteresis', 2), + ) + plugin_precision = '16-mixed' + else: + plugin_precision = 'bf16-mixed' + if megatron_amp_O2: + plugins.append(MegatronHalfPrecisionPlugin(plugin_precision, device='cuda', scaler=scaler)) + else: + plugins.append(MixedPrecisionPlugin(plugin_precision, device='cuda', scaler=scaler)) + # Set precision None after precision plugins are created as PTL >= 2.1 does not allow both + # precision plugins and precision to exist + cfg.trainer.precision = None + + if cfg.get('cluster_type', None) == 'BCP': + plugins.append(TorchElasticEnvironment()) + + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer, callbacks=callbacks) + + exp_manager(trainer, cfg.exp_manager) + + # resume_from_checkpoint = uninject_model_parallel_rank(resume_from_checkpoint) + logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') + + # load existing nemo retro model + if cfg.get("restore_from_path", None) is not None: + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.restore_from_path): + save_restore_connector.model_extracted_dir = cfg.restore_from_path + model = MegatronRetrievalModel.restore_from( + restore_path=cfg.restore_from_path, + trainer=trainer, + override_config_path=cfg.model, + save_restore_connector=save_restore_connector, + strict=False, + ) + else: + model = MegatronRetrievalModel(cfg.model, trainer) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py index f0a501d7cc13..377bff309b7c 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py @@ -12,32 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""RETRO Style dataset.""" +"""RETRO style dataset.""" import os -from typing import List +import time import numpy as np import torch +from omegaconf.dictconfig import DictConfig from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( get_datasets_weights_and_num_samples, get_train_valid_test_split_, ) from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset -from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import ( - _build_index_mappings, - get_indexed_dataset_, -) -from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import ( - KNNIndex, - MMapRetrievalIndexedDataset, -) +from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import deallocate_indexed_dataset_memory +from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset from nemo.core import Dataset from nemo.utils import logging try: - from megatron.core import parallel_state + from megatron.core import mpu, tensor_parallel + from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder + from megatron.core.datasets.retro.config import RetroGPTChunkDatasets + from megatron.core.datasets.retro.query.multi_split_gpt_dataset import ( + MultiSplitGPTDataset, + MultiSplitGPTDatasetConfig, + ) + from megatron.core.datasets.retro.query.retro_dataset import get_retro_datasets + from megatron.core.models.retro import RetroConfig + + from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids HAVE_MEGATRON_CORE = True @@ -45,425 +50,163 @@ HAVE_MEGATRON_CORE = False -__all__ = [ - "RETRODataset", - "build_train_valid_test_datasets", - "MockRETRODataset", - "build_mock_train_valid_test_datasets", -] - class RETRODataset(Dataset): - """ - Dataset for RETRO model. - - It constructs single data record from the training/retrieval indexed retrieval dataset and knn index file. - The KNN index file maps data chunk id to K-nearest neighbors in the the retrieval dataset chunk ids. - First, it loads a long sequence (2048) from training dataset. Then for each chunk in the sequence, it finds the kNN - chunks from the retrieval dataset using the KNN index. Lastly, compute the masks based on pad id. - """ - - def __init__( - self, - cfg, - trainer, - tokenizer, - name: str, - data_prefix: str, - documents, # document ids in the indexed_dataset used for this dataset - indexed_dataset: MMapRetrievalIndexedDataset, - num_samples: int, # number of data samples, max_steps * global_batch_size - seq_length: int, # input seq length - seed: int, - knn_index: KNNIndex, - retrieval_index: MMapRetrievalIndexedDataset, - ): - if not HAVE_MEGATRON_CORE: - raise ImportError( - "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." - ) - + def __init__(self, cfg, retro_config: RetroConfig, tokenizer, mcore_retro_dataset, number_samples_with_neighbors): super().__init__() - self.name = name - self.indexed_dataset: MMapRetrievalIndexedDataset = indexed_dataset - self.knn_index: KNNIndex = knn_index - self.retrieval_index: MMapRetrievalIndexedDataset = retrieval_index - self.chunk_size = self.indexed_dataset.chunk_size - - # make sure seq_length is a multiple of chunk_size - assert seq_length % self.chunk_size == 0 - # Checks - assert np.min(documents) >= 0 - assert np.max(documents) < indexed_dataset.sizes.shape[0] + self.reset_position_ids = cfg.data.get('reset_position_ids', False) + self.reset_attention_mask = cfg.data.get('reset_attention_mask', False) + self.eod_mask_loss = cfg.data.get('eod_mask_loss', False) self.eos_id = tokenizer.eos_id - self.pad_id = tokenizer.pad_id - - assert self.retrieval_index._index.retrieval_db - self._validate_pad_id() - - # save index mappings to a configurable dir - self.index_mapping_dir = cfg.data.get('index_mapping_dir', None) - self.neighbors = cfg.data.get('neighbors', self.knn_index.K) - # the number of neighbors cannot exceed the max number of neighbors in the index - assert self.neighbors <= self.knn_index.K - # create index_mapping_dir on rank 0 - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - if self.index_mapping_dir is not None and not os.path.isdir(self.index_mapping_dir): - os.makedirs(self.index_mapping_dir) - torch.distributed.barrier() - - # Build index mappings. - self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( - self.name, - data_prefix, - documents, - self.indexed_dataset.sizes, - num_samples, - seq_length, - seed, - index_mapping_dir=self.index_mapping_dir, - ) - if len(self.doc_idx) > np.iinfo('int32').max: - raise "number of epochs exceeds the maximum number for int32 used by sample_idx" - self.padding_context = np.ones(2 * self.chunk_size, dtype=self.retrieval_index._index.dtype) * self.pad_id - - def _validate_pad_id(self): - # validate the pad_id matches the dataset pad_id - ptr, size = self.retrieval_index._index[0] - ptr += size * np.dtype(self.retrieval_index._index.dtype).itemsize - # padded chunk_size of pad_ids at the end of the doc - retrieval_paddings = np.frombuffer( - self.retrieval_index._bin_buffer, - dtype=self.retrieval_index._index.dtype, - count=self.chunk_size, - offset=ptr, - ) - assert (retrieval_paddings == self.pad_id).all() + self.retro_config = retro_config + self.mcore_retro_dataset = mcore_retro_dataset + self.number_samples_with_neighbors = number_samples_with_neighbors # quick fix for problems of mismatch in processed/indexed retro data, # of GPT samples is different from # of samples with neighbors retrieved + self.tokenizer = tokenizer - ptr, size = self.indexed_dataset._index[0] - ptr += (size - 1) * np.dtype(self.indexed_dataset._index.dtype).itemsize - data_paddings = np.frombuffer( - self.indexed_dataset._bin_buffer, dtype=self.indexed_dataset._index.dtype, count=1, offset=ptr - ) - # the last element is either a padding or an eos - assert (data_paddings == self.pad_id).all() or (data_paddings == self.eos_id).all() + return def __len__(self): - # -1 is due to data structure used to retieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) - return self.sample_idx.shape[0] - 1 - - def _get_chunks(self, chunk_id: int, num_chunks: int, chunks: List): - """ - starting from chunk_id, loop for num_chunks, get the - KNN chunk ids from retrieval dataset, and get the chunk token ids, - put them into the chunks list - """ - for i in range(chunk_id, chunk_id + num_chunks): - knn = self.knn_index.get_KNN_chunk_ids(i) - for rid in knn[: self.neighbors]: - if rid < 0: - # no neighbor, just pad it - one_chunk = self.padding_context - else: - one_chunk = self.retrieval_index.get_chunk(rid) - chunks.append(one_chunk) - - def _get_text(self, idx: int) -> np.ndarray: - # Get the shuffled index. - idx = self.shuffle_idx[idx] - # Start and end documents and offsets. - doc_index_f = self.sample_idx[idx][0] - doc_index_l = self.sample_idx[idx + 1][0] - offset_f = self.sample_idx[idx][1] - offset_l = self.sample_idx[idx + 1][1] - # If we are within the same document, just extract the chunk. - if doc_index_f == doc_index_l: - sample = self.indexed_dataset.get( - self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1 - ) - chunk_id = self.indexed_dataset.get_chunk_id(self.doc_idx[doc_index_f], offset_f) - num_chunks = (offset_l - offset_f) // self.chunk_size - chunks = [] - self._get_chunks(chunk_id, num_chunks, chunks) - chunks = np.stack(chunks, axis=0).reshape(num_chunks, self.neighbors, -1).astype(np.int64) - else: - # Otherwise, get the rest of the initial document. - sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] - num_chunks = (self.indexed_dataset._index.sizes[self.doc_idx[doc_index_f]] - offset_f) // self.chunk_size - total_chunks = num_chunks - chunks = [] - chunk_id = self.indexed_dataset.get_chunk_id(self.doc_idx[doc_index_f], offset_f) - self._get_chunks(chunk_id, num_chunks, chunks) - # Loop over all in between documents and add the entire document. - for i in range(doc_index_f + 1, doc_index_l): - sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) - chunk_id = self.indexed_dataset.get_chunk_id(self.doc_idx[i], 0) - num_chunks = self.indexed_dataset._index.sizes[self.doc_idx[i]] // self.chunk_size - total_chunks += num_chunks - self._get_chunks(chunk_id, num_chunks, chunks) - # And finally add the relevant portion of last document. - chunk_id = self.indexed_dataset.get_chunk_id(self.doc_idx[doc_index_l], 0) - num_chunks = (offset_l) // self.chunk_size - total_chunks += num_chunks - self._get_chunks(chunk_id, num_chunks, chunks) - sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)) - sample = np.concatenate(sample_list) - chunks = np.stack(chunks, axis=0).reshape(total_chunks, self.neighbors, -1).astype(np.int64) - return sample.astype(np.int64), chunks + return len(self.mcore_retro_dataset.chunk_dataset.sample_dataset) - def __getitem__(self, idx): - text, retrieved = self._get_text(idx) - text = torch.from_numpy(text) - retrieved = torch.from_numpy(retrieved) - tokens = text[:-1].contiguous() - labels = text[1:].contiguous() - hidden_mask = tokens != self.pad_id - context_mask = retrieved != self.pad_id - return { - 'tokens': tokens, - 'labels': labels, - 'tokens_mask': hidden_mask, - 'loss_mask': hidden_mask, - 'retrieved_emb_mask': context_mask, - 'retrieved_ids': retrieved, - } + def _get_text(self, idx: int): + # return the tokens ids of idx + # Caveat: these tokens are got from the already pre-tokenized data file, mcore's GPTDataset doesn't run __getitem__, only run _query_document_sample_shuffle_indices + return self.mcore_retro_dataset[idx] + def __getitem__(self, idx): -def build_train_valid_test_datasets( - cfg, - trainer, - data_prefix: List[str], - data_impl: str, - splits_string: str, - train_valid_test_num_samples, - seq_length: int, - seed: int, - skip_warmup: bool, - tokenizer, - retrieval_prefix: str, - knn_map_path: List[str], -): - """Build train, valid, and test RETRO datasets. - There is one to one mapping between data_prefix and knn_map_path. - Currently only supports one retrieval dataset. - """ - # make sure there is one to one mapping between data_prefix and knn_map_path - assert len(data_prefix) == len(knn_map_path) - - # Single dataset. - if len(data_prefix) == 1: - return _build_train_valid_test_datasets( - cfg, - trainer, - data_prefix[0], - data_impl, - splits_string, - train_valid_test_num_samples, - seq_length, - seed, - skip_warmup, - tokenizer, - retrieval_prefix, - knn_map_path[0], - ) - - # Blending dataset. - # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) - prefixes, weights, datasets_train_valid_test_num_samples = output - train_n, valid_n, test_n = map(sum, zip(*datasets_train_valid_test_num_samples)) - - # Build individual datasets. - train_datasets = [] - valid_datasets = [] - test_datasets = [] - for i in range(len(prefixes)): - train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - cfg, - trainer, - prefixes[i], - data_impl, - splits_string, - datasets_train_valid_test_num_samples[i], - seq_length, - seed, - skip_warmup, - tokenizer, - retrieval_prefix, - knn_map_path[i], + # quick fix for problems of mismatch in processed/indexed retro data, # of GPT samples is different from # of samples with neighbors retrieved + idx = idx % self.number_samples_with_neighbors + + sample = self._get_text(idx) + + # Unpack + tokens_ = torch.from_numpy(sample['text']) + tokens_ = tokens_.long() # size should be [seq_length] + labels = tokens_[1:].contiguous() + tokens = tokens_[:-1].contiguous() + neighbor_tokens = torch.from_numpy(sample['neighbor_tokens']) + neighbor_tokens = neighbor_tokens.long() # size should be [l, k, r] + + # note: [l, k, r] => [l*k, r] + # note: 2x == neighbor, continuation + neighbor_tokens = neighbor_tokens.view(-1, self.retro_config.retro_retrieved_length).long() + + # Get the masks and postition ids for tokens and neighbor_tokens + tokens = torch.unsqueeze( + tokens, 0 + ) # get_ltor_masks_and_position_ids takes as input tokens arguments as a batch (2D tensor), so need to convert tokens from 1D to 2D + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, self.eos_id, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss ) - if train_ds: - train_datasets.append(train_ds) - if valid_ds: - valid_datasets.append(valid_ds) - if test_ds: - test_datasets.append(test_ds) - - # Blend. - blending_train_dataset = None - if train_datasets: - blending_train_dataset = BlendableDataset(train_datasets, weights, train_n) - blending_valid_dataset = None - if valid_datasets: - blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_n) - blending_test_dataset = None - if test_datasets: - blending_test_dataset = BlendableDataset(test_datasets, weights, test_n) - - return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) - - -def _build_train_valid_test_datasets( - cfg, - trainer, - data_prefix: str, - data_impl: str, - splits_string: str, - train_valid_test_num_samples, - seq_length: int, - seed: int, - skip_warmup: bool, - tokenizer, - retrieval_prefix: str, - knn_map_path: str, -): - """Build train, valid, and test datasets.""" - - # Indexed dataset. - indexed_dataset: MMapRetrievalIndexedDataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) - knn_index: KNNIndex = KNNIndex(knn_map_path, skip_warmup) - retrieval_index: MMapRetrievalIndexedDataset = get_indexed_dataset_(retrieval_prefix, data_impl, skip_warmup) - - total_num_of_documents = indexed_dataset.sizes.shape[0] - splits = get_train_valid_test_split_(splits_string, total_num_of_documents) - - # Print stats about the splits. - logging.info(' > dataset split:') - - def print_split_stats(name, index): - logging.info(' {}:'.format(name)) - logging.info( - ' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) + tokens, attention_mask, loss_mask, position_ids = tokens[0], attention_mask[0], loss_mask[0], position_ids[0] + _, _, neighbor_position_ids = get_ltor_masks_and_position_ids( # neighbor_tokens is already a 2D array + neighbor_tokens, self.eos_id, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss ) - - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) - - def build_dataset(index, name): - dataset = None - if splits[index + 1] > splits[index]: - documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) - dataset = RETRODataset( - cfg, - trainer, - tokenizer, - name, - data_prefix, - documents, - indexed_dataset, - train_valid_test_num_samples[index], - seq_length, - seed, - knn_index, - retrieval_index, - ) - return dataset - - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') - - return (train_dataset, valid_dataset, test_dataset) - - -class MockRETRODataset(torch.utils.data.Dataset): - def __init__(self, cfg, trainer, tokenizer, name, size): - super().__init__() - self.name = name - self.tokenizer = tokenizer - self._cfg = cfg - self.size = size - seed_val = parallel_state.get_data_parallel_rank() * 131 + 97 - torch.manual_seed(seed_val) - - def __len__(self): - return self.size - - def __getitem__(self, idx): - vocab_size = self.tokenizer.vocab_size - - neighbors = self._cfg.data.neighbors - input_length = self._cfg.data.seq_length - chunks = input_length // self._cfg.chunk_size - chunk_size = self._cfg.chunk_size - pad_id = self.tokenizer.pad_id - - all_tokens = torch.randint(0, vocab_size, (input_length + 1,)) - # make sure the eod happens at the end of each chunk, can add paddings to it - # e.g. [..., id, id, pad, pad, pad, eod] each has chunk_size, each sentence - # has length of multiple of chunk_size - hidden = all_tokens[:-1] - labels = all_tokens[1:] - - hidden_mask = hidden != pad_id - # to mask out the token ids [id, id, eod, id, pad, eod, id, id] - # so attention is not across eod, mask should be: - # [false, true, true, true, true, true, true, true] - # [false, false, true, true, true, true, true, true] - # [false, false, false,true, true, true, true, true] - # [true, true, true, false, true, true, true, true] - # [true, true, true, true, true, true, true, true] - # [true, true, true, false, true, false, true, true] - # [true, true, true, true, true, true, false, true] - # [true, true, true, true, true, true, false, false] - retrieved = torch.randint(0, vocab_size, (chunks, neighbors, 2 * chunk_size)) - - context_mask = retrieved != pad_id + neighbor_attention_mask = torch.zeros( + [1, 1] + ) # just a dummy values, since the batch neighbor_attention_mask will be set to None in megatron_retro_model.py following Lawrence's implementation return { - 'tokens': hidden, + 'tokens': tokens, 'labels': labels, - 'tokens_mask': hidden_mask, - 'loss_mask': hidden_mask, - 'retrieved_emb_mask': context_mask, - 'retrieved_ids': retrieved, + 'loss_mask': loss_mask, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + 'context_input_ids': neighbor_tokens, + 'context_attention_mask': neighbor_attention_mask, + 'context_position_ids': neighbor_position_ids, } -def build_mock_train_valid_test_datasets( - cfg, trainer, splits_string, tokenizer, mock_data_size, +def build_train_valid_test_datasets( + cfg, retro_config: RetroConfig, train_valid_test_num_samples, seq_length, tokenizer, ): - """Build train, valid, and test datasets.""" - - splits = get_train_valid_test_split_(splits_string, mock_data_size) - # Print stats about the splits. - logging.info(' > dataset split:') - - def print_split_stats(name, index): - logging.info(' {}:'.format(name)) - logging.info( - ' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) + # gpt dataset + train_ds, valid_ds, test_ds = gpt_train_valid_test_datasets_provider(cfg, train_valid_test_num_samples, tokenizer) + + gpt_datasets = { + "train": (train_ds, train_valid_test_num_samples[0]), + "valid": (valid_ds, train_valid_test_num_samples[1]), + "test": (test_ds, train_valid_test_num_samples[2]), + } + + retro_train_ds, retro_valid_ds, retro_test_ds = get_retro_datasets( + config=retro_config, gpt_datasets=gpt_datasets, sample_length=seq_length, eod_token_id=tokenizer.eos_id, + ) + + train_ds = ( + RETRODataset( + cfg=cfg, + retro_config=retro_config, + tokenizer=tokenizer, + mcore_retro_dataset=retro_train_ds, + number_samples_with_neighbors=train_valid_test_num_samples[0], + ) + if retro_train_ds + else None + ) + valid_ds = ( + RETRODataset( + cfg=cfg, + retro_config=retro_config, + tokenizer=tokenizer, + mcore_retro_dataset=retro_valid_ds, + number_samples_with_neighbors=train_valid_test_num_samples[1], + ) + if retro_valid_ds + else None + ) + test_ds = ( + RETRODataset( + cfg=cfg, + retro_config=retro_config, + tokenizer=tokenizer, + mcore_retro_dataset=retro_test_ds, + number_samples_with_neighbors=train_valid_test_num_samples[2], ) + if retro_test_ds + else None + ) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + return train_ds, valid_ds, test_ds - def build_dataset(index, name): - dataset = None - if splits[index + 1] > splits[index]: - dataset = MockRETRODataset(cfg, trainer, tokenizer, name, splits[index + 1] - splits[index],) - return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') +def gpt_train_valid_test_datasets_provider(cfg, train_val_test_num_samples, tokenizer): + """Build the train test and validation datasets. + Implemented from train_valid_test_datasets_provider in M-LM/pretrain_gpt.py + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ - return (train_dataset, valid_dataset, test_dataset) + def is_dataset_built_on_rank(): + return ( + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() + ) and mpu.get_tensor_model_parallel_rank() == 0 + + data_config = MultiSplitGPTDatasetConfig( + random_seed=cfg.seed, + sequence_length=cfg.data.seq_length, + blend=cfg.data.data_prefix, + split=cfg.data.splits_string, + split_preprocessing=cfg.data.retro_data.retro_split_preprocessing, + path_to_cache=None, + return_document_ids=False, + reset_position_ids=cfg.data.get('reset_position_ids', False), + reset_attention_mask=cfg.data.get('reset_attention_mask', False), + eod_mask_loss=cfg.data.get('eod_mask_loss', False), + tokenizer=tokenizer, + ) + + print("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + MultiSplitGPTDataset, train_val_test_num_samples, is_dataset_built_on_rank, data_config + ).build() + + print("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset_legacy.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset_legacy.py new file mode 100644 index 000000000000..f0a501d7cc13 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset_legacy.py @@ -0,0 +1,469 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RETRO Style dataset.""" + +import os +from typing import List + +import numpy as np +import torch + +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( + get_datasets_weights_and_num_samples, + get_train_valid_test_split_, +) +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import ( + _build_index_mappings, + get_indexed_dataset_, +) +from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import ( + KNNIndex, + MMapRetrievalIndexedDataset, +) +from nemo.core import Dataset +from nemo.utils import logging + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +__all__ = [ + "RETRODataset", + "build_train_valid_test_datasets", + "MockRETRODataset", + "build_mock_train_valid_test_datasets", +] + + +class RETRODataset(Dataset): + """ + Dataset for RETRO model. + + It constructs single data record from the training/retrieval indexed retrieval dataset and knn index file. + The KNN index file maps data chunk id to K-nearest neighbors in the the retrieval dataset chunk ids. + First, it loads a long sequence (2048) from training dataset. Then for each chunk in the sequence, it finds the kNN + chunks from the retrieval dataset using the KNN index. Lastly, compute the masks based on pad id. + """ + + def __init__( + self, + cfg, + trainer, + tokenizer, + name: str, + data_prefix: str, + documents, # document ids in the indexed_dataset used for this dataset + indexed_dataset: MMapRetrievalIndexedDataset, + num_samples: int, # number of data samples, max_steps * global_batch_size + seq_length: int, # input seq length + seed: int, + knn_index: KNNIndex, + retrieval_index: MMapRetrievalIndexedDataset, + ): + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + + super().__init__() + self.name = name + self.indexed_dataset: MMapRetrievalIndexedDataset = indexed_dataset + self.knn_index: KNNIndex = knn_index + self.retrieval_index: MMapRetrievalIndexedDataset = retrieval_index + self.chunk_size = self.indexed_dataset.chunk_size + + # make sure seq_length is a multiple of chunk_size + assert seq_length % self.chunk_size == 0 + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + self.eos_id = tokenizer.eos_id + self.pad_id = tokenizer.pad_id + + assert self.retrieval_index._index.retrieval_db + self._validate_pad_id() + + # save index mappings to a configurable dir + self.index_mapping_dir = cfg.data.get('index_mapping_dir', None) + self.neighbors = cfg.data.get('neighbors', self.knn_index.K) + # the number of neighbors cannot exceed the max number of neighbors in the index + assert self.neighbors <= self.knn_index.K + # create index_mapping_dir on rank 0 + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + if self.index_mapping_dir is not None and not os.path.isdir(self.index_mapping_dir): + os.makedirs(self.index_mapping_dir) + torch.distributed.barrier() + + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + num_samples, + seq_length, + seed, + index_mapping_dir=self.index_mapping_dir, + ) + if len(self.doc_idx) > np.iinfo('int32').max: + raise "number of epochs exceeds the maximum number for int32 used by sample_idx" + self.padding_context = np.ones(2 * self.chunk_size, dtype=self.retrieval_index._index.dtype) * self.pad_id + + def _validate_pad_id(self): + # validate the pad_id matches the dataset pad_id + ptr, size = self.retrieval_index._index[0] + ptr += size * np.dtype(self.retrieval_index._index.dtype).itemsize + # padded chunk_size of pad_ids at the end of the doc + retrieval_paddings = np.frombuffer( + self.retrieval_index._bin_buffer, + dtype=self.retrieval_index._index.dtype, + count=self.chunk_size, + offset=ptr, + ) + assert (retrieval_paddings == self.pad_id).all() + + ptr, size = self.indexed_dataset._index[0] + ptr += (size - 1) * np.dtype(self.indexed_dataset._index.dtype).itemsize + data_paddings = np.frombuffer( + self.indexed_dataset._bin_buffer, dtype=self.indexed_dataset._index.dtype, count=1, offset=ptr + ) + # the last element is either a padding or an eos + assert (data_paddings == self.pad_id).all() or (data_paddings == self.eos_id).all() + + def __len__(self): + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.sample_idx.shape[0] - 1 + + def _get_chunks(self, chunk_id: int, num_chunks: int, chunks: List): + """ + starting from chunk_id, loop for num_chunks, get the + KNN chunk ids from retrieval dataset, and get the chunk token ids, + put them into the chunks list + """ + for i in range(chunk_id, chunk_id + num_chunks): + knn = self.knn_index.get_KNN_chunk_ids(i) + for rid in knn[: self.neighbors]: + if rid < 0: + # no neighbor, just pad it + one_chunk = self.padding_context + else: + one_chunk = self.retrieval_index.get_chunk(rid) + chunks.append(one_chunk) + + def _get_text(self, idx: int) -> np.ndarray: + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + if doc_index_f == doc_index_l: + sample = self.indexed_dataset.get( + self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1 + ) + chunk_id = self.indexed_dataset.get_chunk_id(self.doc_idx[doc_index_f], offset_f) + num_chunks = (offset_l - offset_f) // self.chunk_size + chunks = [] + self._get_chunks(chunk_id, num_chunks, chunks) + chunks = np.stack(chunks, axis=0).reshape(num_chunks, self.neighbors, -1).astype(np.int64) + else: + # Otherwise, get the rest of the initial document. + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] + num_chunks = (self.indexed_dataset._index.sizes[self.doc_idx[doc_index_f]] - offset_f) // self.chunk_size + total_chunks = num_chunks + chunks = [] + chunk_id = self.indexed_dataset.get_chunk_id(self.doc_idx[doc_index_f], offset_f) + self._get_chunks(chunk_id, num_chunks, chunks) + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) + chunk_id = self.indexed_dataset.get_chunk_id(self.doc_idx[i], 0) + num_chunks = self.indexed_dataset._index.sizes[self.doc_idx[i]] // self.chunk_size + total_chunks += num_chunks + self._get_chunks(chunk_id, num_chunks, chunks) + # And finally add the relevant portion of last document. + chunk_id = self.indexed_dataset.get_chunk_id(self.doc_idx[doc_index_l], 0) + num_chunks = (offset_l) // self.chunk_size + total_chunks += num_chunks + self._get_chunks(chunk_id, num_chunks, chunks) + sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)) + sample = np.concatenate(sample_list) + chunks = np.stack(chunks, axis=0).reshape(total_chunks, self.neighbors, -1).astype(np.int64) + return sample.astype(np.int64), chunks + + def __getitem__(self, idx): + text, retrieved = self._get_text(idx) + text = torch.from_numpy(text) + retrieved = torch.from_numpy(retrieved) + tokens = text[:-1].contiguous() + labels = text[1:].contiguous() + hidden_mask = tokens != self.pad_id + context_mask = retrieved != self.pad_id + return { + 'tokens': tokens, + 'labels': labels, + 'tokens_mask': hidden_mask, + 'loss_mask': hidden_mask, + 'retrieved_emb_mask': context_mask, + 'retrieved_ids': retrieved, + } + + +def build_train_valid_test_datasets( + cfg, + trainer, + data_prefix: List[str], + data_impl: str, + splits_string: str, + train_valid_test_num_samples, + seq_length: int, + seed: int, + skip_warmup: bool, + tokenizer, + retrieval_prefix: str, + knn_map_path: List[str], +): + """Build train, valid, and test RETRO datasets. + There is one to one mapping between data_prefix and knn_map_path. + Currently only supports one retrieval dataset. + """ + # make sure there is one to one mapping between data_prefix and knn_map_path + assert len(data_prefix) == len(knn_map_path) + + # Single dataset. + if len(data_prefix) == 1: + return _build_train_valid_test_datasets( + cfg, + trainer, + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + tokenizer, + retrieval_prefix, + knn_map_path[0], + ) + + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + train_n, valid_n, test_n = map(sum, zip(*datasets_train_valid_test_num_samples)) + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + cfg, + trainer, + prefixes[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, + seed, + skip_warmup, + tokenizer, + retrieval_prefix, + knn_map_path[i], + ) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights, train_n) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_n) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights, test_n) + + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + + +def _build_train_valid_test_datasets( + cfg, + trainer, + data_prefix: str, + data_impl: str, + splits_string: str, + train_valid_test_num_samples, + seq_length: int, + seed: int, + skip_warmup: bool, + tokenizer, + retrieval_prefix: str, + knn_map_path: str, +): + """Build train, valid, and test datasets.""" + + # Indexed dataset. + indexed_dataset: MMapRetrievalIndexedDataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) + knn_index: KNNIndex = KNNIndex(knn_map_path, skip_warmup) + retrieval_index: MMapRetrievalIndexedDataset = get_indexed_dataset_(retrieval_prefix, data_impl, skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + logging.info(' > dataset split:') + + def print_split_stats(name, index): + logging.info(' {}:'.format(name)) + logging.info( + ' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) + ) + + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) + dataset = RETRODataset( + cfg, + trainer, + tokenizer, + name, + data_prefix, + documents, + indexed_dataset, + train_valid_test_num_samples[index], + seq_length, + seed, + knn_index, + retrieval_index, + ) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +class MockRETRODataset(torch.utils.data.Dataset): + def __init__(self, cfg, trainer, tokenizer, name, size): + super().__init__() + self.name = name + self.tokenizer = tokenizer + self._cfg = cfg + self.size = size + seed_val = parallel_state.get_data_parallel_rank() * 131 + 97 + torch.manual_seed(seed_val) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + vocab_size = self.tokenizer.vocab_size + + neighbors = self._cfg.data.neighbors + input_length = self._cfg.data.seq_length + chunks = input_length // self._cfg.chunk_size + chunk_size = self._cfg.chunk_size + pad_id = self.tokenizer.pad_id + + all_tokens = torch.randint(0, vocab_size, (input_length + 1,)) + # make sure the eod happens at the end of each chunk, can add paddings to it + # e.g. [..., id, id, pad, pad, pad, eod] each has chunk_size, each sentence + # has length of multiple of chunk_size + hidden = all_tokens[:-1] + labels = all_tokens[1:] + + hidden_mask = hidden != pad_id + # to mask out the token ids [id, id, eod, id, pad, eod, id, id] + # so attention is not across eod, mask should be: + # [false, true, true, true, true, true, true, true] + # [false, false, true, true, true, true, true, true] + # [false, false, false,true, true, true, true, true] + # [true, true, true, false, true, true, true, true] + # [true, true, true, true, true, true, true, true] + # [true, true, true, false, true, false, true, true] + # [true, true, true, true, true, true, false, true] + # [true, true, true, true, true, true, false, false] + retrieved = torch.randint(0, vocab_size, (chunks, neighbors, 2 * chunk_size)) + + context_mask = retrieved != pad_id + + return { + 'tokens': hidden, + 'labels': labels, + 'tokens_mask': hidden_mask, + 'loss_mask': hidden_mask, + 'retrieved_emb_mask': context_mask, + 'retrieved_ids': retrieved, + } + + +def build_mock_train_valid_test_datasets( + cfg, trainer, splits_string, tokenizer, mock_data_size, +): + """Build train, valid, and test datasets.""" + + splits = get_train_valid_test_split_(splits_string, mock_data_size) + + # Print stats about the splits. + logging.info(' > dataset split:') + + def print_split_stats(name, index): + logging.info(' {}:'.format(name)) + logging.info( + ' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) + ) + + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + dataset = MockRETRODataset(cfg, trainer, tokenizer, name, splits[index + 1] - splits[index],) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) diff --git a/nemo/collections/nlp/models/language_modeling/__init__.py b/nemo/collections/nlp/models/language_modeling/__init__.py index f63d289f8925..437a7003483b 100644 --- a/nemo/collections/nlp/models/language_modeling/__init__.py +++ b/nemo/collections/nlp/models/language_modeling/__init__.py @@ -17,4 +17,5 @@ MegatronGPTPromptLearningModel, ) from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel +from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel from nemo.collections.nlp.models.language_modeling.transformer_lm_model import TransformerLMModel diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 4493532f88bf..43cc8c26444f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1481,7 +1481,7 @@ def setup(self, stage=None): f'Pipeline model parallel rank: {parallel_state.get_pipeline_model_parallel_rank()}, ' f'Tensor model parallel rank: {parallel_state.get_tensor_model_parallel_rank()}, ' f'Number of model parameters on device: {num_parameters_on_device:.2e}. ' - f'Total number of model parameters: {total_num_parameters:.2e}.' + f'Number of precise model parameters on device: {total_num_parameters}.' ) resume_checkpoint_path = self.trainer.ckpt_path @@ -1548,11 +1548,14 @@ def setup_validation_data(self, cfg): def setup_test_data(self, cfg): if hasattr(self, '_test_ds'): - consumed_samples = 0 - logging.info( - f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' - ) - self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples) + if self._test_ds is not None: + consumed_samples = 0 + logging.info( + f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' + ) + self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples) + else: + self._test_dl = None def generate( self, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py index acd85261f7e5..42323e503f7d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py @@ -23,7 +23,7 @@ MegatronPretrainingRandomSampler, MegatronPretrainingSampler, ) -from nemo.collections.nlp.data.language_modeling.megatron.retro_dataset import ( +from nemo.collections.nlp.data.language_modeling.megatron.retro_dataset_legacy import ( build_mock_train_valid_test_datasets, build_train_valid_test_datasets, ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py new file mode 100644 index 000000000000..8cc39056554c --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py @@ -0,0 +1,651 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import os +import queue +import types +import warnings +from dataclasses import fields +from functools import partial +from typing import Any, Dict, Iterator, List, Optional, Union + +import torch +from omegaconf import OmegaConf, open_dict +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.accelerators import CPUAccelerator +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( + MegatronPretrainingRandomSampler, + MegatronPretrainingSampler, +) + +# from nemo.collections.nlp.data.language_modeling.megatron.retro_dummy_dataset import build_train_valid_test_datasets as dummy_build_train_valid_test_datasets # turn on when running with dummy data +from nemo.collections.nlp.data.language_modeling.megatron.retro_dataset import build_train_valid_test_datasets +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.build_model import build_model +from nemo.collections.nlp.modules.common.megatron.module import Float16Module +from nemo.collections.nlp.modules.common.megatron.utils import ( + ApexGuardDefaults, + average_losses_across_data_parallel_group, + get_all_params_for_weight_decay_optimization, + get_ltor_masks_and_position_ids, + get_params_for_weight_decay_optimization, +) +from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy +from nemo.collections.nlp.modules.common.text_generation_utils import ( + generate, + get_computeprob_response, + get_default_length_params, + get_default_sampling_params, + megatron_gpt_generate, +) +from nemo.collections.nlp.modules.common.transformer.text_generation import ( + LengthParam, + OutputType, + SamplingParam, + TextGeneration, +) +from nemo.collections.nlp.parts import utils_funcs +from nemo.collections.nlp.parts.utils_funcs import activation_to_func, get_last_rank +from nemo.core.classes import Exportable +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.neural_types import ChannelType, NeuralType +from nemo.utils import logging + +try: + import apex.transformer.pipeline_parallel.utils + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + + HAVE_APEX = False + +try: + from megatron.core import InferenceParams, parallel_state + from megatron.core.models.retro import RetroModel as MCoreRetroModel + from megatron.core.models.retro.config import RetroConfig + from megatron.core.models.retro.decoder_spec import get_retro_decoder_block_spec + from megatron.core.models.retro.utils import get_config_path as get_retro_config_path + from megatron.core.models.retro.utils import get_gpt_data_dir as get_retro_data_dir + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + from megatron.core.transformer.module import Float16Module as MCoreFloat16Module + from megatron.core.transformer.transformer_config import TransformerConfig + from megatron.core.utils import init_method_normal, scaled_init_method_normal + + # TODO @tmoon: Use once available in Megatron-LM + # from megatron.core.pipeline_parallel.schedules import DataIteratorList + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + TransformerConfig = ApexGuardDefaults + + HAVE_MEGATRON_CORE = False + +try: + import transformer_engine + from transformer_engine.pytorch import module as te_module + + HAVE_TE = True + +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + + +class MegatronRetroModel(MegatronGPTModel): + """ + Megatron Retro pretraining + """ + + def load_retro_config(self, cfg: DictConfig): + assert cfg.retro.get('retro_project_dir') is not None, "`--retro-project-dir` must be set to use Retro." + + # Retro config path. + retro_config_path = get_retro_config_path(cfg.retro.get('retro_project_dir')) + assert os.path.exists(retro_config_path), "retro project dir missing config.json." + + # Load retro config. + with open(retro_config_path) as f: + + # Parse config. + retro_preprocess_config = types.SimpleNamespace(**json.load(f)) + + # Retro data path is relative to data path (via hard or soft links). + data_dir = get_retro_data_dir(cfg.retro.get('retro_project_dir')) + data_path = list(retro_preprocess_config.retro_gpt_data_path) + if len(data_path) % 2 == 0: + for i in range(len(data_path) - 1, -1, -2): + data_path[i] = os.path.join(data_dir, data_path[i]) + else: + assert len(data_path) == 1 + data_path[0] = os.path.join(data_dir, data_path[0]) + + # Update args. + cfg.global_batch_size = retro_preprocess_config.retro_gpt_global_batch_size + cfg.seed = retro_preprocess_config.retro_gpt_seed + cfg.data.data_prefix = data_path + cfg.encoder_seq_length = retro_preprocess_config.retro_gpt_seq_length + cfg.data.seq_length = retro_preprocess_config.retro_gpt_seq_length + cfg.max_position_embeddings = retro_preprocess_config.retro_gpt_seq_length + # cfg.data.splits_string = retro_preprocess_config.retro_gpt_split # remove because lastest RETRO data-object have separate RETRO training split and RETRO preprocessing split + cfg.tokenizer.model = ( + cfg.retro.get('retro_project_dir') + '/' + retro_preprocess_config.retro_gpt_tokenizer_model + ) + cfg.tokenizer.type = retro_preprocess_config.retro_gpt_tokenizer_type + cfg.tokenizer.vocab_file = retro_preprocess_config.retro_gpt_vocab_file + cfg.tokenizer.merge_file = retro_preprocess_config.retro_gpt_merge_file + with open_dict(cfg): + cfg.retro_train_samples_with_neighbors = retro_preprocess_config.retro_gpt_train_samples + cfg.retro_valid_samples_with_neighbors = retro_preprocess_config.retro_gpt_valid_samples + cfg.data.retro_data.retro_block_size = retro_preprocess_config.retro_block_size + cfg.data.retro_data.retro_chunk_length = retro_preprocess_config.retro_gpt_chunk_length + cfg.data.retro_data.retro_split_preprocessing = retro_preprocess_config.retro_gpt_split + cfg.data.retro_data.retro_neighbor_dirs = retro_preprocess_config.retro_neighbor_dirs + + return cfg + + def __init__(self, cfg: DictConfig, trainer: Trainer): + + # override pre-processing arguments with retro pre-processing arguments + cfg = self.load_retro_config(cfg) + + super().__init__(cfg, trainer=trainer) + + logging.info( + "\n\n************** Experiment configuration (after overriding with RETRO's workdir values) ***********" + ) + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + return + + def model_provider_func(self, pre_process, post_process): + """Model depends on pipeline paralellism.""" + if self.mcore_gpt: + self.retro_model_config = self.build_retro_config() + model = MCoreRetroModel( + config=self.retro_model_config, + transformer_layer_spec=get_retro_decoder_block_spec( + self.retro_model_config, use_transformer_engine=True + ), + vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size), + max_sequence_length=self.cfg.data.get('seq_length', 512), + pre_process=pre_process, + post_process=post_process, + parallel_output=True, + share_embeddings_and_output_weights=self.cfg.get('share_embeddings_and_output_weights', True), + position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'), + rotary_percent=self.cfg.get('rotary_percentage', 1.0), + seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None), + ) + + return model + else: + assert self.mcore_gpt == True, "Currently only support mcore Retro." + + def forward( + self, tokens, text_position_ids, attention_mask, labels, context_input_ids, context_position_ids, context_mask + ): + output_tensor = self.model( + tokens, + text_position_ids, + attention_mask, + context_input_ids=context_input_ids, + context_position_ids=context_position_ids, + context_mask=context_mask, + labels=labels, + ) + return output_tensor + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None, **extra) -> Any: + # batch = {'prompts': List, 'neighbors': List[List]} + + inference_config = self.get_inference_config() + + if torch.distributed.get_rank() == 0: + logging.info("inference_config: ") + logging.info(inference_config) + + if inference_config is None: + return None + else: + # need to overwrite some configuration, make it immutable + inference_config = inference_config.copy() + compute_logprob = inference_config['compute_logprob'] + if compute_logprob: + inference_config['inputs'] = batch['prompts'] + inference_config['neighbors'] = batch['neighbors'] + inference_config['tokens_to_generate'] = 1 + inference_config['all_probs'] = True + inference_config["add_BOS"] = False + inference_config['greedy'] = True + inference_config['retro_inference'] = inference_config['retro_inference'] + response = generate(self, **inference_config) + compute_prob_response = get_computeprob_response(self.tokenizer, response, batch) + return compute_prob_response + else: + inference_config['inputs'] = batch['prompts'] + inference_config['neighbors'] = batch['neighbors'] + inference_config['retro_inference'] = inference_config['retro_inference'] + return generate(self, **inference_config) + + def get_batch(self, data_iterator): + """Generate a batch.""" + + # Broadcast data. + if data_iterator is not None: + # If tuple, 1st element in it is the batch since dataloader_iter returns batch, batch_idx, dataloader_idx + data = next(data_iterator) + if isinstance(data, tuple): + data = data[0] + else: + data = None + + batch = { + 'tokens': data["tokens"], + 'labels': data["labels"], + 'loss_mask': data["loss_mask"], + 'attention_mask': data["attention_mask"], + 'position_ids': data["position_ids"], + 'context_input_ids': data["context_input_ids"], + 'context_attention_mask': data["context_attention_mask"], + 'context_position_ids': data["context_position_ids"], + } + + return batch + + def get_forward_output_and_loss_func(self, validation_step=False): + def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): + + # Get data batch + batch = self.get_batch(dataloader_iter) + + # Transfer needed data to GPU + required_keys = set() + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + required_keys.add('attention_mask') + if parallel_state.is_pipeline_first_stage(): + required_keys.update( + ('tokens', 'position_ids', 'context_input_ids', 'context_position_ids', 'context_mask') + ) + if parallel_state.is_pipeline_last_stage(): + required_keys.update(('labels', 'loss_mask')) + if self.get_attention_mask_from_fusion: + required_keys.remove('attention_mask') + batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + + # reshape context_input_ids and context_position_ids for RETRO from [bs, l*k, r] => [bs*l*k, r] + context_input_ids = batch['context_input_ids'] + context_position_ids = batch['context_position_ids'] + context_input_ids = context_input_ids.view(-1, context_input_ids.shape[-1]).long() + context_position_ids = context_position_ids.view(-1, context_position_ids.shape[-1]).long() + batch['context_input_ids'] = context_input_ids + batch['context_position_ids'] = context_position_ids + + # slice batch along sequence dimension for context parallelism + batch = self.get_batch_on_this_context_parallel_rank(batch) + + # Model forward pass + forward_args = { + 'input_ids': batch['tokens'], + 'position_ids': batch['position_ids'], + 'attention_mask': batch['attention_mask'], + 'context_input_ids': batch['context_input_ids'], + 'context_position_ids': batch['context_position_ids'], + 'context_mask': None, # batch neighbor_attention_mask will be set to None following Lawrence's implementation + 'labels': batch['labels'], + 'loss_mask': batch['loss_mask'], + } + + if not self.mcore_gpt: + forward_args['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers + if not self.use_loss_mask: + forward_args.pop('loss_mask') + else: + # TODO: @eharper can we add this to mcore? + forward_args.pop('loss_mask') + output_tensor = model(**forward_args) + + def loss_func(output_tensor): + # Loss for a micro-batch (ub) + loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) + if validation_step and not self.cfg.data.get('validation_drop_last', True): + num_valid_tokens_in_ub = batch['loss_mask'].sum() + if loss_for_ub.isnan(): + assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' + loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub) + else: + loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub + + loss_sum_and_ub_size_all_gpu = torch.cat( + [ + loss_sum_for_ub.clone().detach().view(1), + torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(), + ] + ) + # Could potentially reduce num_valid_samples_in_microbatch and use that to aggregate instead of len(self._validation_ds) + torch.distributed.all_reduce( + loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() + ) + return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} + else: + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + return loss_for_ub, {'avg': reduced_loss} + + return output_tensor, loss_func + + return fwd_output_and_loss_func + + def get_forward_output_only_func(self): + def fwd_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) + extra_arg = {} + if len(batch) == 5: + batch = [x.cuda() for x in batch] + tokens, attention_mask, position_ids, context_input_ids, context_position_ids, context_mask = batch + attention_mask = attention_mask[0:1] + else: + ( + tokens, + attention_mask, + position_ids, + context_input_ids, + context_position_ids, + context_mask, + set_inference_key_value_memory, + inference_max_sequence_len, + ) = batch + tokens = tokens.cuda() + position_ids = position_ids.cuda() + if attention_mask is not None: + attention_mask = attention_mask.cuda() + attention_mask = attention_mask[0:1] + context_input_ids = context_input_ids.cuda() + context_position_ids = context_position_ids.cuda() + context_mask = None + if self.mcore_gpt: + # if first step, then clear KV cache, otherwise reuse inference_paarms + if set_inference_key_value_memory[0].item(): + self.inference_params = InferenceParams( + max_batch_size=tokens.size(0), max_sequence_length=inference_max_sequence_len[0].item() + ) + extra_arg['inference_params'] = self.inference_params + else: + extra_arg['set_inference_key_value_memory'] = set_inference_key_value_memory[0].item() + extra_arg['inference_max_sequence_len'] = inference_max_sequence_len[0].item() + output_tensor = model( + tokens, + position_ids, + attention_mask, + context_input_ids=context_input_ids, + context_position_ids=context_position_ids, + context_mask=None, # batch neighbor_attention_mask will be set to None following Lawrence's implementation + **extra_arg, + ) + + # Advance inference sequence offset. + if self.inference_params: + # if last stage, then (final) output is [b, s, h], otherwise it's [s, b, h] + if parallel_state.is_pipeline_last_stage(): + self.inference_params.sequence_len_offset += output_tensor.size(1) + else: + self.inference_params.sequence_len_offset += output_tensor.size(0) + + def id_func(output_tensor): + return output_tensor, {'logits': output_tensor} + + return output_tensor, id_func + + return fwd_output_only_func + + def build_retro_config(self) -> RetroConfig: + """ This method build RetroConfig from the already built TransformerConfig + by adding Retro relevant variables. This method runs after running build_transformer_config() method. + """ + retro_config = self.transformer_config + + # retro model args + retro_config.retro_project_dir = self.cfg.retro.get('retro_project_dir') + retro_config.retro_block_size = self.cfg.data.retro_data.get('retro_block_size') + retro_config.retro_chunk_length = self.cfg.data.retro_data.get('retro_chunk_length') + retro_config.retro_encoder_num_layers = self.cfg.retro.get('retro_encoder_num_layers', 2) + retro_config.retro_encoder_hidden_dropout = self.cfg.retro.get('retro_encoder_hidden_dropout', 0.1) + retro_config.retro_encoder_attention_dropout = self.cfg.retro.get('retro_encoder_attention_dropout', 0.1) + retro_config.retro_num_neighbors = self.cfg.retro.get('retro_num_neighbors', 2) + retro_config.retro_num_retrieved_chunks = self.cfg.retro.get('retro_num_retrieved_chunks', 2) + retro_config.retro_verify_neighbor_count = self.cfg.retro.get('retro_verify_neighbor_count', True) + retro_config.retro_retrieved_length = retro_config.retro_num_retrieved_chunks * retro_config.retro_chunk_length + retro_config.retro_split_preprocessing = self.cfg.data.retro_data.get('retro_split_preprocessing') + retro_config.retro_neighbor_dirs = self.cfg.data.retro_data.get('retro_neighbor_dirs') + logging.info("retro_config: ") + logging.info(retro_config) + + # Validate Transformer Engine version. + from importlib.metadata import version + + from pkg_resources import packaging + + te_version = packaging.version.Version(version("transformer-engine")) + if te_version >= packaging.version.Version("1.3"): + try: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + assert os.getenv("NVTE_FLASH_ATTN") == "0" + assert os.getenv("NVTE_FUSED_ATTN") == "0" + except Exception as e: + raise Exception( + "When using Transformer Engine >= 1.3, environment vars NVTE_FLASH_ATTN and NVTE_FUSED_ATTN most both be defined and set to '0'. Currently, NVTE_FLASH_ATTN == %s, NVTE_FUSED_ATTN == %s." + % (os.getenv("NVTE_FLASH_ATTN", "[unset]"), os.getenv("NVTE_FUSED_ATTN", "[unset]"),) + ) + + return retro_config + + def build_train_valid_test_datasets(self): + # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step + # self._reconfigure_val_batches() + logging.info('Building mcore RETRO datasets.') + if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float): + raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") + global_batch_size = self.cfg.global_batch_size + # max_train_steps = self.trainer.max_steps + # eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches # check this carefully, we want to match mcore dataset value, should this computed, or overriden? + # test_iters = self.trainer.limit_test_batches + + # getting train_valid_test_num_samples from values in RETRO's workdir + train_valid_test_num_samples = [ # compute the number of training/validating samples from workdir/query/train_*; dividing number of chunks for (2048/64) + self.cfg.retro_train_samples_with_neighbors, + self.cfg.retro_valid_samples_with_neighbors, + 0, + ] + + if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): + train_valid_test_num_samples[ + 1 + ] = 1 # This is to make sure we only have one epoch on every validation iteration + + self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets( + cfg=self.cfg, + retro_config=self.retro_model_config, + train_valid_test_num_samples=train_valid_test_num_samples, + seq_length=self.cfg.data.seq_length, + tokenizer=self.tokenizer, + ) + + if self._train_ds is not None: + logging.info(f'Length of train dataset: {len(self._train_ds)}') + if self._validation_ds is not None: + logging.info(f'Length of val dataset: {len(self._validation_ds)}') + if self._test_ds is not None: + logging.info(f'Length of test dataset: {len(self._test_ds)}') + logging.info(f'Finished building mcore RETRO datasets.') + + return self._train_ds, self._validation_ds, self._test_ds + + def build_pretraining_data_loader( + self, dataset, consumed_samples, dataset_type=None, drop_last=True, pad_samples_to_global_batch_size=False + ): + """Buld dataloader given an input dataset.""" + + logging.info(f'Building dataloader with consumed samples: {consumed_samples}') + # Megatron sampler + if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: + if self.cfg.data.dataloader_type == 'single': + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=self.cfg.micro_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=drop_last, + global_batch_size=self.cfg.global_batch_size, + rampup_batch_size=self.cfg.get('rampup_batch_size', None), + pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, + ) + elif self.cfg.data.dataloader_type == 'cyclic': + batch_sampler = MegatronPretrainingRandomSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=self.cfg.micro_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=self.cfg.get('drop_last', True), + ) + else: + raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"') + else: + raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"') + + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + persistent_workers=True if self.cfg.data.num_workers > 0 else False, + ) + + def fwd_bwd_step(self, dataloader_iter, forward_only): + + # handle asynchronous grad reduction + no_sync_func = None + grad_sync_func = None + param_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + grad_sync_func = self.reduce_overlap_gradients + param_sync_func = self.sync_overlap_parameters + + # pipeline schedules will get these from self.model.config + for module in self.get_model_module_list(): + module.config.no_sync_func = no_sync_func + module.config.grad_sync_func = grad_sync_func + module.config.param_sync_func = param_sync_func + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + fwd_bwd_function = get_forward_backward_func() + + # TODO @akhattar: add num_micro_batches_with_partial_activation_checkpoints when ready + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(forward_only), + data_iterator=self._make_data_iterator_list(dataloader_iter), + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=self.cfg.encoder_seq_length, + micro_batch_size=self.cfg.micro_batch_size, + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # Get the total loss since micro batches sizes are not uniform + loss_sum_tensors_list = [ + loss_sum['loss_sum_and_ub_size'] + for loss_sum in losses_reduced_per_micro_batch + if loss_sum['loss_sum_and_ub_size'][1] > 0 + ] + loss_sum = ( + torch.vstack(loss_sum_tensors_list).sum(axis=0) + if len(loss_sum_tensors_list) > 0 + else torch.tensor([0.0, 0.0]).cuda() + ) + return loss_sum + else: + # we're not on the last pipeline stage so no losses + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0).cuda() + + return loss_mean + + def validation_step(self, dataloader_iter, dataloader_idx=0): + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + """ + mode = 'test' if self.trainer.testing else 'val' + # Initialize userbuffer communicators. + if self.initialize_ub: + self.initialize_ub_func() + + if isinstance(self.model, list): + for model_module in self.model: + model_module.eval() + else: + self.model.eval() + + if self.cfg.get('fp8', False): + first_val_step = self.prev_step_training and not self.training + self.prev_step_training = self.training + else: + first_val_step = None + + with torch.no_grad(): + loss = self.fwd_bwd_step(dataloader_iter, True) + + if isinstance(self.model, list): + for model_module in self.model: + model_module.train() + else: + self.model.train() + + if mode == 'val': + # Append with the correct dataloader_idx in case of multiple dataloaders + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(loss) + else: + self.validation_step_outputs.append(loss) + else: + if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(loss) + else: + self.test_step_outputs.append(loss) + + return loss diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index 84df4a6965e1..67c94ae5d608 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -114,6 +114,7 @@ def get_tokenizer( tokenizer_name = get_megatron_tokenizer(tokenizer_name) if tokenizer_name == 'sentencepiece': + logging.info("tokenizer_model: " + str(tokenizer_model)) return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( model_path=tokenizer_model, special_tokens=special_tokens, legacy=True ) @@ -195,6 +196,14 @@ def get_nmt_tokenizer( logging.info(f'Using regex tokenization') return RegExTokenizer().load_tokenizer(regex_file=tokenizer_model, vocab_file=vocab_file) elif library == 'megatron': + + if model_name == 'GPTSentencePieceTokenizer': + logging.info("tokenizer_model: ") + logging.info(tokenizer_model) + return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( + model_path=tokenizer_model, legacy=legacy + ) + if model_name in megatron_tokenizer_model_map: model_name = megatron_tokenizer_model_map[model_name] logging.info( diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index 059ce4455977..e532297d9747 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -357,12 +357,15 @@ def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barri barrier_before: Synchronize ranks before removing the marker file. Defaults to False. """ - if barrier_before and torch.distributed.is_initialized(): - torch.distributed.barrier() - if is_global_rank_zero(): - marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) - if marker_path.exists(): - marker_path.unlink() + try: + if barrier_before and torch.distributed.is_initialized(): + torch.distributed.barrier() + if is_global_rank_zero(): + marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) + if marker_path.exists(): + marker_path.unlink() + except: + return def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. diff --git a/tests/collections/nlp/test_indexed_retrieval_dataset.py b/tests/collections/nlp/test_indexed_retrieval_dataset.py index e35c3ab36840..5110651b34a6 100644 --- a/tests/collections/nlp/test_indexed_retrieval_dataset.py +++ b/tests/collections/nlp/test_indexed_retrieval_dataset.py @@ -28,7 +28,7 @@ MMapRetrievalIndexedDatasetBuilder, merge_knn_files, ) -from nemo.collections.nlp.data.language_modeling.megatron.retro_dataset import RETRODataset +from nemo.collections.nlp.data.language_modeling.megatron.retro_dataset_legacy import RETRODataset try: from megatron.core import parallel_state