Skip to content

Commit

Permalink
#9336: change the output tensor to an optional output
Browse files Browse the repository at this point in the history
  • Loading branch information
hschoi4448 committed Jun 24, 2024
1 parent 385ceea commit 42dbdda
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
import torch.nn.functional as F
import copy

import tt_lib as ttl
from models.utility_functions import comp_allclose
Expand All @@ -25,7 +26,15 @@ def to_cpu(npu_tensor, shape, *, cpu_layout=ttl.tensor.Layout.ROW_MAJOR):
return None
if not isinstance(shape, (list, tuple)):
shape = tuple(shape)
cpu_tensor = npu_tensor.cpu().to(cpu_layout).unpad_from_tile(shape).to_torch()

unpad_shape = copy.copy(shape)
if shape == []:
unpad_shape = [1, 1]
if len(shape) == 1:
unpad_shape = [1] + shape

cpu_tensor = npu_tensor.cpu().to(cpu_layout).unpad_from_tile(unpad_shape).to_torch().reshape(shape)

return cpu_tensor


Expand All @@ -45,6 +54,9 @@ def to_npu(
if len(cpu_tensor.shape) == 1:
cpu_tensor = cpu_tensor.reshape([1, len(cpu_tensor)])

if len(cpu_tensor.shape) == 0:
cpu_tensor = cpu_tensor.reshape([1, 1])

npu_tensor = ttl.tensor.Tensor(cpu_tensor, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
return npu_tensor

Expand All @@ -54,9 +66,11 @@ def torch_layernorm(input, *, normalized_dims=1, eps=1e-5, gamma=None, beta=None
mean_rstd_dims = tuple(range(-normalized_dims, 0))

mean = input.clone().mean(dim=mean_rstd_dims, keepdim=True)
var = ((input.clone() - mean) ** 2).mean(dim=mean_rstd_dims, keepdim=True)
var = ((input.clone() - mean) ** 2).mean(dim=mean_rstd_dims)
rstd = (var + eps).rsqrt()

mean = torch.squeeze(mean, mean_rstd_dims)

output = F.layer_norm(input, normalized_shape, weight=gamma, bias=beta, eps=eps)

return output, mean, rstd
Expand Down Expand Up @@ -90,14 +104,18 @@ def tt_layernorm(input, *, normalized_dims=1, eps=1e-5, gamma=None, beta=None, d
input_shape = list(input.shape)

# mean_rstd_shape
mean_rstd_shape = input_shape[:-normalized_dims] + [1] * normalized_dims
mean_rstd_shape = input_shape[:-normalized_dims]

# dtype
cpu_dtype = torch.bfloat16

# input
npu_input = to_npu(input, device)

# output
output = torch.empty_like(input)
npu_output = to_npu(output, device)

# gamma
npu_gamma = to_npu(gamma, device)

Expand All @@ -113,12 +131,13 @@ def tt_layernorm(input, *, normalized_dims=1, eps=1e-5, gamma=None, beta=None, d
npu_rstd = to_npu(cpu_rstd, device)

# Forward
npu_output = ttl.operations.primary.moreh_layernorm(
npu_output, npu_mean, npu_rstd = ttl.operations.primary.moreh_layernorm(
npu_input,
normalized_dims,
eps,
npu_gamma,
npu_beta,
output=npu_output,
mean=npu_mean,
rstd=npu_rstd,
compute_kernel_config=compute_kernel_config,
Expand Down Expand Up @@ -355,9 +374,10 @@ def run_moreh_layernorm_backward(
[
([1, 20], 1), # test 2d
([10, 20], 2), # test 2d
([3, TILE_HEIGHT * 4, TILE_WIDTH * 5], 1), # test 3d
([2, 3, 2 * TILE_HEIGHT, 2 * TILE_WIDTH], 4), # test 4d
([3, TILE_HEIGHT * 1, TILE_WIDTH * 5], 1), # test 3d
([3, 3, 4 * TILE_HEIGHT, 5 * TILE_WIDTH], 4), # test 4d
([5, 2, 3, 4, 2 * TILE_HEIGHT + 13, 3 * TILE_WIDTH + 13], 4), # test 6d
([2, TILE_HEIGHT + 13, 200 * TILE_WIDTH * 2 + 15], 1), # test 6d
],
)
def test_moreh_layernorm(input_shape_normalized_dims, elementwise_affine, eps, device):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void MAIN {
tile_regs_acquire();
cb_reserve_back(cb_mean, onetile);

copy_tile_init_with_dt(cb_ex);
copy_tile_init_with_dt(cb_ex, is_lastdim_layernorm);
copy_tile(cb_ex, first_tile, dst0);
tile_regs_commit();

Expand Down Expand Up @@ -367,7 +367,7 @@ void MAIN {
tile_regs_acquire();
cb_reserve_back(cb_rstd, onetile);

copy_tile_init_with_dt(cb_recip_std);
copy_tile_init_with_dt(cb_recip_std, is_lastdim_layernorm);
copy_tile(cb_recip_std, first_tile, dst0);
tile_regs_commit();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ void MAIN {
tile_regs_acquire();
cb_reserve_back(cb_mean, onetile);

copy_tile_init_with_dt(cb_ex);
copy_tile_init_with_dt(cb_ex, is_lastdim_layernorm);
copy_tile(cb_ex, first_tile, dst0);
tile_regs_commit();

Expand Down Expand Up @@ -346,7 +346,7 @@ void MAIN {
tile_regs_acquire();
cb_reserve_back(cb_rstd, onetile);

copy_tile_init_with_dt(cb_recip_std);
copy_tile_init_with_dt(cb_recip_std, is_lastdim_layernorm);
copy_tile(cb_recip_std, first_tile, dst0);
tile_regs_commit();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,99 @@
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"
#include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp"

void get_noc_offset_no_align(uint32_t h, uint32_t w, uint32_t element_size, uint32_t &noc_offset) {
noc_offset = 0;

// compute h, w in tile
h = h - (h / TILE_HEIGHT) * TILE_HEIGHT;
w = w - (w / TILE_WIDTH) * TILE_WIDTH;

const bool is_even_face = (w < FACE_HEIGHT);
const bool is_odd_face = !is_even_face;

const uint32_t face_width_bytes = FACE_WIDTH * element_size;

if (h < FACE_WIDTH && is_even_face)
noc_offset += h * face_width_bytes + w * element_size; // face 0
else if (h < FACE_WIDTH && is_odd_face)
noc_offset += (FACE_HEIGHT + h) * face_width_bytes + (w - FACE_WIDTH) * element_size; // face 1
else if (h >= FACE_WIDTH && is_even_face)
noc_offset += (FACE_HEIGHT + h) * face_width_bytes + w * element_size; // face 2
else if (h >= FACE_WIDTH && is_odd_face)
noc_offset += (2 * FACE_HEIGHT + h) * face_width_bytes + (w - FACE_WIDTH) * element_size; // face 3
}


template <typename T>
void write_mean_rstd(uint32_t cb_id, uint32_t cb_tile_bytes, uint32_t tile_offset, uint32_t num_inner, uint32_t normalized_dim, uint32_t outer_idx, uint32_t output_height, uint32_t output_width, uint32_t Ht, uint32_t Wt, T addrg)
{
constexpr uint32_t onetile = 1;

const auto cb_dtype_bytes = cb_tile_bytes / (TILE_HEIGHT * TILE_WIDTH);

cb_wait_front(cb_id, onetile);

uint32_t output_l1_write_addr = get_read_ptr(cb_id);
auto l1_ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t *>(output_l1_write_addr);

uint32_t output_tile_offset = tile_offset / num_inner;

if (normalized_dim == 1) {
for (uint32_t src_h = 0; src_h < 2; src_h++) {
auto output_tile_idx = output_tile_offset + outer_idx;

auto wt = output_tile_idx % Wt;
auto nh = output_tile_idx / Wt;
auto h = nh % output_height;
auto n = nh / output_height;

auto w = src_h * FACE_HEIGHT;

auto tilized_idx = get_tilized_idx(h % TILE_HEIGHT, w);

auto ht = h / TILE_HEIGHT;
auto noc_id = n * Ht * Wt + ht * Wt + wt;

auto src_idx = get_tilized_idx(0, src_h * FACE_WIDTH);

auto dst_noc_addr = get_noc_addr(noc_id, addrg);
noc_async_write(
output_l1_write_addr + src_idx * cb_dtype_bytes,
dst_noc_addr + tilized_idx * cb_dtype_bytes,
cb_dtype_bytes * FACE_HEIGHT);
noc_async_write_barrier();
}
} else {
auto output_idx = output_tile_offset + outer_idx;

auto w = output_idx % output_width;
auto nh = output_idx / output_width;
auto h = nh % output_height;
auto n = nh / output_height;

auto tilized_idx = get_tilized_idx(h % TILE_HEIGHT, w % TILE_WIDTH);

auto wt = w / TILE_WIDTH;
auto ht = h / TILE_HEIGHT;

auto noc_id = n * Ht * Wt + ht * Wt + wt;

if (output_idx != 0) {
l1_ptr[tilized_idx] = l1_ptr[0];
}

auto dst_noc_addr = get_noc_addr(noc_id, addrg);
noc_async_write(
output_l1_write_addr + tilized_idx * cb_dtype_bytes,
dst_noc_addr + tilized_idx * cb_dtype_bytes,
cb_dtype_bytes);
noc_async_write_barrier();
}

cb_pop_front(cb_id, onetile);
}

void kernel_main() {
const auto output_addr = get_arg_val<uint32_t>(0);
Expand All @@ -11,6 +104,9 @@ void kernel_main() {
const auto num_rows_per_core = get_arg_val<uint32_t>(3);
const auto num_inner = get_arg_val<uint32_t>(4);
const auto tile_offset = get_arg_val<uint32_t>(5);
const auto mean_rstd_height = get_arg_val<uint32_t>(6);
const auto mean_rstd_width = get_arg_val<uint32_t>(7);
const auto normalized_dim = get_arg_val<uint32_t>(8);

constexpr bool output_is_dram = get_compile_time_arg_val(0) == 1;
constexpr bool mean_is_dram = get_compile_time_arg_val(1) == 1;
Expand Down Expand Up @@ -47,24 +143,17 @@ void kernel_main() {
uint32_t offs = 0;
constexpr uint32_t onetile = 1;

uint32_t Wt = (mean_rstd_width + TILE_WIDTH - 1) / TILE_WIDTH;
uint32_t Ht = (mean_rstd_height + TILE_HEIGHT - 1) / TILE_HEIGHT;

for (uint32_t outer_idx = 0; outer_idx < num_rows_per_core; outer_idx++) {
if (mean_has_value) {
// mean
const auto mean_l1_read_addr = get_read_ptr(cb_id_mean);
cb_wait_front(cb_id_mean, onetile);
noc_async_write_tile((offs + tile_offset) / num_inner, mean_addrg, mean_l1_read_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_mean, onetile);
} // mean_has_value
write_mean_rstd(cb_id_mean, mean_tile_bytes, tile_offset, num_inner, normalized_dim, outer_idx, mean_rstd_height, mean_rstd_width, Ht, Wt, mean_addrg);
}

if (rstd_has_value) {
// rstd
const auto rstd_l1_read_addr = get_read_ptr(cb_id_rstd);
cb_wait_front(cb_id_rstd, onetile);
noc_async_write_tile((offs + tile_offset) / num_inner, rstd_addrg, rstd_l1_read_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_rstd, onetile);
} // rstd_has_value
write_mean_rstd(cb_id_rstd, mean_tile_bytes, tile_offset, num_inner, normalized_dim, outer_idx, mean_rstd_height, mean_rstd_width, Ht, Wt, rstd_addrg);
}

// output
for (uint32_t inner_idx = 0; inner_idx < num_inner; inner_idx += block_size) {
Expand Down
Loading

0 comments on commit 42dbdda

Please sign in to comment.