Skip to content

Commit

Permalink
[pipeline] fix return_dict/fix pure_pipeline_test (hpcaitech#4331)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 authored and ver217 committed Aug 15, 2023
1 parent 1cf9e01 commit f24f3c2
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 53 deletions.
33 changes: 4 additions & 29 deletions colossalai/shardformer/modeling/bert.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -277,9 +278,6 @@ def bert_for_pretraining_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False

outputs = BertPipelineForwards.bert_model_forward(
self.bert,
Expand Down Expand Up @@ -387,9 +385,6 @@ def bert_lm_head_model_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False

outputs = BertPipelineForwards.bert_model_forward(
self.bert,
Expand Down Expand Up @@ -478,9 +473,6 @@ def bert_for_masked_lm_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False

outputs = BertPipelineForwards.bert_model_forward(
self.bert,
Expand Down Expand Up @@ -579,16 +571,15 @@ def bert_for_next_sentence_prediction_forward(
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids,
Expand Down Expand Up @@ -661,10 +652,6 @@ def bert_for_sequence_classification_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids,
Expand Down Expand Up @@ -753,10 +740,6 @@ def bert_for_token_classification_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = BertPipelineForwards.bert_model_forward(
self.bert,
Expand Down Expand Up @@ -832,10 +815,6 @@ def bert_for_multiple_choice_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# in our pipeline design,input ids are copied for every stage and shouldn't be none
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
Expand Down Expand Up @@ -928,10 +907,6 @@ def bert_for_question_answering_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = BertPipelineForwards.bert_model_forward(
self.bert,
Expand Down
12 changes: 0 additions & 12 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,6 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False

transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer,
input_ids,
Expand Down Expand Up @@ -411,9 +408,6 @@ def bloom_for_sequence_classification_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False

transformer_outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer,
Expand Down Expand Up @@ -537,9 +531,6 @@ def bloom_for_token_classification_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False

transformer_outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer,
Expand Down Expand Up @@ -626,9 +617,6 @@ def bloom_for_question_answering_forward(
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False

outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer,
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def gpt2_model_forward(
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

logger = logging.get_logger(__name__)

# Preprocess passed in arguments
Expand Down
28 changes: 20 additions & 8 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
import torch.nn as nn
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.opt.modeling_opt import (
OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification,
OPTModel,
)

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
Expand Down Expand Up @@ -317,7 +329,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]

@staticmethod
def opt_model_forward(
self: 'OPTModel',
self: OPTModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
Expand All @@ -330,7 +342,7 @@ def opt_model_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'BaseModelOutputWithPast']:
) -> Union[Tuple, BaseModelOutputWithPast]:
'''
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
'''
Expand Down Expand Up @@ -506,7 +518,7 @@ def custom_forward(*inputs):

@staticmethod
def opt_for_causal_lm_forward(
self: 'OPTForCausalLM',
self: OPTForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
Expand All @@ -520,7 +532,7 @@ def opt_for_causal_lm_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'CausalLMOutputWithPast']:
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand Down Expand Up @@ -646,7 +658,7 @@ def opt_for_causal_lm_forward(

@staticmethod
def opt_for_sequence_classification_forward(
self: 'OPTForSequenceClassification',
self: OPTForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
Expand All @@ -660,7 +672,7 @@ def opt_for_sequence_classification_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
Expand Down Expand Up @@ -746,7 +758,7 @@ def opt_for_sequence_classification_forward(

@staticmethod
def opt_for_question_answering_forward(
self: 'OPTForQuestionAnswering',
self: OPTForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
Expand All @@ -761,7 +773,7 @@ def opt_for_question_answering_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'QuestionAnsweringModelOutput']:
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Expand Down
7 changes: 3 additions & 4 deletions tests/test_shardformer/test_model/test_pure_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import random
from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -100,8 +99,8 @@ def __getitem__(self, x):
return torch.ones((4, 128), dtype=torch.int).cuda() * 10


def loss(x, y):
return (x[0].float().mean() - y[0].float().mean())
def loss(y, x):
return (y[0].float().mean() - x[0].float().mean())


@parameterize('enable_fused_normalization', [False])
Expand Down Expand Up @@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
batch = next(data_iter)
with torch.no_grad():
y = model_copy(batch)
org_loss = loss(batch, y)
org_loss = loss(y, batch)
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
Expand Down

0 comments on commit f24f3c2

Please sign in to comment.