Skip to content

Commit

Permalink
#8364: Disable fallback for reshape, add ttnn.fallback.reshape, updat…
Browse files Browse the repository at this point in the history
…e usage
  • Loading branch information
ayerofieiev-tt committed May 22, 2024
1 parent 1e41f09 commit 5d256fa
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 22 deletions.
10 changes: 6 additions & 4 deletions models/demos/bert/tt/ttnn_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,30 @@ def bert_attention(
*,
parameters,
):
fallback_reshape = ttnn.get_fallback_function(ttnn.reshape)

num_heads = config.num_attention_heads
batch_size, sequence_size, hidden_size = hidden_states.shape
head_size = hidden_size // num_heads

query = hidden_states @ parameters.self.query.weight
query = query + parameters.self.query.bias
query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT)
query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size))
query = fallback_reshape(query, (batch_size, sequence_size, num_heads, head_size))
query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT)
query = ttnn.permute(query, (0, 2, 1, 3))

key = hidden_states @ parameters.self.key.weight
key = key + parameters.self.key.bias
key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT)
key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size))
key = fallback_reshape(key, (batch_size, sequence_size, num_heads, head_size))
key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT)
key = ttnn.permute(key, (0, 2, 3, 1))

value = hidden_states @ parameters.self.value.weight
value = value + parameters.self.value.bias
value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT)
value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size))
value = fallback_reshape(value, (batch_size, sequence_size, num_heads, head_size))
value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT)
value = ttnn.permute(value, (0, 2, 1, 3))

Expand All @@ -49,7 +51,7 @@ def bert_attention(
context_layer = attention_probs @ value
context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
context_layer = ttnn.to_layout(context_layer, ttnn.ROW_MAJOR_LAYOUT)
context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size))
context_layer = fallback_reshape(context_layer, (batch_size, sequence_size, hidden_size))
context_layer = ttnn.to_layout(context_layer, ttnn.TILE_LAYOUT)

self_output = context_layer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def merge_heads(x: ttnn.Tensor) -> ttnn.Tensor:

# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.reshape(x, shape=(batch_size, seq_length, num_heads * head_size))
x = ttnn.get_fallback_function(ttnn.reshape)(x, shape=(batch_size, seq_length, num_heads * head_size))
x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)
return x

Expand Down
4 changes: 2 additions & 2 deletions models/demos/grayskull/t5/tt/ttnn_functional_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def t5_attention(
def shape(states, head_size, is_key=False):
"""projection"""
states = ttnn.to_layout(states, layout=ttnn.ROW_MAJOR_LAYOUT)
states = ttnn.reshape(states, (batch_size, seq_length, config.num_heads, head_size))
states = ttnn.get_fallback_function(ttnn.reshape)(states, (batch_size, seq_length, config.num_heads, head_size))
if is_key:
states = ttnn.permute(states, (0, 2, 3, 1))
else:
Expand All @@ -180,7 +180,7 @@ def unshape(states, hidden_size):
"""reshape"""
states = ttnn.permute(states, (0, 2, 1, 3))
states = ttnn.to_layout(states, ttnn.ROW_MAJOR_LAYOUT)
states = ttnn.reshape(states, (batch_size, seq_length, hidden_size))
states = ttnn.get_fallback_function(ttnn.reshape)(states, (batch_size, seq_length, hidden_size))
states = ttnn.to_layout(states, ttnn.TILE_LAYOUT)
return states

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ def attention(config, x, bcast_freq_xq, bcast_freq_xk, positions, mask, seqlen,
xk = xk[:, :seqlen, :]
xv = xv[:, :seqlen, :]

xq = ttnn.reshape(xq, (bsz, seqlen, config.n_heads, config.head_dim))
xk = ttnn.reshape(xk, (bsz, seqlen, config.n_kv_heads, config.head_dim))
xv = ttnn.reshape(xv, (bsz, seqlen, config.n_kv_heads, config.head_dim))
fallback_reshape = ttnn.get_fallback_function(ttnn.reshape)
xq = fallback_reshape(xq, (bsz, seqlen, config.n_heads, config.head_dim))
xk = fallback_reshape(xk, (bsz, seqlen, config.n_kv_heads, config.head_dim))
xv = fallback_reshape(xv, (bsz, seqlen, config.n_kv_heads, config.head_dim))

xq, xk = apply_rotary_emb(xq, xk, bcast_freq_xq, bcast_freq_xk, device, mem_config)

Expand Down Expand Up @@ -148,7 +149,7 @@ def attention(config, x, bcast_freq_xq, bcast_freq_xk, positions, mask, seqlen,
output = scores @ value
output = ttnn.permute(output, (0, 2, 1, 3))
output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT)
output = ttnn.reshape(output, (1, bsz, seqlen, -1))
output = fallback_reshape(output, (1, bsz, seqlen, -1))
output = ttnn.to_layout(output, ttnn.TILE_LAYOUT)
output = output @ parameters.wo.weight
return output
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/test_model_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def forward(self, x):
output_tensor = ttnn.relu(output_tensor)
output_tensor = ttnn.permute(output_tensor, (0, 3, 1, 2))
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.reshape(output_tensor, (-1, num_output_channels))
output_tensor = ttnn.get_fallback_function(ttnn.reshape)(output_tensor, (-1, num_output_channels))
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = output_tensor @ linear.weight + linear.bias
output_tensor = ttnn.to_torch(output_tensor)
Expand Down
2 changes: 1 addition & 1 deletion ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import tt_lib as _tt_lib
import ttnn._ttnn


CPP_CONFIG: ttnn._ttnn.core.Config = ttnn._ttnn.CONFIG


Expand Down Expand Up @@ -240,6 +239,7 @@ def manage_config(name, value):
register_pre_operation_hook,
register_post_operation_hook,
get_golden_function,
get_fallback_function,
)

import ttnn.experimental
Expand Down
18 changes: 16 additions & 2 deletions ttnn/ttnn/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,6 @@ def call_wrapper(*function_args, **function_kwargs):
f"{self.python_fully_qualified_name}: falling back to CPU due to {exception_message}"
)
output = golden_function(*function_args, **function_kwargs)

if ttnn.CONFIG.throw_exception_on_fallback and ran_golden_function:
raise RuntimeError(
f"Fallbacks are disabled, but {self.python_fully_qualified_name} used a fallback"
Expand Down Expand Up @@ -690,7 +689,6 @@ def call_wrapper(*function_args, **function_kwargs):
function = fallback_to_golden_function_decorator(function)

function = runtime_decorator(function)

self.decorated_function = function

def __call__(self, *function_args, **function_kwargs):
Expand Down Expand Up @@ -730,12 +728,17 @@ def query_registered_operations(include_experimental=False):


OPERATION_TO_GOLDEN_FUNCTION = {}
OPERATION_TO_FALLBACK_FUNCTION = {}


def get_golden_function(operation):
return OPERATION_TO_GOLDEN_FUNCTION[operation]


def get_fallback_function(operation):
return OPERATION_TO_FALLBACK_FUNCTION[operation]


def register_operation(
*,
name=None,
Expand All @@ -751,9 +754,19 @@ def register_operation(
def operation_decorator(function: callable):
global REGISTERED_APIS
global OPERATION_TO_GOLDEN_FUNCTION
global OPERATION_TO_FALLBACK_FUNCTION

def fallback_function(*function_args, **function_kwargs):
updated_function_args, updated_function_kwargs = preprocess_golden_function_inputs(
function_args, function_kwargs
)
output = golden_function(*updated_function_args, **updated_function_kwargs)
output = postprocess_golden_function_outputs(output, function_args, function_kwargs)
return output

if ttnn.CONFIG.enable_fast_runtime_mode:
OPERATION_TO_GOLDEN_FUNCTION[function] = golden_function
OPERATION_TO_FALLBACK_FUNCTION[function] = fallback_function
return function

is_cpp_function = hasattr(function, "__ttnn__")
Expand Down Expand Up @@ -806,6 +819,7 @@ def method_call(self, *function_args, **function_kwargs):
REGISTERED_APIS.add(api)

OPERATION_TO_GOLDEN_FUNCTION[api] = golden_function
OPERATION_TO_FALLBACK_FUNCTION[api] = fallback_function

return api

Expand Down
11 changes: 8 additions & 3 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,22 @@ def _postprocess_golden_function_outputs(output, args, kwargs):
"""


# Unsupported cases, which require a fallback: (found in bert, t5, bloom)
# Shape([1, 128, 512]) <-> (1, 128, 8, 64)
# Shape([1, 128, 384]) <-> (1, 128, 6, 64)
# Shape([1, 384, 1024]) <-> (1, 384, 16, 64)
# Shape([1, 11, 4096]) <-> (1, 11, 32, 128)
# Shape([1, 128, 28, 28]) <-> (-1, 128)
# Shape([1, 11, 32, 128]) <-> (1, 1, 11, -1) in ttnn_functional_attention.py test_mistral_attention_inference
reshape = ttnn.register_operation(
name="ttnn.reshape",
golden_function=_golden_function,
preprocess_golden_function_inputs=_preprocess_golden_function_inputs,
postprocess_golden_function_outputs=_postprocess_golden_function_outputs,
allow_to_fallback_to_golden_function_on_failure=True,
allow_to_fallback_to_golden_function_on_failure=False,
doc=doc,
)(ttnn._ttnn.operations.core.reshape)


# TODO(arakhmati): remove this once underlying C++ code can handle non-4D shapes
unsqueeze_to_4D = ttnn.register_operation(name="ttnn.unsqueeze_to_4D")(ttnn._ttnn.operations.core.unsqueeze_to_4D)

Expand Down
10 changes: 6 additions & 4 deletions ttnn/tutorials/003.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,29 @@
" *,\n",
" num_heads,\n",
"):\n",
" fallback_reshape = ttnn.get_fallback_function(ttnn.reshape) \n",
" \n",
" batch_size, sequence_size, hidden_size = hidden_states.shape\n",
" head_size = hidden_size // num_heads\n",
"\n",
" query = hidden_states @ query_weight\n",
" query = query + query_bias\n",
" query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT)\n",
" query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size))\n",
" query = fallback_reshape(query, (batch_size, sequence_size, num_heads, head_size))\n",
" query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT)\n",
" query = ttnn.permute(query, (0, 2, 1, 3))\n",
"\n",
" key = hidden_states @ key_weight\n",
" key = key + key_bias\n",
" key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT)\n",
" key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size))\n",
" key = fallback_reshape(key, (batch_size, sequence_size, num_heads, head_size))\n",
" key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT)\n",
" key = ttnn.permute(key, (0, 2, 3, 1))\n",
"\n",
" value = hidden_states @ value_weight\n",
" value = value + value_bias\n",
" value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT)\n",
" value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size))\n",
" value = fallback_reshape(value, (batch_size, sequence_size, num_heads, head_size))\n",
" value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT)\n",
" value = ttnn.permute(value, (0, 2, 1, 3))\n",
"\n",
Expand All @@ -157,7 +159,7 @@
" context_layer = attention_probs @ value\n",
" context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))\n",
" context_layer = ttnn.to_layout(context_layer, layout=ttnn.ROW_MAJOR_LAYOUT)\n",
" context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size))\n",
" context_layer = fallback_reshape(context_layer, (batch_size, sequence_size, hidden_size))\n",
" context_layer = ttnn.to_layout(context_layer, layout=ttnn.TILE_LAYOUT)\n",
"\n",
" self_output = context_layer @ output_weight\n",
Expand Down

0 comments on commit 5d256fa

Please sign in to comment.