Skip to content

Commit

Permalink
#0: Preallocate trace buffer and removed trace owned_pool
Browse files Browse the repository at this point in the history
Update device trace cmds to take in cq_id, remove multi-device apis
Add tracing tests for metal Resnet50. TODO: Cleanup/reuse code
Disable allocations after capturing trace
Update trace apis to return/take in trace id. Make device own TraceBuffer mapping. Remove trace apis that correspond to allowing users to create Trace objects
#8383: End any active traces during device close and assert tracing is not enabled for terminate cmd
  • Loading branch information
tt-aho committed May 16, 2024
1 parent adb73b0 commit 2bcb208
Show file tree
Hide file tree
Showing 32 changed files with 738 additions and 1,361 deletions.
7 changes: 4 additions & 3 deletions docs/aspell-dictionary.pws
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ BRISC
BRISCs
BUF
BUFs
BeginTrace
BeginTraceCapture
BertIntermediate
BinaryOpType
BorrowedStorage
Expand Down Expand Up @@ -62,7 +62,7 @@ DumpDeviceProfileResults
ENDL
ETH
EltwiseUnary
EndTrace
EndTraceCapture
EnqueueProgram
EnqueueReadBuffer
EnqueueRecordEvent
Expand All @@ -84,7 +84,6 @@ Grayskull
HW
HiFi
HostDataType
InstantiateTrace
InterleavedBufferConfig
Jupyter
KernelHandle
Expand Down Expand Up @@ -136,6 +135,8 @@ RISCVs
RISCs
ReadFromDevice
ReduceDim
RepeatTrace
ReplayTrace
ResNet
RuntimeArgs
SETPRECISION
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
BeginTraceCapture
=================

.. doxygenfunction:: BeginTraceCapture

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
EndTraceCapture
===============

.. doxygenfunction:: EndTraceCapture

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
ReleaseTrace
============

.. doxygenfunction:: ReleaseTrace
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
ReplayTrace
===========

.. doxygenfunction:: ReplayTrace
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ CommandQueue
EnqueueWaitForEvent
EventQuery
EventSynchronize
BeginTrace
EndTrace
InstantiateTrace
BeginTraceCapture
EndTraceCapture
ReplayTrace
ReleaseTrace
EnqueueTrace
Finish
120 changes: 120 additions & 0 deletions models/demos/resnet/tests/test_metal_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,123 @@ def test_run_resnet50_inference(
passing_pcc, _ = comp_pcc(torch_output, tt_output, pcc=valid_pcc)
assert passing_pcc
# assert passing # fails because of torch.allclose


@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.")
@pytest.mark.parametrize("device_l1_small_size", [24576], indirect=True)
@pytest.mark.parametrize("batch_size", [1, 2, 16, 20], ids=["batch_1", "batch_2", "batch_16", "batch_20"])
@pytest.mark.parametrize(
"weights_dtype",
[tt_lib.tensor.DataType.BFLOAT16, tt_lib.tensor.DataType.BFLOAT8_B],
ids=["weights_BFLOAT16", "weights_BFLOAT8_B"],
)
@pytest.mark.parametrize(
"activations_dtype",
[tt_lib.tensor.DataType.BFLOAT16, tt_lib.tensor.DataType.BFLOAT8_B],
ids=["activations_BFLOAT16", "activations_BFLOAT8_B"],
)
@pytest.mark.parametrize(
"math_fidelity",
[tt_lib.tensor.MathFidelity.HiFi4, tt_lib.tensor.MathFidelity.HiFi2, tt_lib.tensor.MathFidelity.LoFi],
ids=["HiFi4", "HiFi2", "LoFi"],
)
def test_run_resnet50_trace_inference(
device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input
):
if is_e75(device):
pytest.skip("Resnet50 is not supported on E75")

if batch_size > 8 and (
activations_dtype != tt_lib.tensor.DataType.BFLOAT8_B or weights_dtype != tt_lib.tensor.DataType.BFLOAT8_B
):
pytest.skip("Batch > 8 must be run fully bfp8")
if batch_size <= 2:
pytest.skip("batch 1 and 2 are not supported with sharded data")
image1 = imagenet_sample_input
image = image1
model_config = {
"MATH_FIDELITY": math_fidelity,
"WEIGHTS_DTYPE": weights_dtype,
"ACTIVATIONS_DTYPE": activations_dtype,
}
for i in range(batch_size - 1):
image = torch.cat((image, image1), dim=0)
with torch.no_grad():
torch.manual_seed(1234)

tt_lib.device.EnableMemoryReports()

torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
torch_resnet50.eval()

state_dict = torch_resnet50.state_dict()
storage_in_dram = False
sharded = False
if batch_size >= 8:
sharded = True
# run once to compile ops
tt_resnet50 = ResNet(
Bottleneck,
[3, 4, 6, 3],
device=device,
state_dict=state_dict,
base_address="",
fold_batchnorm=True,
storage_in_dram=storage_in_dram,
batch_size=batch_size,
model_config=model_config,
sharded=sharded,
)

torch_output = torch_resnet50(image).unsqueeze(1).unsqueeze(1)
interleaved_mem_config_DRAM = tt_lib.tensor.MemoryConfig(
memory_layout=tt_lib.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=tt_lib.tensor.BufferType.DRAM,
)

tt_image_res = tt_resnet50.preprocessing(image).to(device, interleaved_mem_config_DRAM)

# Compile
tt_resnet50(tt_image_res)
# Trace
tid = tt_lib.device.BeginTraceCapture(device, 0, 1304576)
tt_output_res = tt_resnet50(tt_image_res)
tt_lib.device.EndTraceCapture(device, 0, tid)

tt_lib.device.ReplayTrace(device, 0, tid, True)

tt_output = tt_output_res.cpu().to_torch().to(torch.float)

# # run again to measure end to end perf
# start_time = datetime.now()
# tt_output = tt_resnet50(image)
# end_time = datetime.now()
# diff = end_time - start_time
# logger.info("End to end time (microseconds))", diff.microseconds)
# throughput_fps = (float) (1000000 / diff.microseconds)
# logger.info("Throughput (fps)", throughput_fps)

_, _, _, info = get_atol_rtol_pcc(torch_output, tt_output)
logger.info(info)

valid_pcc = 1.0
if batch_size >= 8:
valid_pcc = golden_pcc[batch_size][
(model_config["MATH_FIDELITY"], model_config["WEIGHTS_DTYPE"], model_config["ACTIVATIONS_DTYPE"])
]
else:
if model_config["ACTIVATIONS_DTYPE"] == tt_lib.tensor.DataType.BFLOAT8_B:
if model_config["MATH_FIDELITY"] == tt_lib.tensor.MathFidelity.LoFi:
valid_pcc = 0.87
else:
valid_pcc = 0.94
else:
if model_config["MATH_FIDELITY"] == tt_lib.tensor.MathFidelity.LoFi:
valid_pcc = 0.93
else:
valid_pcc = 0.982
passing_pcc, _ = comp_pcc(torch_output, tt_output, pcc=valid_pcc)
assert passing_pcc
# assert passing # fails because of torch.allclose
# Done with the trace, can deallocate the buffers now.
tt_lib.device.ReleaseTrace(device, tid)
163 changes: 156 additions & 7 deletions models/demos/resnet/tests/test_perf_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,19 @@ def run_perf_resnet(
warm_end = warm_start + num_warm_iterations

outputs = []
inference_time_sum = 0
profiler.start(f"run")
for iter in range(warm_start, warm_end):
profiler.start(f"run")
outputs.append(tt_resnet50(tt_inputs).cpu(blocking=False))
profiler.end(f"run")
inference_time_sum += profiler.get("run")
tt_lib.device.DumpDeviceProfiler(device)

tt_lib.device.Synchronize(device)
profiler.end(f"run")
tt_lib.device.DumpDeviceProfiler(device)

# enable_persistent_kernel_cache()

first_iter_time = profiler.get(f"{0}_key")

# ensuring inference time fluctuations is not noise
inference_time_avg = inference_time_sum / num_warm_iterations
inference_time_avg = profiler.get("run") / num_warm_iterations

cpu_time = profiler.get(cpu_key)
compile_time = first_iter_time - inference_time_avg
Expand Down Expand Up @@ -152,3 +149,155 @@ def test_perf_bare_metal(
hf_cat_image_sample_input,
device,
)


def run_perf_resnet_trace(
batch_size,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
):
disable_persistent_kernel_cache()
if batch_size <= 2:
pytest.skip("Batch size 1 and 2 are not supported with sharded data")
first_key = f"first_iter_batchsize{batch_size}"
second_key = f"second_iter_batchsize{batch_size}"
cpu_key = f"ref_key_batchsize{batch_size}"
model_name = "microsoft/resnet-50"

image = hf_cat_image_sample_input
image_processor = AutoImageProcessor.from_pretrained(model_name)
inputs = image_processor(image, return_tensors="pt")

inputs = inputs["pixel_values"]
comments = f"{list(inputs.shape)[-2]}x{list(inputs.shape)[-1]}_batchsize{batch_size}"

inputs1 = inputs
for i in range(batch_size - 1):
inputs = torch.cat((inputs, inputs1), dim=0)

torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
torch_resnet50.eval()

state_dict = torch_resnet50.state_dict()
sharded = False
if batch_size >= 8:
sharded = True
tt_resnet50 = ResNet(
Bottleneck,
[3, 4, 6, 3],
device=device,
state_dict=state_dict,
base_address="",
fold_batchnorm=True,
storage_in_dram=False,
batch_size=batch_size,
model_config=model_config,
sharded=sharded,
)

with torch.no_grad():
profiler.start(cpu_key)
logits = torch_resnet50(inputs)
profiler.end(cpu_key)

tt_inputs = tt_resnet50.preprocessing(inputs)
interleaved_mem_config_DRAM = tt_lib.tensor.MemoryConfig(
memory_layout=tt_lib.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=tt_lib.tensor.BufferType.DRAM,
)
tt_image_res = tt_inputs.to(device, interleaved_mem_config_DRAM)
# Compile
profiler.start(f"{0}_key")
tt_lib.tensor.write_tensor(tt_inputs, tt_image_res)
tt_resnet50(tt_image_res).cpu(blocking=True)
profiler.end(f"{0}_key")
tt_lib.device.DumpDeviceProfiler(device)

# Capture
tid = tt_lib.device.BeginTraceCapture(device, 0, 1304576)
tt_output_res = tt_resnet50(tt_image_res)
tt_lib.device.EndTraceCapture(device, 0, tid)
tt_lib.device.DumpDeviceProfiler(device)

warmup_end = 6
for iter in range(1, warmup_end):
profiler.start(f"{iter}_key")
tt_lib.tensor.write_tensor(tt_inputs, tt_image_res)
tt_lib.device.ReplayTrace(device, 0, tid, False)
_ = tt_output_res.cpu(blocking=True)
profiler.end(f"{iter}_key")
tt_lib.device.DumpDeviceProfiler(device)

num_warm_iterations = 15
warm_start = warmup_end
warm_end = warm_start + num_warm_iterations

outputs = []
profiler.start(f"run")
for iter in range(warm_start, warm_end):
tt_lib.tensor.write_tensor(tt_inputs, tt_image_res)
tt_lib.device.ReplayTrace(device, 0, tid, False)
outputs.append(tt_output_res.cpu(blocking=False))
tt_lib.device.Synchronize(device)
profiler.end(f"run")
tt_lib.device.DumpDeviceProfiler(device)

# enable_persistent_kernel_cache()

first_iter_time = profiler.get(f"{0}_key")

# ensuring inference time fluctuations is not noise
inference_time_avg = profiler.get("run") / num_warm_iterations

cpu_time = profiler.get(cpu_key)
compile_time = first_iter_time - inference_time_avg
prep_perf_report(
model_name=f"resnet50_trace_batch_size{batch_size}",
batch_size=batch_size,
inference_and_compile_time=first_iter_time,
inference_time=inference_time_avg,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments=comments,
inference_time_cpu=cpu_time,
)

logger.info(f"resnet50 {comments} inference time (avg): {inference_time_avg}")
logger.info(f"resnet50 compile time: {compile_time}")

tt_lib.device.ReleaseTrace(device, tid)

assert inference_time_avg < expected_inference_time, f"resnet50 {comments} inference is too slow"
assert compile_time < expected_compile_time, f"resnet50 {comments} compilation is too slow"


@skip_for_wormhole_b0(reason_str="Not tested on single WH")
@pytest.mark.parametrize("device_l1_small_size", [32768], indirect=True)
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"batch_size, expected_inference_time, expected_compile_time",
(
(16, 0.03, 25),
(20, 0.03, 25),
),
)
def test_perf_trace_bare_metal(
device,
use_program_cache,
batch_size,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
):
if is_e75(device):
pytest.skip("Resnet is not supported on E75")

run_perf_resnet_trace(
batch_size,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
)

0 comments on commit 2bcb208

Please sign in to comment.