Skip to content

Commit

Permalink
#7560: Slow down MMs when running non-perf tests to avoid ND hang. No…
Browse files Browse the repository at this point in the history
… hang observed on BM machine where perf tests are run.
  • Loading branch information
AleksKnezevic committed May 4, 2024
1 parent b390434 commit 65c745e
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_stable_diffusion_device_perf(expected_perf):
margin = 0.02
batch = 1
iterations = 1
command = f"pytest tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py::test_unet_2d_condition_model512x512[batch_size=2-in_channels=4-input_height=64-input_width=64-device_l1_small_size=32768]"
command = f"pytest tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py::test_unet_2d_condition_model_512x512[batch_size=2-in_channels=4-input_height=64-input_width=64-device_l1_small_size=32768]"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def time_sharded_attention(self, query, t_key, value, head_size):
per_core_M=tiles_per_shard,
per_core_N=seq_len // 32,
out_subblock_h=1,
out_subblock_w=8,
out_subblock_w=4,
fuse_batch=True,
fused_activation=None,
mcast_in0=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(self, device, parameters):
def __call__(self, config, hidden_states):
hidden_states = self.geglu(config, hidden_states)

# TODO: Output sharded once https://github.com/tenstorrent/tt-metal/issues/6775 is fixed
interleaved_output = False
size = hidden_states.shape[-2]
grid_size = self.grid_sizes[size]
Expand All @@ -68,6 +67,10 @@ def __call__(self, config, hidden_states):
out_block_h = math.ceil(M / grid_size[1] / 32)
out_block_w = math.ceil(N / grid_size[0] / 32)
out_subblock_h, out_subblock_w = determine_largest_subblock_size(out_block_h, out_block_w)
# TODO: https://github.com/tenstorrent/tt-metal/issues/7560
if size == 512:
out_subblock_h = 1
out_subblock_w = 1
program_config = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid_size,
in0_block_w=in0_block_w,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tt_lib as ttl
from models.experimental.functional_stable_diffusion.tt2.ttnn_functional_utility_functions import (
determine_largest_subblock_size,
determine_blocking,
)


Expand Down Expand Up @@ -95,11 +96,13 @@ def __call__(self, config, hidden_states):
ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED,
ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR,
)
in0_block_h = M // grid_size[1] // 32
in0_block_w = K // grid_size[0] // 32
out_block_h = math.ceil(M / grid_size[1] / 32)
out_block_w = math.ceil(N / grid_size[0] / 32)
out_subblock_h, out_subblock_w = determine_largest_subblock_size(out_block_h, out_block_w)
in0_block_h, in0_block_w, out_subblock_h, out_subblock_w, out_block_h, out_block_w = determine_blocking(
M, K, N, grid_size
)
# TODO: https://github.com/tenstorrent/tt-metal/issues/7560
if size == 512:
out_subblock_h = 1
out_subblock_w = 1
program_config = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid_size,
in0_block_w=in0_block_w,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
out_channels = parameters.conv_in.weight.shape[0]
in_channels = parameters.conv_in.weight.shape[1]

print(f"CIN: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"CIN: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
self.conv_in = ttnn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
Expand All @@ -140,7 +140,7 @@ def __init__(
self.down_blocks = []
input_height = self.conv_in.output_height
input_width = self.conv_in.output_height
print(f"D-1: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"D-1: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
self.down_block_types = down_block_types
for i, down_block_type in enumerate(down_block_types):
if down_block_type == "CrossAttnDownBlock2D":
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(
self.down_blocks.append(down_block)
input_height = down_block.output_height
input_width = down_block.output_width
print(f"D{i}: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"D{i}: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")

assert mid_block_type == "UNetMidBlock2DCrossAttn"
self.mid_block = unet_mid_block_2d_cross_attn(
Expand All @@ -183,7 +183,7 @@ def __init__(
)
input_height = self.mid_block.output_height
input_width = self.mid_block.output_width
print(f"MID: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"MID: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")

self.up_blocks = []
self.up_block_types = up_block_types
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
self.up_blocks.append(up_block)
input_height = up_block.output_height
input_width = up_block.output_width
print(f"UP{i}: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"UP{i}: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")

parameters.conv_out.weight, parameters.conv_out.bias = permute_conv_weights(
parameters.conv_out.weight, parameters.conv_out.bias
Expand All @@ -227,7 +227,7 @@ def __init__(
out_channels = parameters.conv_out.weight.shape[0]
in_channels = parameters.conv_out.weight.shape[1]

print(f"COU: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
logger.info(f"COU: height: {input_height}, width: {input_width}, dim: {2 * input_height * input_width}")
self.conv_out = ttnn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
Expand Down Expand Up @@ -432,7 +432,7 @@ def __call__(
output_channel = block_out_channels[0]
for i, (down_block_type, down_block) in enumerate(zip(self.down_block_types, self.down_blocks)):
ttl.device.DumpDeviceProfiler(self.device)
print(f"Down block {i}")
logger.info(f"Down block {i}")
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
Expand Down Expand Up @@ -486,7 +486,7 @@ def __call__(
down_block_res_samples += res_samples

# 4.mid
print("Mid block")
logger.info("Mid block")
sample = self.mid_block(
hidden_states=sample,
temb=emb,
Expand Down Expand Up @@ -517,7 +517,7 @@ def __call__(
output_channel = reversed_block_out_channels[0]
for i, (up_block_type, up_block) in enumerate(zip(self.up_block_types, self.up_blocks)):
ttl.device.DumpDeviceProfiler(self.device)
print(f"Up block {i}")
logger.info(f"Up block {i}")
is_final_block = i == len(block_out_channels) - 1

prev_output_channel = output_channel
Expand Down Expand Up @@ -591,8 +591,6 @@ def __call__(
), f"CrossAttnUpBlock2D, and UpBlock2D are the only up blocks implemented! you requested {up_block_type}"

# 6.post-process
# print(sample.shape)
# print(sample.memory_config())
sample = ttnn.to_layout(sample, ttnn.ROW_MAJOR_LAYOUT)
if self.fallback_on_groupnorm:
assert self.norm_num_groups == norm_num_groups
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,11 @@ def determine_blocking(M, K, N, grid_size, transpose_mcast=False):
out_block_h = math.ceil(M / logical_grid_size[1] / 32)
out_block_w = math.ceil(N / logical_grid_size[0] / 32)
out_subblock_h, out_subblock_w = determine_largest_subblock_size(out_block_h, out_block_w)
# TODO: https://github.com/tenstorrent/tt-metal/issues/7560
# There's a bug that causes an ND hang, until it's solved reduce subblock sizes to 1, if we're not
import os

if os.environ.get("SLOW_MATMULS", "0") == "1":
out_subblock_h = 1
out_subblock_w = 1
return in0_block_h, in0_block_w, out_subblock_h, out_subblock_w, out_block_h, out_block_w
2 changes: 1 addition & 1 deletion tests/scripts/nightly/run_wh_b0_only.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fi
echo "Running nightly tests for WH B0 only"

env pytest tests/ttnn/integration_tests/unet # -> failing: issue #7556
# env pytest tests/ttnn/integration_tests/stable_diffusion # -> failing/hanging: issue #7560
SLOW_MATMULS=1 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml env pytest tests/ttnn/integration_tests/stable_diffusion

env pytest models/demos/mamba/tests/test_mamba_ssm.py
env pytest models/demos/mamba/tests/test_mamba_block.py
Expand Down
2 changes: 2 additions & 0 deletions tests/scripts/run_python_model_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,6 @@ if [ "$ARCH_NAME" != "wormhole_b0" ]; then
pytest $TT_METAL_HOME/models/demos/falcon7b/tests/unit_tests/test_falcon_attn_matmul.py -k "not attn_matmul_from_cache"
# higher sequence lengths and different formats trigger memory issues
pytest $TT_METAL_HOME/models/demos/falcon7b/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_128 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM"

SLOW_MATMULS=1 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest $TT_METAL_HOME/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py -k 512
fi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
cross_attention as tt2_ttnn_cross_attention,
)
from ttnn.model_preprocessing import preprocess_model_parameters
from tests.ttnn.utils_for_testing import assert_with_pcc
from tests.ttnn.utils_for_testing import assert_with_pcc, comp_pcc
from models.utility_functions import (
skip_for_grayskull,
)
Expand Down Expand Up @@ -254,17 +254,13 @@ def test_cross_attention_512x512(device, model_name, N, C, H, W, index, has_enco
ttnn_hidden_states = ttnn.to_device(ttnn_hidden_states, device)

model = tt2_ttnn_cross_attention(device, parameters)
signpost(header="start")
ttnn_output = model(
ttnn_hidden_states,
ttnn_encoder_hidden_states,
attention_mask=None,
dim_head=W // 8,
)
signpost(header="stop")

ttnn_output = ttnn.from_device(ttnn_output)
ttnn_output = ttnn.to_torch(ttnn_output)
ttnn_output = ttnn_output.reshape(N, C, H, W)

assert_with_pcc(torch_output, ttnn_output, pcc=0.99)
passing, output = comp_pcc(torch_output, ttnn_output, pcc=0.99)
print(output)
assert passing
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from tqdm.auto import tqdm
import time
from tracy import signpost

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import (
Expand Down Expand Up @@ -211,7 +210,13 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_
model = UNet2D(device, parameters, batch_size, input_height, input_width, reader_patterns_cache)

first_iter = time.time()
signpost(header="start")
use_signpost = True
try:
from tracy import signpost
except ModuleNotFoundError:
use_signpost = False
if use_signpost:
signpost(header="start")
ttnn_output = model(
input,
timestep=ttnn_timestep,
Expand All @@ -222,8 +227,10 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_
return_dict=return_dict,
config=config,
)
signpost(header="stop")
if use_signpost:
signpost(header="stop")
first_iter = time.time() - first_iter
ttnn_output = ttnn_to_torch(ttnn_output)
print(f"First iteration took {first_iter} seconds")
# times = []
# for i in range(50):
Expand All @@ -238,6 +245,9 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_
# return_dict=return_dict,
# config=config,
# )
# ttnn_output = ttnn_to_torch(ttnn_output)
# passing, output = comp_pcc(torch_output, ttnn_output, pcc=0.99)
# print(output)
# end = time.time()
# times.append(end - start)
# print(f"Current iteration took {end - start} seconds")
Expand All @@ -247,7 +257,6 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_
# print(iter)
# print(f"Time taken for 50 iterations: {total_time}")
# print(f"Samples per second: {50 / total_time}")
ttnn_output = ttnn_to_torch(ttnn_output)
passing, output = comp_pcc(torch_output, ttnn_output, pcc=0.99)
print(output)
assert passing
Expand Down

0 comments on commit 65c745e

Please sign in to comment.