Skip to content

Commit

Permalink
[shardformer] rewrite tests for opt/bloom/llama/vit/chatglm (hpcaitec…
Browse files Browse the repository at this point in the history
…h#4395)

* rewrite opt tests

* rewrite llama tests

* rewrite bloom & vit tests

* rewrite chatglm tests

* fix LinearCol for classfiers

* add judge for other tp layers, fix lazy init in util
  • Loading branch information
Fridge003 authored and ver217 committed Aug 15, 2023
1 parent 17a34e3 commit 77620f0
Show file tree
Hide file tree
Showing 19 changed files with 1,072 additions and 1,281 deletions.
16 changes: 16 additions & 0 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]

tp_size = dist.get_world_size(process_group)
if out_features < tp_size:
return module

if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")

linear_1d = Linear1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
Expand Down Expand Up @@ -293,6 +301,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]

tp_size = dist.get_world_size(process_group)
if in_features < tp_size:
return module

if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")

linear_1d = Linear1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
Expand Down
16 changes: 16 additions & 0 deletions colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,14 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]

tp_size = dist.get_world_size(process_group)
if out_features < tp_size:
return module

if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")

linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
Expand Down Expand Up @@ -420,6 +428,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]

tp_size = dist.get_world_size(process_group)
if in_features < tp_size:
return module

if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")

linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
Expand Down
Loading

0 comments on commit 77620f0

Please sign in to comment.