Skip to content

Commit

Permalink
#0: Reduce subblock_w in MM to avoid hang
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksKnezevic committed May 3, 2024
1 parent 10a517e commit fb3ca54
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(self, device, parameters):
def __call__(self, config, hidden_states):
hidden_states = self.geglu(config, hidden_states)

# TODO: Output sharded once https://github.com/tenstorrent/tt-metal/issues/6775 is fixed
interleaved_output = False
size = hidden_states.shape[-2]
grid_size = self.grid_sizes[size]
Expand All @@ -68,6 +67,9 @@ def __call__(self, config, hidden_states):
out_block_h = math.ceil(M / grid_size[1] / 32)
out_block_w = math.ceil(N / grid_size[0] / 32)
out_subblock_h, out_subblock_w = determine_largest_subblock_size(out_block_h, out_block_w)
if size == 512:
out_subblock_h = 1
out_subblock_w = 1
program_config = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid_size,
in0_block_w=in0_block_w,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tt_lib as ttl
from models.experimental.functional_stable_diffusion.tt2.ttnn_functional_utility_functions import (
determine_largest_subblock_size,
determine_blocking,
)


Expand Down Expand Up @@ -95,11 +96,12 @@ def __call__(self, config, hidden_states):
ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED,
ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR,
)
in0_block_h = M // grid_size[1] // 32
in0_block_w = K // grid_size[0] // 32
out_block_h = math.ceil(M / grid_size[1] / 32)
out_block_w = math.ceil(N / grid_size[0] / 32)
out_subblock_h, out_subblock_w = determine_largest_subblock_size(out_block_h, out_block_w)
in0_block_h, in0_block_w, out_subblock_h, out_subblock_w, out_block_h, out_block_w = determine_blocking(
M, K, N, grid_size
)
if size == 512:
out_subblock_h = 1
out_subblock_w = 1
program_config = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid_size,
in0_block_w=in0_block_w,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
out_channels = parameters.conv_in.weight.shape[0]
in_channels = parameters.conv_in.weight.shape[1]

print(f"CIN: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"CIN: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
self.conv_in = ttnn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
Expand All @@ -140,7 +140,7 @@ def __init__(
self.down_blocks = []
input_height = self.conv_in.output_height
input_width = self.conv_in.output_height
print(f"D-1: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"D-1: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
self.down_block_types = down_block_types
for i, down_block_type in enumerate(down_block_types):
if down_block_type == "CrossAttnDownBlock2D":
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(
self.down_blocks.append(down_block)
input_height = down_block.output_height
input_width = down_block.output_width
print(f"D{i}: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"D{i}: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")

assert mid_block_type == "UNetMidBlock2DCrossAttn"
self.mid_block = unet_mid_block_2d_cross_attn(
Expand All @@ -183,7 +183,7 @@ def __init__(
)
input_height = self.mid_block.output_height
input_width = self.mid_block.output_width
print(f"MID: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"MID: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")

self.up_blocks = []
self.up_block_types = up_block_types
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
self.up_blocks.append(up_block)
input_height = up_block.output_height
input_width = up_block.output_width
print(f"UP{i}: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"UP{i}: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")

parameters.conv_out.weight, parameters.conv_out.bias = permute_conv_weights(
parameters.conv_out.weight, parameters.conv_out.bias
Expand All @@ -227,7 +227,7 @@ def __init__(
out_channels = parameters.conv_out.weight.shape[0]
in_channels = parameters.conv_out.weight.shape[1]

print(f"COU: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"COU: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
self.conv_out = ttnn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
Expand Down Expand Up @@ -432,7 +432,7 @@ def __call__(
output_channel = block_out_channels[0]
for i, (down_block_type, down_block) in enumerate(zip(self.down_block_types, self.down_blocks)):
ttl.device.DumpDeviceProfiler(self.device)
print(f"Down block {i}")
logger.info(f"Down block {i}")
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
Expand Down Expand Up @@ -486,7 +486,7 @@ def __call__(
down_block_res_samples += res_samples

# 4.mid
print("Mid block")
logger.info("Mid block")
sample = self.mid_block(
hidden_states=sample,
temb=emb,
Expand Down Expand Up @@ -517,7 +517,7 @@ def __call__(
output_channel = reversed_block_out_channels[0]
for i, (up_block_type, up_block) in enumerate(zip(self.up_block_types, self.up_blocks)):
ttl.device.DumpDeviceProfiler(self.device)
print(f"Up block {i}")
logger.info(f"Up block {i}")
is_final_block = i == len(block_out_channels) - 1

prev_output_channel = output_channel
Expand Down Expand Up @@ -591,8 +591,6 @@ def __call__(
), f"CrossAttnUpBlock2D, and UpBlock2D are the only up blocks implemented! you requested {up_block_type}"

# 6.post-process
# print(sample.shape)
# print(sample.memory_config())
sample = ttnn.to_layout(sample, ttnn.ROW_MAJOR_LAYOUT)
if self.fallback_on_groupnorm:
assert self.norm_num_groups == norm_num_groups
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/run_python_model_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ if [ "$ARCH_NAME" != "wormhole_b0" ]; then
# higher sequence lengths and different formats trigger memory issues
pytest $TT_METAL_HOME/models/demos/falcon7b/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_128 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM"

WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest $TT_METAL_HOME/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py -k 512
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -svv $TT_METAL_HOME/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py -k 512
fi

0 comments on commit fb3ca54

Please sign in to comment.